1use std::{
2 io::{Read, Seek},
3 path::Path,
4 sync::Arc,
5};
6
7use arrow::{
8 csv::{ReaderBuilder, reader::Format},
9 datatypes::{Field, Schema, SchemaRef},
10 record_batch::RecordBatch,
11};
12use comfy_table::Table;
13use itertools::Itertools;
14
15pub fn read_csv_file<P: AsRef<Path>>(path: P) -> anyhow::Result<RecordBatch> {
16 let mut file = std::fs::File::open(&path)?;
17 log::debug!("Reading CSV file {:?}", path.as_ref());
18 read_csv(&mut file)
19}
20
21pub fn read_csv<R>(reader: &mut R) -> anyhow::Result<RecordBatch>
23where
24 R: Read + Seek,
25{
26 let (file_schema, _) = Format::default()
28 .with_header(true)
29 .infer_schema(reader.by_ref(), Some(100))?;
30 reader.rewind()?;
31
32 log::debug!(
33 "Inferred schema: {:?}",
34 file_schema
35 .fields()
36 .iter()
37 .map(|f| f.name())
38 .collect::<Vec<_>>()
39 );
40
41 let _time = Arc::new(arrow::datatypes::Field::new(
42 "time",
43 arrow::datatypes::DataType::Float64,
44 false,
45 ));
46
47 let file_schema = Arc::new(Schema::new(
49 file_schema
50 .fields()
51 .iter()
52 .map(|f| Arc::new(Field::new(f.name(), f.data_type().clone(), false)) as Arc<Field>)
53 .collect::<Vec<_>>(),
54 ));
55
56 let reader = ReaderBuilder::new(file_schema)
57 .with_header(true)
58 .build(reader)?;
59
60 let batches = reader.collect::<Result<Vec<_>, _>>()?;
61
62 Ok(arrow::compute::concat_batches(
63 &batches[0].schema(),
64 &batches,
65 )?)
66}
67
68pub fn pretty_format_projection(
70 input_data_schema: Arc<Schema>,
71 model_input_schema: Arc<Schema>,
72 time_field: Arc<Field>,
73) -> impl std::fmt::Display {
74 let mut table = Table::new();
75 table.load_preset(comfy_table::presets::ASCII_BORDERS_ONLY_CONDENSED);
76 table.set_header(vec!["Variable", "Input Type", "Model Type"]);
77 let rows_iter = input_data_schema.fields().iter().map(|input_field| {
78 let model_field_name = model_input_schema
79 .fields()
80 .iter()
81 .chain(std::iter::once(&time_field))
82 .find(|model_field| model_field.name() == input_field.name())
83 .map(|model_field| model_field.data_type());
84 vec![
85 input_field.name().to_string(),
86 input_field.data_type().to_string(),
87 model_field_name
88 .map(|t| t.to_string())
89 .unwrap_or("-None-".to_string()),
90 ]
91 });
92 table.add_rows(rows_iter);
93 table
94}
95
96pub fn project_input_data(
101 input_data: &RecordBatch,
102 model_input_schema: SchemaRef,
103) -> anyhow::Result<RecordBatch> {
104 let input_data_schema = input_data.schema();
105
106 let time_field = Arc::new(Field::new(
107 "time",
108 arrow::datatypes::DataType::Float64,
109 false,
110 ));
111
112 let fields_iter = std::iter::once(&time_field).chain(model_input_schema.fields().iter());
114
115 let (projected_fields, projected_columns): (Vec<_>, Vec<_>) = fields_iter
116 .filter_map(|field| {
117 input_data.column_by_name(field.name()).map(|col| {
118 arrow::compute::cast(col, field.data_type())
119 .map(|col| (field.clone(), col))
120 .map_err(|_| anyhow::anyhow!("Error casting type"))
121 })
122 })
123 .process_results(|pairs| pairs.unzip())?;
124
125 log::debug!(
126 "Projected input data schema:\n{}",
127 pretty_format_projection(input_data_schema, model_input_schema, time_field)
128 );
129
130 let input_data_schema = Arc::new(Schema::new(projected_fields));
131 RecordBatch::try_new(input_data_schema, projected_columns).map_err(anyhow::Error::from)
132}