Skip to main content

floe_core/checks/
not_null.rs

1use polars::prelude::{AnyValue, DataFrame};
2
3use super::{ColumnIndex, RowError};
4use crate::errors::RunError;
5use crate::FloeResult;
6
7pub fn not_null_errors(
8    df: &DataFrame,
9    required_cols: &[String],
10    indices: &ColumnIndex,
11) -> FloeResult<Vec<Vec<RowError>>> {
12    let mut errors_per_row = vec![Vec::new(); df.height()];
13    if required_cols.is_empty() {
14        return Ok(errors_per_row);
15    }
16
17    let mut null_masks = Vec::with_capacity(required_cols.len());
18    for name in required_cols {
19        let index = indices.get(name).ok_or_else(|| {
20            Box::new(RunError(format!(
21                "required column {name} not found in dataframe"
22            )))
23        })?;
24        let mask = df
25            .select_at_idx(*index)
26            .ok_or_else(|| {
27                Box::new(RunError(format!(
28                    "required column {name} not found in dataframe"
29                )))
30            })?
31            .is_null();
32        null_masks.push(mask);
33    }
34
35    for (row_idx, errors) in errors_per_row.iter_mut().enumerate() {
36        for (col, mask) in required_cols.iter().zip(null_masks.iter()) {
37            if mask.get(row_idx).unwrap_or(false) {
38                errors.push(RowError::new("not_null", col, "required value missing"));
39            }
40        }
41    }
42
43    Ok(errors_per_row)
44}
45
46pub fn not_null_counts(df: &DataFrame, required_cols: &[String]) -> FloeResult<Vec<(String, u64)>> {
47    if required_cols.is_empty() || df.height() == 0 {
48        return Ok(Vec::new());
49    }
50
51    let null_counts = df.null_count();
52    let mut counts = Vec::new();
53    for name in required_cols {
54        let series = null_counts.column(name).map_err(|err| {
55            Box::new(RunError(format!("required column {name} not found: {err}")))
56        })?;
57        let value = series.get(0).unwrap_or(AnyValue::UInt32(0));
58        let violations = match value {
59            AnyValue::UInt32(value) => value as u64,
60            AnyValue::UInt64(value) => value,
61            AnyValue::Int64(value) => value.max(0) as u64,
62            AnyValue::Int32(value) => value.max(0) as u64,
63            AnyValue::Null => 0,
64            _ => 0,
65        };
66        if violations > 0 {
67            counts.push((name.clone(), violations));
68        }
69    }
70    Ok(counts)
71}