Skip to main content

floe_core/checks/
cast.rs

1use polars::prelude::{col, lit, BooleanChunked, DataFrame, Expr, StringChunked, NULL};
2
3use super::{ColumnIndex, RowError, SparseRowErrors};
4use crate::errors::RunError;
5use crate::{config, FloeResult};
6
7/// Detect cast mismatches by comparing raw (string) values to typed values.
8/// If a raw value exists but the typed value is null, we treat it as a cast error.
9pub fn cast_mismatch_errors(
10    raw_df: &DataFrame,
11    typed_df: &DataFrame,
12    columns: &[config::ColumnConfig],
13    raw_indices: &ColumnIndex,
14    typed_indices: &ColumnIndex,
15) -> FloeResult<Vec<Vec<RowError>>> {
16    let mut errors_per_row = vec![Vec::new(); typed_df.height()];
17    if typed_df.height() == 0 {
18        return Ok(errors_per_row);
19    }
20
21    for column in columns {
22        if is_string_type(&column.column_type) {
23            continue;
24        }
25        let raw_index = raw_indices
26            .get(&column.name)
27            .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?;
28        let typed_index = typed_indices
29            .get(&column.name)
30            .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?;
31        let raw = raw_df
32            .select_at_idx(*raw_index)
33            .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?
34            .str()
35            .map_err(|err| {
36                Box::new(RunError(format!(
37                    "raw column {} is not utf8: {err}",
38                    column.name
39                )))
40            })?;
41        let typed_nulls = typed_df
42            .select_at_idx(*typed_index)
43            .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?
44            .is_null();
45
46        append_cast_errors(&mut errors_per_row, &column.name, raw, &typed_nulls)?;
47    }
48
49    Ok(errors_per_row)
50}
51
52pub fn cast_mismatch_errors_sparse(
53    raw_df: &DataFrame,
54    typed_df: &DataFrame,
55    columns: &[config::ColumnConfig],
56    raw_indices: &ColumnIndex,
57    typed_indices: &ColumnIndex,
58) -> FloeResult<SparseRowErrors> {
59    let mut errors = SparseRowErrors::new(typed_df.height());
60    if typed_df.height() == 0 {
61        return Ok(errors);
62    }
63
64    for column in columns {
65        if is_string_type(&column.column_type) {
66            continue;
67        }
68        let raw_index = raw_indices
69            .get(&column.name)
70            .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?;
71        let typed_index = typed_indices
72            .get(&column.name)
73            .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?;
74        let raw = raw_df
75            .select_at_idx(*raw_index)
76            .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?
77            .str()
78            .map_err(|err| {
79                Box::new(RunError(format!(
80                    "raw column {} is not utf8: {err}",
81                    column.name
82                )))
83            })?;
84        let typed_nulls = typed_df
85            .select_at_idx(*typed_index)
86            .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?
87            .is_null();
88
89        let raw_not_null = raw.is_not_null();
90        let invalid_mask = typed_nulls & raw_not_null;
91        for (row_idx, invalid) in invalid_mask.into_iter().enumerate() {
92            if invalid == Some(true) {
93                errors.add_error(
94                    row_idx,
95                    RowError::new("cast_error", &column.name, "invalid value for target type"),
96                );
97            }
98        }
99    }
100
101    Ok(errors)
102}
103
104pub fn cast_mismatch_expr(typed_col: &str, raw_col: &str) -> (String, Expr) {
105    let err_col = format!("_e_cast_{typed_col}");
106    let error_json =
107        RowError::new("cast_error", typed_col, "invalid value for target type").to_json();
108    let expr = polars::prelude::when(col(raw_col).is_not_null().and(col(typed_col).is_null()))
109        .then(lit(error_json))
110        .otherwise(lit(NULL))
111        .alias(&err_col);
112    (err_col, expr)
113}
114
115pub fn cast_mismatch_counts(
116    raw_df: &DataFrame,
117    typed_df: &DataFrame,
118    columns: &[config::ColumnConfig],
119) -> FloeResult<Vec<(String, u64, String)>> {
120    if typed_df.height() == 0 {
121        return Ok(Vec::new());
122    }
123
124    let mut counts = Vec::new();
125    for column in columns {
126        if is_string_type(&column.column_type) {
127            continue;
128        }
129
130        let raw = raw_df
131            .column(&column.name)
132            .map_err(|err| {
133                Box::new(RunError(format!(
134                    "raw column {} not found: {err}",
135                    column.name
136                )))
137            })?
138            .str()
139            .map_err(|err| {
140                Box::new(RunError(format!(
141                    "raw column {} is not utf8: {err}",
142                    column.name
143                )))
144            })?;
145        let typed_nulls = typed_df
146            .column(&column.name)
147            .map_err(|err| {
148                Box::new(RunError(format!(
149                    "typed column {} not found: {err}",
150                    column.name
151                )))
152            })?
153            .is_null();
154
155        let raw_not_null = raw.is_not_null();
156        let violations = (&typed_nulls & &raw_not_null).sum().unwrap_or(0) as u64;
157
158        if violations > 0 {
159            counts.push((column.name.clone(), violations, column.column_type.clone()));
160        }
161    }
162
163    Ok(counts)
164}
165
166fn append_cast_errors(
167    errors_per_row: &mut [Vec<RowError>],
168    column_name: &str,
169    raw: &StringChunked,
170    typed_nulls: &BooleanChunked,
171) -> FloeResult<()> {
172    let raw_not_null = raw.is_not_null();
173    let invalid_mask = typed_nulls & &raw_not_null;
174    for (row_idx, invalid) in invalid_mask.into_iter().enumerate() {
175        if invalid == Some(true) {
176            errors_per_row[row_idx].push(RowError::new(
177                "cast_error",
178                column_name,
179                "invalid value for target type",
180            ));
181        }
182    }
183    Ok(())
184}
185
186pub(crate) fn is_string_type(value: &str) -> bool {
187    let normalized = value.to_ascii_lowercase().replace(['-', '_'], "");
188    matches!(normalized.as_str(), "string" | "str" | "text")
189}