csv_managed/
stats.rs

1use std::collections::HashMap;
2
3use anyhow::{Context, Result, anyhow, bail};
4use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc};
5use encoding_rs::Encoding;
6use log::info;
7
8use crate::{
9    cli::StatsArgs,
10    data::Value,
11    filter, frequency, io_utils,
12    rows::{evaluate_filter_expressions, parse_typed_row},
13    schema::{self, ColumnType, DecimalSpec, Schema},
14    table,
15};
16
17pub fn execute(args: &StatsArgs) -> Result<()> {
18    if args.schema.is_none() && io_utils::is_dash(&args.input) {
19        return Err(anyhow!(
20            "Reading from stdin requires --schema (or --meta) for stats operations"
21        ));
22    }
23
24    let delimiter = io_utils::resolve_input_delimiter(&args.input, args.delimiter);
25    let encoding = io_utils::resolve_encoding(args.input_encoding.as_deref())?;
26
27    let schema = load_or_infer_schema(args, delimiter, encoding)?;
28
29    let columns = resolve_columns(&schema, &args.columns, args.frequency)?;
30    if columns.is_empty() {
31        if args.frequency {
32            return Err(anyhow!(
33                "No columns available for frequency analysis. Supply --columns to continue."
34            ));
35        }
36        return Err(anyhow!(
37            "No numeric or temporal columns available. Provide a schema file or explicit column list."
38        ));
39    }
40
41    let filters = filter::parse_filters(&args.filters)?;
42
43    if args.frequency {
44        let freq_options = frequency::FrequencyOptions {
45            top: args.top,
46            row_limit: (args.limit > 0).then_some(args.limit),
47            filters: &filters,
48            filter_exprs: &args.filter_exprs,
49        };
50        let rows = frequency::compute_frequency_rows(
51            &args.input,
52            &schema,
53            delimiter,
54            encoding,
55            &columns,
56            &freq_options,
57        )?;
58        let headers = vec![
59            "column".to_string(),
60            "value".to_string(),
61            "count".to_string(),
62            "percent".to_string(),
63        ];
64        table::print_table(&headers, &rows);
65        info!("Computed frequency counts for {} column(s)", columns.len());
66        return Ok(());
67    }
68
69    let expects_headers = schema.expects_headers();
70    let mut reader = io_utils::open_csv_reader_from_path(&args.input, delimiter, expects_headers)?;
71    let headers = if expects_headers {
72        let headers = io_utils::reader_headers(&mut reader, encoding)?;
73        schema
74            .validate_headers(&headers)
75            .with_context(|| format!("Validating headers for {:?}", args.input))?;
76        headers
77    } else {
78        schema.headers()
79    };
80    let header_aliases = schema.header_alias_sets();
81
82    let mut stats = StatsAccumulator::new(&columns, &schema);
83
84    for (row_idx, record) in reader.byte_records().enumerate() {
85        if args.limit > 0 && row_idx >= args.limit {
86            break;
87        }
88        let record = record.with_context(|| format!("Reading row {}", row_idx + 2))?;
89        let mut decoded = io_utils::decode_record(&record, encoding)?;
90        if schema::row_looks_like_header(&decoded, &header_aliases) {
91            continue;
92        }
93        if schema.has_transformations() {
94            schema
95                .apply_transformations_to_row(&mut decoded)
96                .with_context(|| {
97                    format!(
98                        "Applying datatype mappings to row {} in {:?}",
99                        row_idx + 2,
100                        args.input
101                    )
102                })?;
103        }
104        schema.apply_replacements_to_row(&mut decoded);
105        let typed = parse_typed_row(&schema, &decoded)
106            .with_context(|| format!("Parsing row {}", row_idx + 2))?;
107        if !filters.is_empty()
108            && !filter::evaluate_conditions(&filters, &schema, &headers, &decoded, &typed)?
109        {
110            continue;
111        }
112        if !args.filter_exprs.is_empty()
113            && !evaluate_filter_expressions(
114                &args.filter_exprs,
115                &headers,
116                &decoded,
117                &typed,
118                Some(row_idx + 1),
119            )?
120        {
121            continue;
122        }
123        stats
124            .ingest(&typed)
125            .with_context(|| format!("Processing row {}", row_idx + 2))?;
126    }
127
128    let rows = stats.render_rows();
129    let headers = vec![
130        "column".to_string(),
131        "count".to_string(),
132        "min".to_string(),
133        "max".to_string(),
134        "mean".to_string(),
135        "median".to_string(),
136        "std_dev".to_string(),
137    ];
138    table::print_table(&headers, &rows);
139    info!("Computed summary statistics for {} column(s)", rows.len());
140    Ok(())
141}
142
143fn load_or_infer_schema(
144    args: &StatsArgs,
145    delimiter: u8,
146    encoding: &'static Encoding,
147) -> Result<Schema> {
148    if let Some(path) = &args.schema {
149        Schema::load(path).with_context(|| format!("Loading schema from {path:?}"))
150    } else {
151        schema::infer_schema(&args.input, 0, delimiter, encoding, None)
152            .with_context(|| format!("Inferring schema from {input:?}", input = args.input))
153    }
154}
155
156fn resolve_columns(
157    schema: &Schema,
158    specified: &[String],
159    frequency_mode: bool,
160) -> Result<Vec<usize>> {
161    if frequency_mode {
162        if specified.is_empty() {
163            Ok((0..schema.columns.len()).collect())
164        } else {
165            specified
166                .iter()
167                .map(|name| {
168                    schema
169                        .column_index(name)
170                        .ok_or_else(|| anyhow!("Column '{name}' not found in schema"))
171                })
172                .collect()
173        }
174    } else if specified.is_empty() {
175        Ok(schema
176            .columns
177            .iter()
178            .enumerate()
179            .filter(|(_, col)| is_supported_datatype(&col.datatype))
180            .map(|(idx, _)| idx)
181            .collect())
182    } else {
183        specified
184            .iter()
185            .map(|name| {
186                let idx = schema
187                    .column_index(name)
188                    .ok_or_else(|| anyhow!("Column '{name}' not found in schema"))?;
189                let column = &schema.columns[idx];
190                if !is_supported_datatype(&column.datatype) {
191                    return Err(anyhow!(
192                        "Column '{}' is type {:?} and cannot be profiled for statistics",
193                        column.output_name(),
194                        column.datatype
195                    ));
196                }
197                Ok(idx)
198            })
199            .collect()
200    }
201}
202
203fn is_supported_datatype(datatype: &ColumnType) -> bool {
204    matches!(
205        datatype,
206        ColumnType::Integer
207            | ColumnType::Float
208            | ColumnType::Currency
209            | ColumnType::Decimal(_)
210            | ColumnType::Date
211            | ColumnType::DateTime
212            | ColumnType::Time
213    )
214}
215
216struct StatsAccumulator {
217    columns: Vec<usize>,
218    data: HashMap<usize, ColumnStats>,
219}
220
221impl StatsAccumulator {
222    fn new(columns: &[usize], schema: &Schema) -> Self {
223        let mut data = HashMap::new();
224        for idx in columns {
225            let stats = ColumnStats::with_column(
226                schema.columns[*idx].output_name().to_string(),
227                schema.columns[*idx].datatype.clone(),
228            );
229            data.insert(*idx, stats);
230        }
231        Self {
232            columns: columns.to_vec(),
233            data,
234        }
235    }
236
237    fn ingest(&mut self, typed_row: &[Option<Value>]) -> Result<()> {
238        for column_index in &self.columns {
239            if let Some(stats) = self.data.get_mut(column_index)
240                && let Some(Some(value)) = typed_row.get(*column_index)
241            {
242                let column_name = stats.name.clone();
243                stats
244                    .add_value(value)
245                    .with_context(|| format!("Column '{}'", column_name))?;
246            }
247        }
248        Ok(())
249    }
250
251    fn render_rows(&self) -> Vec<Vec<String>> {
252        let mut rows = Vec::new();
253        for column_index in &self.columns {
254            if let Some(stats) = self.data.get(column_index) {
255                rows.push(stats.render_row());
256            }
257        }
258        rows
259    }
260}
261
262struct ColumnStats {
263    name: String,
264    datatype: ColumnType,
265    values: Vec<f64>,
266    sum: f64,
267    sum_squares: f64,
268    count: usize,
269    min: Option<f64>,
270    max: Option<f64>,
271    currency_scale: Option<u32>,
272    decimal_scale: Option<u32>,
273}
274
275impl ColumnStats {
276    fn with_column(name: String, datatype: ColumnType) -> Self {
277        Self {
278            name,
279            datatype,
280            values: Vec::new(),
281            sum: 0.0,
282            sum_squares: 0.0,
283            count: 0,
284            min: None,
285            max: None,
286            currency_scale: None,
287            decimal_scale: None,
288        }
289    }
290
291    fn add_value(&mut self, value: &Value) -> Result<()> {
292        if let (ColumnType::Currency, Value::Currency(currency)) = (&self.datatype, value) {
293            let scale = currency.scale();
294            self.currency_scale = Some(
295                self.currency_scale
296                    .map_or(scale, |current| current.max(scale)),
297            );
298        }
299        if let (ColumnType::Decimal(_), Value::Decimal(decimal)) = (&self.datatype, value) {
300            let scale = decimal.scale();
301            self.decimal_scale = Some(
302                self.decimal_scale
303                    .map_or(scale, |current| current.max(scale)),
304            );
305        }
306        let numeric = value_to_metric(value, &self.datatype)?;
307        self.count += 1;
308        self.sum += numeric;
309        self.sum_squares += numeric * numeric;
310        self.min = Some(match self.min {
311            Some(current) => current.min(numeric),
312            None => numeric,
313        });
314        self.max = Some(match self.max {
315            Some(current) => current.max(numeric),
316            None => numeric,
317        });
318        self.values.push(numeric);
319        Ok(())
320    }
321
322    fn mean(&self) -> Option<f64> {
323        if self.count > 0 {
324            Some(self.sum / self.count as f64)
325        } else {
326            None
327        }
328    }
329
330    fn median(&self) -> Option<f64> {
331        if self.values.is_empty() {
332            return None;
333        }
334        let mut sorted = self.values.clone();
335        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
336        let mid = sorted.len() / 2;
337        if sorted.len().is_multiple_of(2) {
338            Some((sorted[mid - 1] + sorted[mid]) / 2.0)
339        } else {
340            Some(sorted[mid])
341        }
342    }
343
344    fn std_dev(&self) -> Option<f64> {
345        if self.count < 2 {
346            return None;
347        }
348        let mean = self.mean()?;
349        let variance =
350            (self.sum_squares - self.count as f64 * mean * mean) / (self.count as f64 - 1.0);
351        Some(variance.max(0.0).sqrt())
352    }
353
354    fn render_row(&self) -> Vec<String> {
355        vec![
356            self.name.clone(),
357            self.count.to_string(),
358            self.format_metric(self.min),
359            self.format_metric(self.max),
360            self.format_metric(self.mean()),
361            self.format_metric(self.median()),
362            self.format_std_dev(self.std_dev()),
363        ]
364    }
365
366    fn format_metric(&self, metric: Option<f64>) -> String {
367        metric
368            .map(|value| {
369                format_metric(
370                    value,
371                    &self.datatype,
372                    self.currency_scale,
373                    self.decimal_scale,
374                )
375            })
376            .unwrap_or_default()
377    }
378
379    fn format_std_dev(&self, metric: Option<f64>) -> String {
380        metric
381            .map(|value| {
382                format_std_dev_value(
383                    value,
384                    &self.datatype,
385                    self.currency_scale,
386                    self.decimal_scale,
387                )
388            })
389            .unwrap_or_default()
390    }
391}
392
393fn format_number(value: f64) -> String {
394    if value.fract() == 0.0 {
395        format!("{value:.0}")
396    } else {
397        format!("{value:.4}")
398    }
399}
400
401fn value_to_metric(value: &Value, datatype: &ColumnType) -> Result<f64> {
402    match (datatype, value) {
403        (ColumnType::Integer, Value::Integer(i)) => Ok(*i as f64),
404        (ColumnType::Float, Value::Float(f)) => Ok(*f),
405        (ColumnType::Float, Value::Integer(i)) => Ok(*i as f64),
406        (ColumnType::Currency, Value::Currency(c)) => c
407            .to_f64()
408            .ok_or_else(|| anyhow!("Currency value out of range for statistics")),
409        (ColumnType::Decimal(_), Value::Decimal(d)) => d
410            .to_f64()
411            .ok_or_else(|| anyhow!("Decimal value out of range for statistics")),
412        (ColumnType::Date, Value::Date(d)) => Ok(date_to_metric(d)),
413        (ColumnType::DateTime, Value::DateTime(dt)) => Ok(datetime_to_metric(dt)),
414        (ColumnType::Time, Value::Time(t)) => Ok(time_to_metric(t)),
415        _ => bail!("Value {:?} incompatible with datatype {datatype:?}", value),
416    }
417}
418
419fn date_to_metric(date: &NaiveDate) -> f64 {
420    date.num_days_from_ce() as f64
421}
422
423fn datetime_to_metric(dt: &NaiveDateTime) -> f64 {
424    dt.and_utc().timestamp() as f64
425}
426
427fn time_to_metric(time: &NaiveTime) -> f64 {
428    time.num_seconds_from_midnight() as f64
429}
430
431fn metric_to_date(metric: f64) -> Option<NaiveDate> {
432    NaiveDate::from_num_days_from_ce_opt(metric.round() as i32)
433}
434
435fn metric_to_datetime(metric: f64) -> Option<NaiveDateTime> {
436    if metric.is_nan() || metric.is_infinite() {
437        return None;
438    }
439    DateTime::<Utc>::from_timestamp(metric.round() as i64, 0).map(|dt| dt.naive_utc())
440}
441
442fn metric_to_time(metric: f64) -> Option<NaiveTime> {
443    let mut seconds = metric.round();
444    if seconds.is_nan() || seconds.is_infinite() {
445        return None;
446    }
447    if seconds < 0.0 {
448        seconds = 0.0;
449    }
450    if seconds >= 86_400.0 {
451        seconds = 86_399.0;
452    }
453    NaiveTime::from_num_seconds_from_midnight_opt(seconds as u32, 0)
454}
455
456fn format_metric(
457    value: f64,
458    datatype: &ColumnType,
459    currency_scale: Option<u32>,
460    decimal_scale: Option<u32>,
461) -> String {
462    match datatype {
463        ColumnType::Integer | ColumnType::Float => format_number(value),
464        ColumnType::Currency => format_currency_number(value, currency_scale),
465        ColumnType::Decimal(spec) => format_decimal_number(value, spec, decimal_scale),
466        ColumnType::Date => metric_to_date(value)
467            .map(|d| d.format("%Y-%m-%d").to_string())
468            .unwrap_or_default(),
469        ColumnType::DateTime => metric_to_datetime(value)
470            .map(|dt| dt.format("%Y-%m-%d %H:%M:%S").to_string())
471            .unwrap_or_default(),
472        ColumnType::Time => metric_to_time(value)
473            .map(|t| t.format("%H:%M:%S").to_string())
474            .unwrap_or_default(),
475        _ => String::new(),
476    }
477}
478
479fn format_std_dev_value(
480    value: f64,
481    datatype: &ColumnType,
482    currency_scale: Option<u32>,
483    decimal_scale: Option<u32>,
484) -> String {
485    match datatype {
486        ColumnType::Integer | ColumnType::Float => format_number(value),
487        ColumnType::Currency => format_currency_number(value, currency_scale),
488        ColumnType::Decimal(spec) => format_decimal_number(value, spec, decimal_scale),
489        ColumnType::Date => format_duration(value, "days"),
490        ColumnType::DateTime | ColumnType::Time => format_duration(value, "seconds"),
491        _ => String::new(),
492    }
493}
494
495fn format_currency_number(value: f64, scale: Option<u32>) -> String {
496    if value.is_nan() || value.is_infinite() {
497        return String::new();
498    }
499    let digits = match scale.unwrap_or(2) {
500        4 => 4,
501        _ => 2,
502    };
503    format!("{value:.precision$}", precision = digits as usize)
504}
505
506fn format_decimal_number(value: f64, spec: &DecimalSpec, observed_scale: Option<u32>) -> String {
507    if value.is_nan() || value.is_infinite() {
508        return String::new();
509    }
510    let digits = observed_scale.unwrap_or(spec.scale) as usize;
511    if digits == 0 {
512        format!("{value:.0}")
513    } else {
514        format!("{value:.precision$}", precision = digits)
515    }
516}
517
518fn format_duration(value: f64, unit: &str) -> String {
519    let magnitude = format_number(value.abs());
520    if magnitude.is_empty() {
521        return String::new();
522    }
523    if value < 0.0 {
524        format!("-{magnitude} {unit}")
525    } else {
526        format!("{magnitude} {unit}")
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533    use encoding_rs::UTF_8;
534
535    const DATA_FILE: &str = "big_5_players_stats_2023_2024.csv";
536    const GOALS_COL: &str = "Performance_Gls";
537    const ASSISTS_COL: &str = "Performance_Ast";
538
539    fn fixture_path() -> std::path::PathBuf {
540        std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
541            .join("tests")
542            .join("data")
543            .join(DATA_FILE)
544    }
545
546    #[test]
547    fn accumulator_computes_stats_for_big5_subset() {
548        let path = fixture_path();
549        assert!(path.exists(), "fixture missing: {path:?}");
550        let delimiter = crate::io_utils::resolve_input_delimiter(&path, None);
551        let mut schema =
552            crate::schema::infer_schema(&path, 200, delimiter, UTF_8, None).expect("infer schema");
553        let goals_index = schema.column_index(GOALS_COL).expect("goals index");
554        let assists_index = schema.column_index(ASSISTS_COL).expect("assists index");
555        schema.columns[goals_index].datatype = crate::schema::ColumnType::Integer;
556        schema.columns[assists_index].datatype = crate::schema::ColumnType::Integer;
557        let columns = vec![goals_index, assists_index];
558        let mut accumulator = StatsAccumulator::new(&columns, &schema);
559        let mut reader =
560            crate::io_utils::open_csv_reader_from_path(&path, delimiter, true).expect("open csv");
561        crate::io_utils::reader_headers(&mut reader, UTF_8).expect("headers");
562
563        for (idx, record) in reader.byte_records().enumerate() {
564            if idx >= 100 {
565                break;
566            }
567            let record = record.expect("record");
568            let mut decoded = crate::io_utils::decode_record(&record, UTF_8).expect("decode");
569            schema.apply_replacements_to_row(&mut decoded);
570            let typed = match crate::rows::parse_typed_row(&schema, &decoded) {
571                Ok(values) => values,
572                Err(_) => continue,
573            };
574            if accumulator.ingest(&typed).is_err() {
575                continue;
576            }
577        }
578
579        let rows = accumulator.render_rows();
580        assert_eq!(rows.len(), columns.len());
581        let goal_stats = rows
582            .iter()
583            .find(|row| row[0] == GOALS_COL)
584            .expect("goal stats");
585        assert_ne!(goal_stats[1], "0");
586        assert!(!goal_stats[4].is_empty());
587    }
588}