datasynth_eval/behavioral_fidelity/
loader.rs1use std::collections::HashMap;
4use std::fs::File;
5use std::path::Path;
6
7use arrow::array::{
8 Array, Date32Array, Float64Array, StringArray, TimestampMicrosecondArray,
9 TimestampMillisecondArray, TimestampNanosecondArray,
10};
11use arrow::record_batch::RecordBatch;
12use chrono::{DateTime, NaiveDate, TimeZone, Utc};
13use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
14
15use super::entity_profile::{reference_corpus_aliases, synthetic_aliases};
16use super::error::{BehavioralFidelityError, BehavioralFidelityResult};
17use super::types::Record;
18
19pub fn load_parquet_records(path: &Path) -> BehavioralFidelityResult<Vec<Record>> {
21 let path = resolve_single_file(path, "parquet")?;
22 let file = File::open(&path)?;
23 let builder = ParquetRecordBatchReaderBuilder::try_new(file)
24 .map_err(|e| BehavioralFidelityError::Parquet(e.to_string()))?;
25 let reader = builder
26 .build()
27 .map_err(|e| BehavioralFidelityError::Parquet(e.to_string()))?;
28 let mut out = Vec::new();
29 let mut skipped = 0usize;
30 for batch_res in reader {
31 let batch = batch_res.map_err(|e| BehavioralFidelityError::Parquet(e.to_string()))?;
32 let aliases = pick_alias_map(&batch);
33 for row in 0..batch.num_rows() {
34 match extract_row(&batch, row, &aliases) {
35 Ok(rec) => out.push(rec),
36 Err(BehavioralFidelityError::Schema(_)) => skipped += 1,
37 Err(e) => return Err(e),
38 }
39 }
40 }
41 if skipped > 0 {
42 tracing::warn!(
43 "load_parquet_records: skipped {} rows with missing/malformed required dates",
44 skipped
45 );
46 }
47 Ok(out)
48}
49
50pub fn load_csv_records(path: &Path) -> BehavioralFidelityResult<Vec<Record>> {
53 let path = resolve_single_file(path, "csv")?;
54 let mut rdr = csv::ReaderBuilder::new()
55 .has_headers(true)
56 .from_path(&path)
57 .map_err(|e| BehavioralFidelityError::Io(std::io::Error::other(e.to_string())))?;
58 let headers: Vec<String> = rdr
59 .headers()
60 .map_err(|e| BehavioralFidelityError::Io(std::io::Error::other(e.to_string())))?
61 .iter()
62 .map(|s| s.to_string())
63 .collect();
64 let aliases = pick_alias_map_from_headers(&headers);
65 let header_idx: HashMap<&str, usize> = headers
66 .iter()
67 .enumerate()
68 .map(|(i, h)| (h.as_str(), i))
69 .collect();
70 let mut out = Vec::new();
71 let mut skipped = 0usize;
72 for rec in rdr.records() {
73 let rec =
74 rec.map_err(|e| BehavioralFidelityError::Io(std::io::Error::other(e.to_string())))?;
75 match extract_csv_row(&rec, &header_idx, &aliases) {
76 Ok(record) => out.push(record),
77 Err(BehavioralFidelityError::Schema(_)) => skipped += 1,
78 Err(e) => return Err(e),
79 }
80 }
81 if skipped > 0 {
82 tracing::warn!(
83 "load_csv_records: skipped {} rows with missing/malformed required dates",
84 skipped
85 );
86 }
87 Ok(out)
88}
89
90fn resolve_single_file(
91 path: &Path,
92 extension: &str,
93) -> BehavioralFidelityResult<std::path::PathBuf> {
94 if path.is_file() {
95 return Ok(path.to_path_buf());
96 }
97 if path.is_dir() {
98 for entry in std::fs::read_dir(path)? {
99 let entry = entry?;
100 if entry
101 .path()
102 .extension()
103 .is_some_and(|e| e.eq_ignore_ascii_case(extension))
104 {
105 return Ok(entry.path());
106 }
107 }
108 }
109 Err(BehavioralFidelityError::Io(std::io::Error::other(format!(
110 "no {extension} file at {}",
111 path.display()
112 ))))
113}
114
115fn pick_alias_map(batch: &RecordBatch) -> HashMap<&'static str, &'static str> {
116 let schema = batch.schema();
117 let cols: Vec<String> = schema
118 .fields()
119 .iter()
120 .map(|f| f.name().to_string())
121 .collect();
122 pick_alias_map_from_headers(&cols)
123}
124
125fn pick_alias_map_from_headers(cols: &[String]) -> HashMap<&'static str, &'static str> {
126 let has = |needle: &str| cols.iter().any(|c| c == needle);
127 if has("Tarding Partner") || has("Functional Amount") {
128 reference_corpus_aliases().into_iter().collect()
129 } else {
130 synthetic_aliases().into_iter().collect()
131 }
132}
133
134fn extract_row(
135 batch: &RecordBatch,
136 row: usize,
137 aliases: &HashMap<&'static str, &'static str>,
138) -> BehavioralFidelityResult<Record> {
139 let s = batch.schema();
140 let col_idx = |canon: &str| -> Option<usize> {
141 let real = aliases.get(canon)?;
142 s.fields().iter().position(|f| f.name() == *real)
143 };
144 let str_at = |canon: &str| -> Option<String> {
145 let i = col_idx(canon)?;
146 let arr = batch.column(i).as_any().downcast_ref::<StringArray>()?;
147 if arr.is_null(row) {
148 None
149 } else {
150 Some(arr.value(row).to_string())
151 }
152 };
153 let f64_at = |canon: &str| -> Option<f64> {
154 let i = col_idx(canon)?;
155 let arr = batch.column(i).as_any().downcast_ref::<Float64Array>()?;
156 if arr.is_null(row) {
157 None
158 } else {
159 Some(arr.value(row))
160 }
161 };
162 let date_at = |canon: &str| -> Option<NaiveDate> {
163 let i = col_idx(canon)?;
164 if let Some(arr) = batch.column(i).as_any().downcast_ref::<StringArray>() {
165 if arr.is_null(row) {
166 return None;
167 }
168 return NaiveDate::parse_from_str(arr.value(row), "%Y-%m-%d").ok();
169 }
170 if let Some(arr) = batch.column(i).as_any().downcast_ref::<Date32Array>() {
171 if arr.is_null(row) {
172 return None;
173 }
174 return arr.value_as_date(row);
175 }
176 None
177 };
178 let ts_at = |canon: &str| -> Option<DateTime<Utc>> {
179 let i = col_idx(canon)?;
180 if let Some(arr) = batch
181 .column(i)
182 .as_any()
183 .downcast_ref::<TimestampMillisecondArray>()
184 {
185 if arr.is_null(row) {
186 return None;
187 }
188 return Utc.timestamp_millis_opt(arr.value(row)).single();
189 }
190 if let Some(arr) = batch
191 .column(i)
192 .as_any()
193 .downcast_ref::<TimestampMicrosecondArray>()
194 {
195 if arr.is_null(row) {
196 return None;
197 }
198 return Utc.timestamp_micros(arr.value(row)).single();
199 }
200 if let Some(arr) = batch
201 .column(i)
202 .as_any()
203 .downcast_ref::<TimestampNanosecondArray>()
204 {
205 if arr.is_null(row) {
206 return None;
207 }
208 let nanos = arr.value(row);
209 return Some(Utc.timestamp_nanos(nanos));
210 }
211 if let Some(arr) = batch.column(i).as_any().downcast_ref::<StringArray>() {
212 if arr.is_null(row) {
213 return None;
214 }
215 return DateTime::parse_from_rfc3339(arr.value(row))
216 .ok()
217 .map(|dt| dt.with_timezone(&Utc));
218 }
219 None
220 };
221
222 Ok(Record {
223 source: str_at("Source").unwrap_or_default(),
224 gl_account: str_at("GLAccount").unwrap_or_default(),
225 cost_center: str_at("CostCenter"),
226 profit_center: str_at("ProfitCenter"),
227 trading_partner: str_at("TradingPartner"),
228 je_number: str_at("JENumber").unwrap_or_default(),
229 je_line_number: str_at("JELineNumber").unwrap_or_default(),
230 effective_date: date_at("EffectiveDate")
231 .ok_or_else(|| BehavioralFidelityError::Schema("missing EffectiveDate".into()))?,
232 entry_date: date_at("EntryDate")
233 .ok_or_else(|| BehavioralFidelityError::Schema("missing EntryDate".into()))?,
234 created_at: ts_at("CreatedAt"),
235 functional_amount: f64_at("FunctionalAmount").unwrap_or(0.0),
236 header_text: str_at("HeaderText").unwrap_or_default(),
238 line_text: str_at("LineText").unwrap_or_default(),
239 })
240}
241
242fn extract_csv_row(
243 rec: &csv::StringRecord,
244 headers: &HashMap<&str, usize>,
245 aliases: &HashMap<&'static str, &'static str>,
246) -> BehavioralFidelityResult<Record> {
247 let get = |canon: &str| -> Option<String> {
248 let real = aliases.get(canon)?;
249 let i = headers.get(*real)?;
250 let v = rec.get(*i)?;
251 if v.is_empty() {
252 None
253 } else {
254 Some(v.to_string())
255 }
256 };
257 Ok(Record {
258 source: get("Source").unwrap_or_default(),
259 gl_account: get("GLAccount").unwrap_or_default(),
260 cost_center: get("CostCenter"),
261 profit_center: get("ProfitCenter"),
262 trading_partner: get("TradingPartner"),
263 je_number: get("JENumber").unwrap_or_default(),
264 je_line_number: get("JELineNumber").unwrap_or_default(),
265 effective_date: get("EffectiveDate")
266 .and_then(|s| NaiveDate::parse_from_str(&s, "%Y-%m-%d").ok())
267 .ok_or_else(|| BehavioralFidelityError::Schema("missing EffectiveDate".into()))?,
268 entry_date: get("EntryDate")
269 .and_then(|s| NaiveDate::parse_from_str(&s, "%Y-%m-%d").ok())
270 .ok_or_else(|| BehavioralFidelityError::Schema("missing EntryDate".into()))?,
271 created_at: get("CreatedAt").and_then(|s| {
272 DateTime::parse_from_rfc3339(&s)
273 .ok()
274 .map(|d| d.with_timezone(&Utc))
275 }),
276 functional_amount: get("FunctionalAmount")
277 .and_then(|s| s.parse().ok())
278 .unwrap_or(0.0),
279 header_text: get("HeaderText").unwrap_or_default(),
281 line_text: get("LineText").unwrap_or_default(),
282 })
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use std::sync::Arc;
289
290 use arrow::array::{Float64Array, StringArray};
291 use arrow::datatypes::{DataType, Field, Schema};
292 use arrow::record_batch::RecordBatch;
293 use parquet::arrow::ArrowWriter;
294 use tempfile::NamedTempFile;
295
296 fn build_reference_corpus_batch() -> RecordBatch {
297 let schema = Arc::new(Schema::new(vec![
298 Field::new("JE Number", DataType::Utf8, false),
299 Field::new("GL Account Number", DataType::Utf8, false),
300 Field::new("Functional Amount", DataType::Float64, false),
301 Field::new("Effective Date", DataType::Utf8, false),
302 Field::new("Entry Date", DataType::Utf8, false),
303 Field::new("Source", DataType::Utf8, false),
304 Field::new("Cost Center", DataType::Utf8, true),
305 Field::new("Profit Center", DataType::Utf8, true),
306 Field::new("Tarding Partner", DataType::Utf8, true),
307 Field::new("JE Line Number", DataType::Utf8, false),
308 ]));
309 let arr_je = StringArray::from(vec!["2022-0090-001", "2022-0090-001"]);
310 let arr_gl = StringArray::from(vec!["1100", "2000"]);
311 let arr_amt = Float64Array::from(vec![100.0, -100.0]);
312 let arr_eff = StringArray::from(vec!["2022-04-25", "2022-04-25"]);
313 let arr_ent = StringArray::from(vec!["2022-04-14", "2022-04-14"]);
314 let arr_src = StringArray::from(vec!["KR", "KR"]);
315 let arr_cc = StringArray::from(vec![Some("CC100"), None]);
316 let arr_pc = StringArray::from(vec![Some("PC100"), None]);
317 let arr_tp = StringArray::from(vec![Some("TP1"), None]);
318 let arr_line = StringArray::from(vec!["001", "002"]);
319 RecordBatch::try_new(
320 schema,
321 vec![
322 Arc::new(arr_je),
323 Arc::new(arr_gl),
324 Arc::new(arr_amt),
325 Arc::new(arr_eff),
326 Arc::new(arr_ent),
327 Arc::new(arr_src),
328 Arc::new(arr_cc),
329 Arc::new(arr_pc),
330 Arc::new(arr_tp),
331 Arc::new(arr_line),
332 ],
333 )
334 .unwrap()
335 }
336
337 #[test]
338 fn load_parquet_reference_corpus_shape() {
339 let batch = build_reference_corpus_batch();
340 let tmp = NamedTempFile::new().unwrap();
341 let parquet_path = tmp.path().with_extension("parquet");
342 {
343 let file = File::create(&parquet_path).unwrap();
344 let mut writer = ArrowWriter::try_new(file, batch.schema(), None).unwrap();
345 writer.write(&batch).unwrap();
346 writer.close().unwrap();
347 }
348 let records = load_parquet_records(&parquet_path).unwrap();
349 assert_eq!(records.len(), 2);
350 assert_eq!(records[0].source, "KR");
351 assert_eq!(records[0].cost_center.as_deref(), Some("CC100"));
352 assert_eq!(records[1].cost_center, None);
353 assert_eq!(
354 records[0].entry_date,
355 NaiveDate::from_ymd_opt(2022, 4, 14).unwrap()
356 );
357 let _ = std::fs::remove_file(&parquet_path);
358 }
359}