1use std::{fmt::Display, sync::Arc};
2
3use anyhow::Context;
4use arrow::{
5 array::{ArrayBuilder, ArrayRef, Float64Array, Float64Builder, downcast_array, make_builder},
6 datatypes::{DataType, Field, Schema},
7 downcast_primitive_array,
8 record_batch::RecordBatch,
9};
10use fmi::traits::FmiInstance;
11
12use crate::Error;
13
14use super::{
15 interpolation::{Interpolate, PreLookup, find_index},
16 params::SimParams,
17 traits::{ImportSchemaBuilder, InstSetValues},
18 util::project_input_data,
19};
20
21pub struct StartValues<VR> {
23 pub structural_parameters: Vec<(VR, ArrayRef)>,
24 pub variables: Vec<(VR, ArrayRef)>,
25}
26
27pub struct InputState<Inst: FmiInstance> {
28 pub(crate) input_data: Option<RecordBatch>,
29 pub(crate) continuous_inputs: Vec<(Field, Inst::ValueRef)>,
31 pub(crate) discrete_inputs: Vec<(Field, Inst::ValueRef)>,
33}
34
35impl<Inst> Display for InputState<Inst>
36where
37 Inst: FmiInstance,
38{
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 let continuous_inputs = self
41 .continuous_inputs
42 .iter()
43 .map(|(field, _)| field.name())
44 .collect::<Vec<_>>();
45
46 let discrete_inputs = self
47 .discrete_inputs
48 .iter()
49 .map(|(field, _)| field.name())
50 .collect::<Vec<_>>();
51
52 f.write_str("InputState {\n")?;
53 if let Some(input_data) = &self.input_data {
54 writeln!(
55 f,
56 "input_data:\n{}",
57 arrow::util::pretty::pretty_format_batches(std::slice::from_ref(input_data))
58 .unwrap()
59 )?;
60 } else {
61 writeln!(f, "input_data: None")?;
62 }
63 writeln!(f, "continuous_inputs: {continuous_inputs:?}")?;
64 writeln!(f, "discrete_inputs: {discrete_inputs:?}")?;
65 write!(f, "}}")
66 }
67}
68
69impl<Inst> InputState<Inst>
70where
71 Inst: FmiInstance,
72{
73 pub fn new<Import: ImportSchemaBuilder<ValueRef = Inst::ValueRef>>(
74 import: &Import,
75 input_data: Option<RecordBatch>,
76 ) -> anyhow::Result<Self> {
77 let model_input_schema = Arc::new(import.inputs_schema());
78 let continuous_inputs = import.continuous_inputs().collect();
79 let discrete_inputs = import.discrete_inputs().collect();
80
81 let input_data = input_data
82 .map(|input_data| project_input_data(&input_data, model_input_schema.clone()))
83 .transpose()?;
84
85 Ok(Self {
86 input_data,
87 continuous_inputs,
88 discrete_inputs,
89 })
90 }
91}
92
93impl<Inst> InputState<Inst>
94where
95 Inst: InstSetValues,
96{
97 pub fn apply_input<I: Interpolate>(
98 &mut self,
99 time: f64,
100 inst: &mut Inst,
101 discrete: bool,
102 continuous: bool,
103 after_event: bool,
104 ) -> Result<(), Error> {
105 if let Some(input_data) = &self.input_data {
106 let time_array: Float64Array = downcast_array(
107 input_data
108 .column_by_name("time")
109 .context("Input data must have a column named 'time' with the time values")?,
110 );
111
112 if continuous {
113 let pl = PreLookup::new(&time_array, time, after_event);
114
115 for (field, vr) in &self.continuous_inputs {
116 if let Some(input_col) = input_data.column_by_name(field.name()) {
117 assert_eq!(input_col.data_type(), field.data_type());
118
119 inst.set_interpolated::<I>(*vr, &pl, input_col)?;
124 }
125 }
126 }
127
128 if discrete {
129 let input_idx = find_index(&time_array, time, after_event);
131
132 for (field, vr) in &self.discrete_inputs {
133 if let Some(input_col) = input_data.column_by_name(field.name()) {
134 let ary = arrow::compute::cast(input_col, field.data_type())
135 .map_err(|_| anyhow::anyhow!("Error casting type"))?;
136
137 let values = &ary.slice(input_idx, 1);
138
139 inst.set_array(&[*vr], values);
142 }
143 }
144 }
145 }
146
147 Ok(())
148 }
149
150 pub fn next_input_event(&self, time: f64) -> f64 {
151 if let Some(input_data) = &self.input_data {
152 let time_array: Float64Array =
153 downcast_array(input_data.column_by_name("time").unwrap());
154
155 for i in 0..(time_array.len() - 1) {
156 let t0 = time_array.value(i);
157 let t1 = time_array.value(i + 1);
158
159 if time >= t1 {
160 continue;
161 }
162
163 if t0 == t1 {
164 return t0; }
166
167 for (field, _vr) in &self.discrete_inputs {
171 if let Some(input_col) = input_data.column_by_name(field.name()) {
172 if downcast_primitive_array!(
173 input_col => input_col.value(i) != input_col.value(i + 1),
174 t => panic!("Unsupported datatype {}", t)
175 ) {
176 return t1;
177 }
178 }
179 }
180 }
181 }
182 f64::INFINITY
183 }
184}
185
186pub struct Recorder<Inst: FmiInstance> {
187 pub(crate) field: Field,
188 pub(crate) value_reference: Inst::ValueRef,
189 pub(crate) builder: Box<dyn ArrayBuilder>,
190}
191
192pub struct RecorderState<Inst: FmiInstance> {
193 pub(crate) time: Float64Builder,
194 pub(crate) recorders: Vec<Recorder<Inst>>,
195}
196
197impl<Inst> RecorderState<Inst>
198where
199 Inst: FmiInstance,
200{
201 pub fn new<Import: ImportSchemaBuilder<ValueRef = Inst::ValueRef>>(
202 import: &Import,
203 sim_params: &SimParams,
204 ) -> Self {
205 let num_points = ((sim_params.stop_time - sim_params.start_time)
206 / sim_params.output_interval)
207 .ceil() as usize;
208
209 let time = Float64Builder::with_capacity(num_points);
210
211 let recorders = import
212 .outputs()
213 .map(|(field, vr)| {
214 let builder = make_builder(field.data_type(), num_points);
215 Recorder {
216 field,
217 value_reference: vr,
218 builder,
219 }
220 })
221 .collect();
222
223 Self { time, recorders }
224 }
225
226 pub fn finish(self) -> RecordBatch {
228 let Self {
229 mut time,
230 recorders,
231 } = self;
232
233 let recorders = recorders.into_iter().map(
234 |Recorder {
235 field,
236 value_reference: _,
237 mut builder,
238 }| { (field, builder.finish()) },
239 );
240
241 let time = std::iter::once((
242 Field::new("time", DataType::Float64, false),
243 Arc::new(time.finish()) as _,
244 ));
245
246 let (fields, columns): (Vec<_>, Vec<_>) = time.chain(recorders).unzip();
247 let schema = Arc::new(Schema::new(fields));
248 RecordBatch::try_new(schema, columns).unwrap()
249 }
250}