Skip to main content

floe_core/checks/
cast.rs

1use polars::prelude::{BooleanChunked, DataFrame, StringChunked};
2
3use super::{ColumnIndex, RowError, SparseRowErrors};
4use crate::errors::RunError;
5use crate::{config, FloeResult};
6
7/// Detect cast mismatches by comparing raw (string) values to typed values.
8/// If a raw value exists but the typed value is null, we treat it as a cast error.
9pub fn cast_mismatch_errors(
10    raw_df: &DataFrame,
11    typed_df: &DataFrame,
12    columns: &[config::ColumnConfig],
13    raw_indices: &ColumnIndex,
14    typed_indices: &ColumnIndex,
15) -> FloeResult<Vec<Vec<RowError>>> {
16    let mut errors_per_row = vec![Vec::new(); typed_df.height()];
17    if typed_df.height() == 0 {
18        return Ok(errors_per_row);
19    }
20
21    for column in columns {
22        if is_string_type(&column.column_type) {
23            continue;
24        }
25        let raw_index = raw_indices
26            .get(&column.name)
27            .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?;
28        let typed_index = typed_indices
29            .get(&column.name)
30            .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?;
31        let raw = raw_df
32            .select_at_idx(*raw_index)
33            .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?
34            .str()
35            .map_err(|err| {
36                Box::new(RunError(format!(
37                    "raw column {} is not utf8: {err}",
38                    column.name
39                )))
40            })?;
41        let typed_nulls = typed_df
42            .select_at_idx(*typed_index)
43            .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?
44            .is_null();
45
46        append_cast_errors(&mut errors_per_row, &column.name, raw, &typed_nulls)?;
47    }
48
49    Ok(errors_per_row)
50}
51
52pub fn cast_mismatch_errors_sparse(
53    raw_df: &DataFrame,
54    typed_df: &DataFrame,
55    columns: &[config::ColumnConfig],
56    raw_indices: &ColumnIndex,
57    typed_indices: &ColumnIndex,
58) -> FloeResult<SparseRowErrors> {
59    let mut errors = SparseRowErrors::new(typed_df.height());
60    if typed_df.height() == 0 {
61        return Ok(errors);
62    }
63
64    for column in columns {
65        if is_string_type(&column.column_type) {
66            continue;
67        }
68        let raw_index = raw_indices
69            .get(&column.name)
70            .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?;
71        let typed_index = typed_indices
72            .get(&column.name)
73            .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?;
74        let raw = raw_df
75            .select_at_idx(*raw_index)
76            .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?
77            .str()
78            .map_err(|err| {
79                Box::new(RunError(format!(
80                    "raw column {} is not utf8: {err}",
81                    column.name
82                )))
83            })?;
84        let typed_nulls = typed_df
85            .select_at_idx(*typed_index)
86            .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?
87            .is_null();
88
89        let raw_not_null = raw.is_not_null();
90        let invalid_mask = typed_nulls & raw_not_null;
91        for (row_idx, invalid) in invalid_mask.into_iter().enumerate() {
92            if invalid == Some(true) {
93                errors.add_error(
94                    row_idx,
95                    RowError::new("cast_error", &column.name, "invalid value for target type"),
96                );
97            }
98        }
99    }
100
101    Ok(errors)
102}
103
104pub fn cast_mismatch_counts(
105    raw_df: &DataFrame,
106    typed_df: &DataFrame,
107    columns: &[config::ColumnConfig],
108) -> FloeResult<Vec<(String, u64, String)>> {
109    if typed_df.height() == 0 {
110        return Ok(Vec::new());
111    }
112
113    let mut counts = Vec::new();
114    for column in columns {
115        if is_string_type(&column.column_type) {
116            continue;
117        }
118
119        let raw = raw_df
120            .column(&column.name)
121            .map_err(|err| {
122                Box::new(RunError(format!(
123                    "raw column {} not found: {err}",
124                    column.name
125                )))
126            })?
127            .str()
128            .map_err(|err| {
129                Box::new(RunError(format!(
130                    "raw column {} is not utf8: {err}",
131                    column.name
132                )))
133            })?;
134        let typed_nulls = typed_df
135            .column(&column.name)
136            .map_err(|err| {
137                Box::new(RunError(format!(
138                    "typed column {} not found: {err}",
139                    column.name
140                )))
141            })?
142            .is_null();
143
144        let raw_not_null = raw.is_not_null();
145        let violations = (&typed_nulls & &raw_not_null).sum().unwrap_or(0) as u64;
146
147        if violations > 0 {
148            counts.push((column.name.clone(), violations, column.column_type.clone()));
149        }
150    }
151
152    Ok(counts)
153}
154
155fn append_cast_errors(
156    errors_per_row: &mut [Vec<RowError>],
157    column_name: &str,
158    raw: &StringChunked,
159    typed_nulls: &BooleanChunked,
160) -> FloeResult<()> {
161    let raw_not_null = raw.is_not_null();
162    let invalid_mask = typed_nulls & &raw_not_null;
163    for (row_idx, invalid) in invalid_mask.into_iter().enumerate() {
164        if invalid == Some(true) {
165            errors_per_row[row_idx].push(RowError::new(
166                "cast_error",
167                column_name,
168                "invalid value for target type",
169            ));
170        }
171    }
172    Ok(())
173}
174
175fn is_string_type(value: &str) -> bool {
176    let normalized = value.to_ascii_lowercase().replace(['-', '_'], "");
177    matches!(normalized.as_str(), "string" | "str" | "text")
178}