Skip to main content

floe_core/checks/
not_null.rs

1use polars::prelude::{col, lit, AnyValue, DataFrame, Expr, NULL};
2
3use super::{ColumnIndex, RowError, SparseRowErrors};
4use crate::errors::RunError;
5use crate::FloeResult;
6
7pub fn not_null_errors(
8    df: &DataFrame,
9    required_cols: &[String],
10    indices: &ColumnIndex,
11) -> FloeResult<Vec<Vec<RowError>>> {
12    let mut errors_per_row = vec![Vec::new(); df.height()];
13    if required_cols.is_empty() {
14        return Ok(errors_per_row);
15    }
16
17    let mut null_masks = Vec::with_capacity(required_cols.len());
18    for name in required_cols {
19        let index = indices.get(name).ok_or_else(|| {
20            Box::new(RunError(format!(
21                "required column {name} not found in dataframe"
22            )))
23        })?;
24        let mask = df
25            .select_at_idx(*index)
26            .ok_or_else(|| {
27                Box::new(RunError(format!(
28                    "required column {name} not found in dataframe"
29                )))
30            })?
31            .is_null();
32        null_masks.push(mask);
33    }
34
35    for (row_idx, errors) in errors_per_row.iter_mut().enumerate() {
36        for (col, mask) in required_cols.iter().zip(null_masks.iter()) {
37            if mask.get(row_idx).unwrap_or(false) {
38                errors.push(RowError::new("not_null", col, "required value missing"));
39            }
40        }
41    }
42
43    Ok(errors_per_row)
44}
45
46pub fn not_null_errors_sparse(
47    df: &DataFrame,
48    required_cols: &[String],
49    indices: &ColumnIndex,
50) -> FloeResult<SparseRowErrors> {
51    let mut errors = SparseRowErrors::new(df.height());
52    if required_cols.is_empty() || df.height() == 0 {
53        return Ok(errors);
54    }
55
56    let mut null_masks = Vec::with_capacity(required_cols.len());
57    for name in required_cols {
58        let index = indices.get(name).ok_or_else(|| {
59            Box::new(RunError(format!(
60                "required column {name} not found in dataframe"
61            )))
62        })?;
63        let mask = df
64            .select_at_idx(*index)
65            .ok_or_else(|| {
66                Box::new(RunError(format!(
67                    "required column {name} not found in dataframe"
68                )))
69            })?
70            .is_null();
71        null_masks.push((name, mask));
72    }
73
74    for row_idx in 0..df.height() {
75        for (col, mask) in null_masks.iter() {
76            if mask.get(row_idx).unwrap_or(false) {
77                errors.add_error(
78                    row_idx,
79                    RowError::new("not_null", col, "required value missing"),
80                );
81            }
82        }
83    }
84
85    Ok(errors)
86}
87
88pub fn not_null_expr(col_name: &str) -> (String, Expr) {
89    let err_col = format!("_e_nn_{col_name}");
90    let error_json = RowError::new("not_null", col_name, "required value missing").to_json();
91    let expr = polars::prelude::when(col(col_name).is_null())
92        .then(lit(error_json))
93        .otherwise(lit(NULL))
94        .alias(&err_col);
95    (err_col, expr)
96}
97
98pub fn not_null_counts(df: &DataFrame, required_cols: &[String]) -> FloeResult<Vec<(String, u64)>> {
99    if required_cols.is_empty() || df.height() == 0 {
100        return Ok(Vec::new());
101    }
102
103    let null_counts = df.null_count();
104    let mut counts = Vec::new();
105    for name in required_cols {
106        let series = null_counts.column(name).map_err(|err| {
107            Box::new(RunError(format!("required column {name} not found: {err}")))
108        })?;
109        let value = series.get(0).unwrap_or(AnyValue::UInt32(0));
110        let violations = match value {
111            AnyValue::UInt32(value) => value as u64,
112            AnyValue::UInt64(value) => value,
113            AnyValue::Int64(value) => value.max(0) as u64,
114            AnyValue::Int32(value) => value.max(0) as u64,
115            AnyValue::Null => 0,
116            _ => 0,
117        };
118        if violations > 0 {
119            counts.push((name.clone(), violations));
120        }
121    }
122    Ok(counts)
123}