fmi_sim/sim/
util.rs

1use std::{
2    io::{Read, Seek},
3    path::Path,
4    sync::Arc,
5};
6
7use arrow::{
8    csv::{ReaderBuilder, reader::Format},
9    datatypes::{Field, Schema, SchemaRef},
10    record_batch::RecordBatch,
11};
12use comfy_table::Table;
13use itertools::Itertools;
14
15pub fn read_csv_file<P: AsRef<Path>>(path: P) -> anyhow::Result<RecordBatch> {
16    let mut file = std::fs::File::open(&path)?;
17    log::debug!("Reading CSV file {:?}", path.as_ref());
18    read_csv(&mut file)
19}
20
21/// Read a CSV file into a single RecordBatch.
22pub fn read_csv<R>(reader: &mut R) -> anyhow::Result<RecordBatch>
23where
24    R: Read + Seek,
25{
26    // Infer the schema with the first 100 records
27    let (file_schema, _) = Format::default()
28        .with_header(true)
29        .infer_schema(reader.by_ref(), Some(100))?;
30    reader.rewind()?;
31
32    log::debug!(
33        "Inferred schema: {:?}",
34        file_schema
35            .fields()
36            .iter()
37            .map(|f| f.name())
38            .collect::<Vec<_>>()
39    );
40
41    let _time = Arc::new(arrow::datatypes::Field::new(
42        "time",
43        arrow::datatypes::DataType::Float64,
44        false,
45    ));
46
47    // Create a non-nullible schema from the file schema
48    let file_schema = Arc::new(Schema::new(
49        file_schema
50            .fields()
51            .iter()
52            .map(|f| Arc::new(Field::new(f.name(), f.data_type().clone(), false)) as Arc<Field>)
53            .collect::<Vec<_>>(),
54    ));
55
56    let reader = ReaderBuilder::new(file_schema)
57        .with_header(true)
58        .build(reader)?;
59
60    let batches = reader.collect::<Result<Vec<_>, _>>()?;
61
62    Ok(arrow::compute::concat_batches(
63        &batches[0].schema(),
64        &batches,
65    )?)
66}
67
68/// Format the projected fields in a human-readable format
69pub fn pretty_format_projection(
70    input_data_schema: Arc<Schema>,
71    model_input_schema: Arc<Schema>,
72    time_field: Arc<Field>,
73) -> impl std::fmt::Display {
74    let mut table = Table::new();
75    table.load_preset(comfy_table::presets::ASCII_BORDERS_ONLY_CONDENSED);
76    table.set_header(vec!["Variable", "Input Type", "Model Type"]);
77    let rows_iter = input_data_schema.fields().iter().map(|input_field| {
78        let model_field_name = model_input_schema
79            .fields()
80            .iter()
81            .chain(std::iter::once(&time_field))
82            .find(|model_field| model_field.name() == input_field.name())
83            .map(|model_field| model_field.data_type());
84        vec![
85            input_field.name().to_string(),
86            input_field.data_type().to_string(),
87            model_field_name
88                .map(|t| t.to_string())
89                .unwrap_or("-None-".to_string()),
90        ]
91    });
92    table.add_rows(rows_iter);
93    table
94}
95
96/// Transform the `input_data` to match the `model_input_schema`. Input data columns are projected and
97/// cast to the corresponding input schema columns.
98///
99/// This is necessary because the `input_data` may have extra columns or have different datatypes.
100pub fn project_input_data(
101    input_data: &RecordBatch,
102    model_input_schema: SchemaRef,
103) -> anyhow::Result<RecordBatch> {
104    let input_data_schema = input_data.schema();
105
106    let time_field = Arc::new(Field::new(
107        "time",
108        arrow::datatypes::DataType::Float64,
109        false,
110    ));
111
112    // Create an iterator over the fields of the input data, starting with the time field
113    let fields_iter = std::iter::once(&time_field).chain(model_input_schema.fields().iter());
114
115    let (projected_fields, projected_columns): (Vec<_>, Vec<_>) = fields_iter
116        .filter_map(|field| {
117            input_data.column_by_name(field.name()).map(|col| {
118                arrow::compute::cast(col, field.data_type())
119                    .map(|col| (field.clone(), col))
120                    .map_err(|_| anyhow::anyhow!("Error casting type"))
121            })
122        })
123        .process_results(|pairs| pairs.unzip())?;
124
125    log::debug!(
126        "Projected input data schema:\n{}",
127        pretty_format_projection(input_data_schema, model_input_schema, time_field)
128    );
129
130    let input_data_schema = Arc::new(Schema::new(projected_fields));
131    RecordBatch::try_new(input_data_schema, projected_columns).map_err(anyhow::Error::from)
132}