use std::collections::{BTreeMap, HashSet};
use std::path::Path;
use arrow::array::Array;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use datasynth_core::distributions::behavioral_priors::ManualSharePrior;
use crate::error::{FingerprintError, FingerprintResult};
use super::tb_extractor::{find_column_index, float64_column, string_column};
pub const DEFAULT_MIN_MANUAL_OBSERVATIONS: usize = 100;
const MANUAL_CANDIDATES: &[&str] = &["systemmanual", "system_manual", "system manual", "manual"];
const SOURCE_CANDIDATES: &[&str] = &["source", "blart", "doctype"];
const JE_NUMBER_CANDIDATES: &[&str] = &["je number", "je_number", "document_number", "doc_number"];
const AMOUNT_CANDIDATES: &[&str] = &["functional amount", "functional_amount", "amount"];
pub fn extract_manual_share_from_parquet(
path: &Path,
min_observations_per_source: usize,
) -> FingerprintResult<Option<ManualSharePrior>> {
let file = std::fs::File::open(path).map_err(|e| {
FingerprintError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("JE parquet open failed: {e}"),
))
})?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file).map_err(|e| {
FingerprintError::InvalidFormat(format!("JE parquet: cannot build reader: {e}"))
})?;
let col_names: Vec<String> = builder
.schema()
.fields()
.iter()
.map(|f| f.name().to_lowercase())
.collect();
let Some(manual_idx) = find_column_index(&col_names, MANUAL_CANDIDATES) else {
return Ok(None);
};
let source_idx = find_column_index(&col_names, SOURCE_CANDIDATES);
let je_idx = find_column_index(&col_names, JE_NUMBER_CANDIDATES);
let amount_idx = find_column_index(&col_names, AMOUNT_CANDIDATES);
let reader = builder.build().map_err(|e| {
FingerprintError::InvalidFormat(format!("JE parquet: cannot open reader: {e}"))
})?;
let mut total = 0usize;
let mut total_manual = 0usize;
let mut per_source: BTreeMap<String, (usize, usize)> = BTreeMap::new();
let mut seen_jes: HashSet<String> = HashSet::new();
for batch_res in reader {
let batch = batch_res.map_err(|e| {
FingerprintError::InvalidFormat(format!("JE parquet: batch read error: {e}"))
})?;
let Some(manual_arr) = string_column(&batch, manual_idx) else {
continue;
};
let source_arr = source_idx.and_then(|i| string_column(&batch, i));
let je_arr = je_idx.and_then(|i| string_column(&batch, i));
let amount_vals = amount_idx.map(|i| float64_column(&batch, i));
for row in 0..batch.num_rows() {
if manual_arr.is_null(row) {
continue;
}
let raw = manual_arr.value(row).trim();
if raw.is_empty() {
continue;
}
if let Some(amts) = &amount_vals {
match amts[row] {
Some(v) if v != 0.0 => {}
_ => continue,
}
}
if let Some(jes) = &je_arr {
if !jes.is_null(row) {
let je = jes.value(row).trim();
if !je.is_empty() && !seen_jes.insert(je.to_string()) {
continue;
}
}
}
let is_manual = raw.to_lowercase().contains("manual");
total += 1;
if is_manual {
total_manual += 1;
}
if let Some(sources) = &source_arr {
if !sources.is_null(row) {
let source = sources.value(row).trim();
if !source.is_empty() {
let entry = per_source.entry(source.to_string()).or_insert((0, 0));
entry.0 += 1;
if is_manual {
entry.1 += 1;
}
}
}
}
}
}
if total == 0 {
return Ok(None);
}
let by_source: BTreeMap<String, f64> = per_source
.into_iter()
.filter(|(_, (n, _))| *n >= min_observations_per_source)
.map(|(source, (n, manual))| (source, manual as f64 / n as f64))
.collect();
Ok(Some(ManualSharePrior {
overall: total_manual as f64 / total as f64,
by_source,
n_observations: total,
}))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::StringArray;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use parquet::arrow::ArrowWriter;
use std::fs::File;
use std::sync::Arc;
fn write_parquet(path: &Path, source_col: &str, rows: &[(&str, &str)]) {
let schema = Arc::new(Schema::new(vec![
Field::new(source_col, DataType::Utf8, true),
Field::new("SystemManual", DataType::Utf8, true),
]));
let sources: Vec<Option<&str>> = rows.iter().map(|(s, _)| Some(*s)).collect();
let manuals: Vec<Option<&str>> = rows.iter().map(|(_, m)| Some(*m)).collect();
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(sources)),
Arc::new(StringArray::from(manuals)),
],
)
.expect("batch");
let file = File::create(path).expect("create");
let mut writer = ArrowWriter::try_new(file, schema, None).expect("writer");
writer.write(&batch).expect("write");
writer.close().expect("close");
}
#[test]
fn shares_computed_per_source_and_overall() {
let tmp = tempfile::tempdir().expect("tmpdir");
let path = tmp.path().join("je.parquet");
let mut rows: Vec<(&str, &str)> = Vec::new();
rows.extend([("SA", "Manual"); 3]);
rows.push(("SA", "System"));
rows.push(("RE", "Manual"));
rows.extend([("RE", "System"); 3]);
write_parquet(&path, "Source", &rows);
let ms = extract_manual_share_from_parquet(&path, 1)
.expect("extract")
.expect("Some");
assert_eq!(ms.n_observations, 8);
assert!((ms.overall - 0.5).abs() < 1e-12);
assert!((ms.by_source["SA"] - 0.75).abs() < 1e-12);
assert!((ms.by_source["RE"] - 0.25).abs() < 1e-12);
}
#[test]
fn sources_below_observation_gate_roll_into_overall_only() {
let tmp = tempfile::tempdir().expect("tmpdir");
let path = tmp.path().join("je.parquet");
let mut rows: Vec<(&str, &str)> = Vec::new();
rows.extend([("SA", "Manual"); 10]);
rows.push(("ZZ", "System")); write_parquet(&path, "Source", &rows);
let ms = extract_manual_share_from_parquet(&path, 5)
.expect("extract")
.expect("Some");
assert_eq!(ms.n_observations, 11);
assert!(ms.by_source.contains_key("SA"));
assert!(
!ms.by_source.contains_key("ZZ"),
"below-gate source must not be emitted per-source"
);
assert!((ms.overall - 10.0 / 11.0).abs() < 1e-12);
}
#[test]
fn missing_manual_column_yields_none() {
let tmp = tempfile::tempdir().expect("tmpdir");
let path = tmp.path().join("je.parquet");
let schema = Arc::new(Schema::new(vec![Field::new(
"Source",
DataType::Utf8,
true,
)]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(StringArray::from(vec![Some("SA")]))],
)
.expect("batch");
let file = File::create(&path).expect("create");
let mut writer = ArrowWriter::try_new(file, schema, None).expect("writer");
writer.write(&batch).expect("write");
writer.close().expect("close");
let ms = extract_manual_share_from_parquet(&path, 1).expect("extract");
assert!(ms.is_none(), "no indicator column must yield None");
}
#[test]
fn je_number_column_switches_shares_to_per_je() {
let tmp = tempfile::tempdir().expect("tmpdir");
let path = tmp.path().join("je.parquet");
let schema = Arc::new(Schema::new(vec![
Field::new("JE Number", DataType::Utf8, true),
Field::new("Source", DataType::Utf8, true),
Field::new("SystemManual", DataType::Utf8, true),
]));
let rows: Vec<(&str, &str, &str)> = vec![
("2024/J1", "SA", "Manual"),
("2024/J1", "SA", "Manual"),
("2024/J1", "SA", "Manual"),
("2024/J2", "SA", "System"),
("2024/J3", "RE", "Manual"),
("2024/J4", "RE", "System"),
("2024/J4", "RE", "System"),
("2024/J4", "RE", "System"),
];
let je: Vec<Option<&str>> = rows.iter().map(|(j, _, _)| Some(*j)).collect();
let src: Vec<Option<&str>> = rows.iter().map(|(_, s, _)| Some(*s)).collect();
let man: Vec<Option<&str>> = rows.iter().map(|(_, _, m)| Some(*m)).collect();
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(je)),
Arc::new(StringArray::from(src)),
Arc::new(StringArray::from(man)),
],
)
.expect("batch");
let file = File::create(&path).expect("create");
let mut writer = ArrowWriter::try_new(file, schema, None).expect("writer");
writer.write(&batch).expect("write");
writer.close().expect("close");
let ms = extract_manual_share_from_parquet(&path, 1)
.expect("extract")
.expect("Some");
assert_eq!(ms.n_observations, 4, "JEs counted, not lines");
assert!((ms.overall - 0.5).abs() < 1e-12);
assert!(
(ms.by_source["SA"] - 0.5).abs() < 1e-12,
"SA per-JE 0.5, not line-level 0.75"
);
assert!(
(ms.by_source["RE"] - 0.5).abs() < 1e-12,
"RE per-JE 0.5, not line-level 0.25"
);
}
#[test]
fn zero_amount_jes_excluded_when_amount_column_present() {
use arrow::array::Float64Array;
let tmp = tempfile::tempdir().expect("tmpdir");
let path = tmp.path().join("je.parquet");
let schema = Arc::new(Schema::new(vec![
Field::new("JE Number", DataType::Utf8, true),
Field::new("Source", DataType::Utf8, true),
Field::new("SystemManual", DataType::Utf8, true),
Field::new("Functional Amount", DataType::Float64, true),
]));
let rows: Vec<(&str, &str, &str, f64)> = vec![
("2024/J1", "SA", "Manual", 0.0),
("2024/J1", "SA", "Manual", 100.0),
("2024/J2", "SA", "Manual", 0.0),
("2024/J3", "SA", "System", 50.0),
];
let je: Vec<Option<&str>> = rows.iter().map(|(j, _, _, _)| Some(*j)).collect();
let src: Vec<Option<&str>> = rows.iter().map(|(_, s, _, _)| Some(*s)).collect();
let man: Vec<Option<&str>> = rows.iter().map(|(_, _, m, _)| Some(*m)).collect();
let amt: Vec<Option<f64>> = rows.iter().map(|(_, _, _, a)| Some(*a)).collect();
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(je)),
Arc::new(StringArray::from(src)),
Arc::new(StringArray::from(man)),
Arc::new(Float64Array::from(amt)),
],
)
.expect("batch");
let file = File::create(&path).expect("create");
let mut writer = ArrowWriter::try_new(file, schema, None).expect("writer");
writer.write(&batch).expect("write");
writer.close().expect("close");
let ms = extract_manual_share_from_parquet(&path, 1)
.expect("extract")
.expect("Some");
assert_eq!(ms.n_observations, 2, "zero-amount JE must not count");
assert!(
(ms.overall - 0.5).abs() < 1e-12,
"1 manual of 2 amount-bearing JEs"
);
assert!((ms.by_source["SA"] - 0.5).abs() < 1e-12);
}
#[test]
fn indicator_matching_is_case_insensitive_and_alias_tolerant() {
let tmp = tempfile::tempdir().expect("tmpdir");
let path = tmp.path().join("je.parquet");
let schema = Arc::new(Schema::new(vec![
Field::new("Source", DataType::Utf8, true),
Field::new("system_manual", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec![
Some("SA"),
Some("SA"),
Some("SA"),
Some("SA"),
])),
Arc::new(StringArray::from(vec![
Some("MANUAL"),
Some("manual"),
Some("SYSTEM"),
Some(""),
])),
],
)
.expect("batch");
let file = File::create(&path).expect("create");
let mut writer = ArrowWriter::try_new(file, schema, None).expect("writer");
writer.write(&batch).expect("write");
writer.close().expect("close");
let ms = extract_manual_share_from_parquet(&path, 1)
.expect("extract")
.expect("Some");
assert_eq!(ms.n_observations, 3);
assert!((ms.overall - 2.0 / 3.0).abs() < 1e-12);
assert!((ms.by_source["SA"] - 2.0 / 3.0).abs() < 1e-12);
}
}