Skip to main content

datasynth_eval/behavioral_fidelity/
loader.rs

1//! Parquet / CSV loader producing canonical `Record`s.
2
3use std::collections::HashMap;
4use std::fs::File;
5use std::path::Path;
6
7use arrow::array::{
8    Array, Date32Array, Float64Array, StringArray, TimestampMicrosecondArray,
9    TimestampMillisecondArray, TimestampNanosecondArray,
10};
11use arrow::record_batch::RecordBatch;
12use chrono::{DateTime, NaiveDate, TimeZone, Utc};
13use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
14
15use super::entity_profile::{reference_corpus_aliases, synthetic_aliases};
16use super::error::{BehavioralFidelityError, BehavioralFidelityResult};
17use super::types::Record;
18
19/// Load all records from a parquet file (or single-file directory).
20pub fn load_parquet_records(path: &Path) -> BehavioralFidelityResult<Vec<Record>> {
21    let path = resolve_single_file(path, "parquet")?;
22    let file = File::open(&path)?;
23    let builder = ParquetRecordBatchReaderBuilder::try_new(file)
24        .map_err(|e| BehavioralFidelityError::Parquet(e.to_string()))?;
25    let reader = builder
26        .build()
27        .map_err(|e| BehavioralFidelityError::Parquet(e.to_string()))?;
28    let mut out = Vec::new();
29    let mut skipped = 0usize;
30    for batch_res in reader {
31        let batch = batch_res.map_err(|e| BehavioralFidelityError::Parquet(e.to_string()))?;
32        let aliases = pick_alias_map(&batch);
33        for row in 0..batch.num_rows() {
34            match extract_row(&batch, row, &aliases) {
35                Ok(rec) => out.push(rec),
36                Err(BehavioralFidelityError::Schema(_)) => skipped += 1,
37                Err(e) => return Err(e),
38            }
39        }
40    }
41    if skipped > 0 {
42        tracing::warn!(
43            "load_parquet_records: skipped {} rows with missing/malformed required dates",
44            skipped
45        );
46    }
47    Ok(out)
48}
49
50/// Load all records from a CSV file. Header line is required and is used
51/// to pick the alias map.
52pub fn load_csv_records(path: &Path) -> BehavioralFidelityResult<Vec<Record>> {
53    let path = resolve_single_file(path, "csv")?;
54    let mut rdr = csv::ReaderBuilder::new()
55        .has_headers(true)
56        .from_path(&path)
57        .map_err(|e| BehavioralFidelityError::Io(std::io::Error::other(e.to_string())))?;
58    let headers: Vec<String> = rdr
59        .headers()
60        .map_err(|e| BehavioralFidelityError::Io(std::io::Error::other(e.to_string())))?
61        .iter()
62        .map(|s| s.to_string())
63        .collect();
64    let aliases = pick_alias_map_from_headers(&headers);
65    let header_idx: HashMap<&str, usize> = headers
66        .iter()
67        .enumerate()
68        .map(|(i, h)| (h.as_str(), i))
69        .collect();
70    let mut out = Vec::new();
71    let mut skipped = 0usize;
72    for rec in rdr.records() {
73        let rec =
74            rec.map_err(|e| BehavioralFidelityError::Io(std::io::Error::other(e.to_string())))?;
75        match extract_csv_row(&rec, &header_idx, &aliases) {
76            Ok(record) => out.push(record),
77            Err(BehavioralFidelityError::Schema(_)) => skipped += 1,
78            Err(e) => return Err(e),
79        }
80    }
81    if skipped > 0 {
82        tracing::warn!(
83            "load_csv_records: skipped {} rows with missing/malformed required dates",
84            skipped
85        );
86    }
87    Ok(out)
88}
89
90fn resolve_single_file(
91    path: &Path,
92    extension: &str,
93) -> BehavioralFidelityResult<std::path::PathBuf> {
94    if path.is_file() {
95        return Ok(path.to_path_buf());
96    }
97    if path.is_dir() {
98        for entry in std::fs::read_dir(path)? {
99            let entry = entry?;
100            if entry
101                .path()
102                .extension()
103                .is_some_and(|e| e.eq_ignore_ascii_case(extension))
104            {
105                return Ok(entry.path());
106            }
107        }
108    }
109    Err(BehavioralFidelityError::Io(std::io::Error::other(format!(
110        "no {extension} file at {}",
111        path.display()
112    ))))
113}
114
115fn pick_alias_map(batch: &RecordBatch) -> HashMap<&'static str, &'static str> {
116    let schema = batch.schema();
117    let cols: Vec<String> = schema
118        .fields()
119        .iter()
120        .map(|f| f.name().to_string())
121        .collect();
122    pick_alias_map_from_headers(&cols)
123}
124
125fn pick_alias_map_from_headers(cols: &[String]) -> HashMap<&'static str, &'static str> {
126    let has = |needle: &str| cols.iter().any(|c| c == needle);
127    if has("Tarding Partner") || has("Functional Amount") {
128        reference_corpus_aliases().into_iter().collect()
129    } else {
130        synthetic_aliases().into_iter().collect()
131    }
132}
133
134fn extract_row(
135    batch: &RecordBatch,
136    row: usize,
137    aliases: &HashMap<&'static str, &'static str>,
138) -> BehavioralFidelityResult<Record> {
139    let s = batch.schema();
140    let col_idx = |canon: &str| -> Option<usize> {
141        let real = aliases.get(canon)?;
142        s.fields().iter().position(|f| f.name() == *real)
143    };
144    let str_at = |canon: &str| -> Option<String> {
145        let i = col_idx(canon)?;
146        let arr = batch.column(i).as_any().downcast_ref::<StringArray>()?;
147        if arr.is_null(row) {
148            None
149        } else {
150            Some(arr.value(row).to_string())
151        }
152    };
153    let f64_at = |canon: &str| -> Option<f64> {
154        let i = col_idx(canon)?;
155        let arr = batch.column(i).as_any().downcast_ref::<Float64Array>()?;
156        if arr.is_null(row) {
157            None
158        } else {
159            Some(arr.value(row))
160        }
161    };
162    let date_at = |canon: &str| -> Option<NaiveDate> {
163        let i = col_idx(canon)?;
164        if let Some(arr) = batch.column(i).as_any().downcast_ref::<StringArray>() {
165            if arr.is_null(row) {
166                return None;
167            }
168            return NaiveDate::parse_from_str(arr.value(row), "%Y-%m-%d").ok();
169        }
170        if let Some(arr) = batch.column(i).as_any().downcast_ref::<Date32Array>() {
171            if arr.is_null(row) {
172                return None;
173            }
174            return arr.value_as_date(row);
175        }
176        None
177    };
178    let ts_at = |canon: &str| -> Option<DateTime<Utc>> {
179        let i = col_idx(canon)?;
180        if let Some(arr) = batch
181            .column(i)
182            .as_any()
183            .downcast_ref::<TimestampMillisecondArray>()
184        {
185            if arr.is_null(row) {
186                return None;
187            }
188            return Utc.timestamp_millis_opt(arr.value(row)).single();
189        }
190        if let Some(arr) = batch
191            .column(i)
192            .as_any()
193            .downcast_ref::<TimestampMicrosecondArray>()
194        {
195            if arr.is_null(row) {
196                return None;
197            }
198            return Utc.timestamp_micros(arr.value(row)).single();
199        }
200        if let Some(arr) = batch
201            .column(i)
202            .as_any()
203            .downcast_ref::<TimestampNanosecondArray>()
204        {
205            if arr.is_null(row) {
206                return None;
207            }
208            let nanos = arr.value(row);
209            return Some(Utc.timestamp_nanos(nanos));
210        }
211        if let Some(arr) = batch.column(i).as_any().downcast_ref::<StringArray>() {
212            if arr.is_null(row) {
213                return None;
214            }
215            return DateTime::parse_from_rfc3339(arr.value(row))
216                .ok()
217                .map(|dt| dt.with_timezone(&Utc));
218        }
219        None
220    };
221
222    Ok(Record {
223        source: str_at("Source").unwrap_or_default(),
224        gl_account: str_at("GLAccount").unwrap_or_default(),
225        cost_center: str_at("CostCenter"),
226        profit_center: str_at("ProfitCenter"),
227        trading_partner: str_at("TradingPartner"),
228        je_number: str_at("JENumber").unwrap_or_default(),
229        je_line_number: str_at("JELineNumber").unwrap_or_default(),
230        effective_date: date_at("EffectiveDate")
231            .ok_or_else(|| BehavioralFidelityError::Schema("missing EffectiveDate".into()))?,
232        entry_date: date_at("EntryDate")
233            .ok_or_else(|| BehavioralFidelityError::Schema("missing EntryDate".into()))?,
234        created_at: ts_at("CreatedAt"),
235        functional_amount: f64_at("FunctionalAmount").unwrap_or(0.0),
236        // SP4.4 W7.3 — header/line text from JE Description / JE Line Description columns.
237        header_text: str_at("HeaderText").unwrap_or_default(),
238        line_text: str_at("LineText").unwrap_or_default(),
239    })
240}
241
242fn extract_csv_row(
243    rec: &csv::StringRecord,
244    headers: &HashMap<&str, usize>,
245    aliases: &HashMap<&'static str, &'static str>,
246) -> BehavioralFidelityResult<Record> {
247    let get = |canon: &str| -> Option<String> {
248        let real = aliases.get(canon)?;
249        let i = headers.get(*real)?;
250        let v = rec.get(*i)?;
251        if v.is_empty() {
252            None
253        } else {
254            Some(v.to_string())
255        }
256    };
257    Ok(Record {
258        source: get("Source").unwrap_or_default(),
259        gl_account: get("GLAccount").unwrap_or_default(),
260        cost_center: get("CostCenter"),
261        profit_center: get("ProfitCenter"),
262        trading_partner: get("TradingPartner"),
263        je_number: get("JENumber").unwrap_or_default(),
264        je_line_number: get("JELineNumber").unwrap_or_default(),
265        effective_date: get("EffectiveDate")
266            .and_then(|s| NaiveDate::parse_from_str(&s, "%Y-%m-%d").ok())
267            .ok_or_else(|| BehavioralFidelityError::Schema("missing EffectiveDate".into()))?,
268        entry_date: get("EntryDate")
269            .and_then(|s| NaiveDate::parse_from_str(&s, "%Y-%m-%d").ok())
270            .ok_or_else(|| BehavioralFidelityError::Schema("missing EntryDate".into()))?,
271        created_at: get("CreatedAt").and_then(|s| {
272            DateTime::parse_from_rfc3339(&s)
273                .ok()
274                .map(|d| d.with_timezone(&Utc))
275        }),
276        functional_amount: get("FunctionalAmount")
277            .and_then(|s| s.parse().ok())
278            .unwrap_or(0.0),
279        // SP4.4 W7.3 — header/line text from JE Description / JE Line Description columns.
280        header_text: get("HeaderText").unwrap_or_default(),
281        line_text: get("LineText").unwrap_or_default(),
282    })
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use std::sync::Arc;
289
290    use arrow::array::{Float64Array, StringArray};
291    use arrow::datatypes::{DataType, Field, Schema};
292    use arrow::record_batch::RecordBatch;
293    use parquet::arrow::ArrowWriter;
294    use tempfile::NamedTempFile;
295
296    fn build_reference_corpus_batch() -> RecordBatch {
297        let schema = Arc::new(Schema::new(vec![
298            Field::new("JE Number", DataType::Utf8, false),
299            Field::new("GL Account Number", DataType::Utf8, false),
300            Field::new("Functional Amount", DataType::Float64, false),
301            Field::new("Effective Date", DataType::Utf8, false),
302            Field::new("Entry Date", DataType::Utf8, false),
303            Field::new("Source", DataType::Utf8, false),
304            Field::new("Cost Center", DataType::Utf8, true),
305            Field::new("Profit Center", DataType::Utf8, true),
306            Field::new("Tarding Partner", DataType::Utf8, true),
307            Field::new("JE Line Number", DataType::Utf8, false),
308        ]));
309        let arr_je = StringArray::from(vec!["2022-0090-001", "2022-0090-001"]);
310        let arr_gl = StringArray::from(vec!["1100", "2000"]);
311        let arr_amt = Float64Array::from(vec![100.0, -100.0]);
312        let arr_eff = StringArray::from(vec!["2022-04-25", "2022-04-25"]);
313        let arr_ent = StringArray::from(vec!["2022-04-14", "2022-04-14"]);
314        let arr_src = StringArray::from(vec!["KR", "KR"]);
315        let arr_cc = StringArray::from(vec![Some("CC100"), None]);
316        let arr_pc = StringArray::from(vec![Some("PC100"), None]);
317        let arr_tp = StringArray::from(vec![Some("TP1"), None]);
318        let arr_line = StringArray::from(vec!["001", "002"]);
319        RecordBatch::try_new(
320            schema,
321            vec![
322                Arc::new(arr_je),
323                Arc::new(arr_gl),
324                Arc::new(arr_amt),
325                Arc::new(arr_eff),
326                Arc::new(arr_ent),
327                Arc::new(arr_src),
328                Arc::new(arr_cc),
329                Arc::new(arr_pc),
330                Arc::new(arr_tp),
331                Arc::new(arr_line),
332            ],
333        )
334        .unwrap()
335    }
336
337    #[test]
338    fn load_parquet_reference_corpus_shape() {
339        let batch = build_reference_corpus_batch();
340        let tmp = NamedTempFile::new().unwrap();
341        let parquet_path = tmp.path().with_extension("parquet");
342        {
343            let file = File::create(&parquet_path).unwrap();
344            let mut writer = ArrowWriter::try_new(file, batch.schema(), None).unwrap();
345            writer.write(&batch).unwrap();
346            writer.close().unwrap();
347        }
348        let records = load_parquet_records(&parquet_path).unwrap();
349        assert_eq!(records.len(), 2);
350        assert_eq!(records[0].source, "KR");
351        assert_eq!(records[0].cost_center.as_deref(), Some("CC100"));
352        assert_eq!(records[1].cost_center, None);
353        assert_eq!(
354            records[0].entry_date,
355            NaiveDate::from_ymd_opt(2022, 4, 14).unwrap()
356        );
357        let _ = std::fs::remove_file(&parquet_path);
358    }
359}