fmi_sim/sim/
io.rs

1use std::{fmt::Display, sync::Arc};
2
3use anyhow::Context;
4use arrow::{
5    array::{ArrayBuilder, ArrayRef, Float64Array, Float64Builder, downcast_array, make_builder},
6    datatypes::{DataType, Field, Schema},
7    downcast_primitive_array,
8    record_batch::RecordBatch,
9};
10use fmi::traits::FmiInstance;
11
12use crate::Error;
13
14use super::{
15    interpolation::{Interpolate, PreLookup, find_index},
16    params::SimParams,
17    traits::{ImportSchemaBuilder, InstSetValues},
18    util::project_input_data,
19};
20
21/// Container for holding initial values for the FMU.
22pub struct StartValues<VR> {
23    pub structural_parameters: Vec<(VR, ArrayRef)>,
24    pub variables: Vec<(VR, ArrayRef)>,
25}
26
27pub struct InputState<Inst: FmiInstance> {
28    pub(crate) input_data: Option<RecordBatch>,
29    // Map schema column index to ValueReference
30    pub(crate) continuous_inputs: Vec<(Field, Inst::ValueRef)>,
31    // Map schema column index to ValueReference
32    pub(crate) discrete_inputs: Vec<(Field, Inst::ValueRef)>,
33}
34
35impl<Inst> Display for InputState<Inst>
36where
37    Inst: FmiInstance,
38{
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        let continuous_inputs = self
41            .continuous_inputs
42            .iter()
43            .map(|(field, _)| field.name())
44            .collect::<Vec<_>>();
45
46        let discrete_inputs = self
47            .discrete_inputs
48            .iter()
49            .map(|(field, _)| field.name())
50            .collect::<Vec<_>>();
51
52        f.write_str("InputState {\n")?;
53        if let Some(input_data) = &self.input_data {
54            writeln!(
55                f,
56                "input_data:\n{}",
57                arrow::util::pretty::pretty_format_batches(std::slice::from_ref(input_data))
58                    .unwrap()
59            )?;
60        } else {
61            writeln!(f, "input_data: None")?;
62        }
63        writeln!(f, "continuous_inputs: {continuous_inputs:?}")?;
64        writeln!(f, "discrete_inputs: {discrete_inputs:?}")?;
65        write!(f, "}}")
66    }
67}
68
69impl<Inst> InputState<Inst>
70where
71    Inst: FmiInstance,
72{
73    pub fn new<Import: ImportSchemaBuilder<ValueRef = Inst::ValueRef>>(
74        import: &Import,
75        input_data: Option<RecordBatch>,
76    ) -> anyhow::Result<Self> {
77        let model_input_schema = Arc::new(import.inputs_schema());
78        let continuous_inputs = import.continuous_inputs().collect();
79        let discrete_inputs = import.discrete_inputs().collect();
80
81        let input_data = input_data
82            .map(|input_data| project_input_data(&input_data, model_input_schema.clone()))
83            .transpose()?;
84
85        Ok(Self {
86            input_data,
87            continuous_inputs,
88            discrete_inputs,
89        })
90    }
91}
92
93impl<Inst> InputState<Inst>
94where
95    Inst: InstSetValues,
96{
97    pub fn apply_input<I: Interpolate>(
98        &mut self,
99        time: f64,
100        inst: &mut Inst,
101        discrete: bool,
102        continuous: bool,
103        after_event: bool,
104    ) -> Result<(), Error> {
105        if let Some(input_data) = &self.input_data {
106            let time_array: Float64Array = downcast_array(
107                input_data
108                    .column_by_name("time")
109                    .context("Input data must have a column named 'time' with the time values")?,
110            );
111
112            if continuous {
113                let pl = PreLookup::new(&time_array, time, after_event);
114
115                for (field, vr) in &self.continuous_inputs {
116                    if let Some(input_col) = input_data.column_by_name(field.name()) {
117                        assert_eq!(input_col.data_type(), field.data_type());
118
119                        //let ary = arrow::compute::cast(input_col, field.data_type()).map_err(|_| anyhow::anyhow!("Error casting type"))?;
120
121                        //log::trace!( "Applying continuous input {}={input_col:?} at time {time}", field.name());
122
123                        inst.set_interpolated::<I>(*vr, &pl, input_col)?;
124                    }
125                }
126            }
127
128            if discrete {
129                // TODO: Refactor the interpolation code to separate index lookup from interpolation
130                let input_idx = find_index(&time_array, time, after_event);
131
132                for (field, vr) in &self.discrete_inputs {
133                    if let Some(input_col) = input_data.column_by_name(field.name()) {
134                        let ary = arrow::compute::cast(input_col, field.data_type())
135                            .map_err(|_| anyhow::anyhow!("Error casting type"))?;
136
137                        let values = &ary.slice(input_idx, 1);
138
139                        //log::trace!( "Applying discrete input {}={values:#?} at time {time:.2}", field.name());
140
141                        inst.set_array(&[*vr], values);
142                    }
143                }
144            }
145        }
146
147        Ok(())
148    }
149
150    pub fn next_input_event(&self, time: f64) -> f64 {
151        if let Some(input_data) = &self.input_data {
152            let time_array: Float64Array =
153                downcast_array(input_data.column_by_name("time").unwrap());
154
155            for i in 0..(time_array.len() - 1) {
156                let t0 = time_array.value(i);
157                let t1 = time_array.value(i + 1);
158
159                if time >= t1 {
160                    continue;
161                }
162
163                if t0 == t1 {
164                    return t0; // discrete change of a continuous variable
165                }
166
167                // TODO: This could be computed once and cached
168
169                // skip continuous variables
170                for (field, _vr) in &self.discrete_inputs {
171                    if let Some(input_col) = input_data.column_by_name(field.name()) {
172                        if downcast_primitive_array!(
173                            input_col => input_col.value(i) != input_col.value(i + 1),
174                            t => panic!("Unsupported datatype {}", t)
175                        ) {
176                            return t1;
177                        }
178                    }
179                }
180            }
181        }
182        f64::INFINITY
183    }
184}
185
186pub struct Recorder<Inst: FmiInstance> {
187    pub(crate) field: Field,
188    pub(crate) value_reference: Inst::ValueRef,
189    pub(crate) builder: Box<dyn ArrayBuilder>,
190}
191
192pub struct RecorderState<Inst: FmiInstance> {
193    pub(crate) time: Float64Builder,
194    pub(crate) recorders: Vec<Recorder<Inst>>,
195}
196
197impl<Inst> RecorderState<Inst>
198where
199    Inst: FmiInstance,
200{
201    pub fn new<Import: ImportSchemaBuilder<ValueRef = Inst::ValueRef>>(
202        import: &Import,
203        sim_params: &SimParams,
204    ) -> Self {
205        let num_points = ((sim_params.stop_time - sim_params.start_time)
206            / sim_params.output_interval)
207            .ceil() as usize;
208
209        let time = Float64Builder::with_capacity(num_points);
210
211        let recorders = import
212            .outputs()
213            .map(|(field, vr)| {
214                let builder = make_builder(field.data_type(), num_points);
215                Recorder {
216                    field,
217                    value_reference: vr,
218                    builder,
219                }
220            })
221            .collect();
222
223        Self { time, recorders }
224    }
225
226    /// Finish the output state and return the RecordBatch.
227    pub fn finish(self) -> RecordBatch {
228        let Self {
229            mut time,
230            recorders,
231        } = self;
232
233        let recorders = recorders.into_iter().map(
234            |Recorder {
235                 field,
236                 value_reference: _,
237                 mut builder,
238             }| { (field, builder.finish()) },
239        );
240
241        let time = std::iter::once((
242            Field::new("time", DataType::Float64, false),
243            Arc::new(time.finish()) as _,
244        ));
245
246        let (fields, columns): (Vec<_>, Vec<_>) = time.chain(recorders).unzip();
247        let schema = Arc::new(Schema::new(fields));
248        RecordBatch::try_new(schema, columns).unwrap()
249    }
250}