1use polars::prelude::{BooleanChunked, DataFrame, StringChunked};
2
3use super::{ColumnIndex, RowError, SparseRowErrors};
4use crate::errors::RunError;
5use crate::{config, FloeResult};
6
7pub fn cast_mismatch_errors(
10 raw_df: &DataFrame,
11 typed_df: &DataFrame,
12 columns: &[config::ColumnConfig],
13 raw_indices: &ColumnIndex,
14 typed_indices: &ColumnIndex,
15) -> FloeResult<Vec<Vec<RowError>>> {
16 let mut errors_per_row = vec![Vec::new(); typed_df.height()];
17 if typed_df.height() == 0 {
18 return Ok(errors_per_row);
19 }
20
21 for column in columns {
22 if is_string_type(&column.column_type) {
23 continue;
24 }
25 let raw_index = raw_indices
26 .get(&column.name)
27 .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?;
28 let typed_index = typed_indices
29 .get(&column.name)
30 .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?;
31 let raw = raw_df
32 .select_at_idx(*raw_index)
33 .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?
34 .str()
35 .map_err(|err| {
36 Box::new(RunError(format!(
37 "raw column {} is not utf8: {err}",
38 column.name
39 )))
40 })?;
41 let typed_nulls = typed_df
42 .select_at_idx(*typed_index)
43 .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?
44 .is_null();
45
46 append_cast_errors(&mut errors_per_row, &column.name, raw, &typed_nulls)?;
47 }
48
49 Ok(errors_per_row)
50}
51
52pub fn cast_mismatch_errors_sparse(
53 raw_df: &DataFrame,
54 typed_df: &DataFrame,
55 columns: &[config::ColumnConfig],
56 raw_indices: &ColumnIndex,
57 typed_indices: &ColumnIndex,
58) -> FloeResult<SparseRowErrors> {
59 let mut errors = SparseRowErrors::new(typed_df.height());
60 if typed_df.height() == 0 {
61 return Ok(errors);
62 }
63
64 for column in columns {
65 if is_string_type(&column.column_type) {
66 continue;
67 }
68 let raw_index = raw_indices
69 .get(&column.name)
70 .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?;
71 let typed_index = typed_indices
72 .get(&column.name)
73 .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?;
74 let raw = raw_df
75 .select_at_idx(*raw_index)
76 .ok_or_else(|| Box::new(RunError(format!("raw column {} not found", column.name))))?
77 .str()
78 .map_err(|err| {
79 Box::new(RunError(format!(
80 "raw column {} is not utf8: {err}",
81 column.name
82 )))
83 })?;
84 let typed_nulls = typed_df
85 .select_at_idx(*typed_index)
86 .ok_or_else(|| Box::new(RunError(format!("typed column {} not found", column.name))))?
87 .is_null();
88
89 let raw_not_null = raw.is_not_null();
90 let invalid_mask = typed_nulls & raw_not_null;
91 for (row_idx, invalid) in invalid_mask.into_iter().enumerate() {
92 if invalid == Some(true) {
93 errors.add_error(
94 row_idx,
95 RowError::new("cast_error", &column.name, "invalid value for target type"),
96 );
97 }
98 }
99 }
100
101 Ok(errors)
102}
103
104pub fn cast_mismatch_counts(
105 raw_df: &DataFrame,
106 typed_df: &DataFrame,
107 columns: &[config::ColumnConfig],
108) -> FloeResult<Vec<(String, u64, String)>> {
109 if typed_df.height() == 0 {
110 return Ok(Vec::new());
111 }
112
113 let mut counts = Vec::new();
114 for column in columns {
115 if is_string_type(&column.column_type) {
116 continue;
117 }
118
119 let raw = raw_df
120 .column(&column.name)
121 .map_err(|err| {
122 Box::new(RunError(format!(
123 "raw column {} not found: {err}",
124 column.name
125 )))
126 })?
127 .str()
128 .map_err(|err| {
129 Box::new(RunError(format!(
130 "raw column {} is not utf8: {err}",
131 column.name
132 )))
133 })?;
134 let typed_nulls = typed_df
135 .column(&column.name)
136 .map_err(|err| {
137 Box::new(RunError(format!(
138 "typed column {} not found: {err}",
139 column.name
140 )))
141 })?
142 .is_null();
143
144 let raw_not_null = raw.is_not_null();
145 let violations = (&typed_nulls & &raw_not_null).sum().unwrap_or(0) as u64;
146
147 if violations > 0 {
148 counts.push((column.name.clone(), violations, column.column_type.clone()));
149 }
150 }
151
152 Ok(counts)
153}
154
155fn append_cast_errors(
156 errors_per_row: &mut [Vec<RowError>],
157 column_name: &str,
158 raw: &StringChunked,
159 typed_nulls: &BooleanChunked,
160) -> FloeResult<()> {
161 let raw_not_null = raw.is_not_null();
162 let invalid_mask = typed_nulls & &raw_not_null;
163 for (row_idx, invalid) in invalid_mask.into_iter().enumerate() {
164 if invalid == Some(true) {
165 errors_per_row[row_idx].push(RowError::new(
166 "cast_error",
167 column_name,
168 "invalid value for target type",
169 ));
170 }
171 }
172 Ok(())
173}
174
175fn is_string_type(value: &str) -> bool {
176 let normalized = value.to_ascii_lowercase().replace(['-', '_'], "");
177 matches!(normalized.as_str(), "string" | "str" | "text")
178}