1use datafusion::arrow::array::*;
7use datafusion::arrow::datatypes::{DataType, Field, Schema};
8use datafusion::functions_aggregate::expr_fn::{
9 avg, count, max as agg_max, min as agg_min, stddev,
10};
11use datafusion::prelude::*;
12use std::sync::Arc;
13
14use crate::engine::DataEngine;
15
16impl DataEngine {
17 pub fn fill_null(
20 &self,
21 df: DataFrame,
22 column: &str,
23 strategy: &str,
24 fill_value: Option<f64>,
25 ) -> Result<DataFrame, String> {
26 let fill_expr = match strategy {
27 "value" => {
28 let val =
29 fill_value.ok_or("fill_null with 'value' strategy requires a fill_value")?;
30 coalesce(vec![col(column), lit(val)]).alias(column)
31 }
32 "zero" => coalesce(vec![col(column), lit(0.0)]).alias(column),
33 "mean" => {
34 let mean_df = df
36 .clone()
37 .aggregate(vec![], vec![avg(col(column)).alias("__mean")])
38 .map_err(|e| format!("fill_null mean aggregate error: {e}"))?;
39 let batches = self.collect(mean_df)?;
40 let mean_val = if !batches.is_empty() && batches[0].num_rows() > 0 {
41 let col_arr = batches[0].column(0);
42 if let Some(f64_arr) = col_arr.as_any().downcast_ref::<Float64Array>() {
43 if f64_arr.is_null(0) {
44 0.0
45 } else {
46 f64_arr.value(0)
47 }
48 } else {
49 0.0
50 }
51 } else {
52 0.0
53 };
54 coalesce(vec![col(column), lit(mean_val)]).alias(column)
55 }
56 "median" => {
57 let mean_df = df
60 .clone()
61 .aggregate(vec![], vec![avg(col(column)).alias("__mean")])
62 .map_err(|e| format!("fill_null median aggregate error: {e}"))?;
63 let batches = self.collect(mean_df)?;
64 let mean_val = if !batches.is_empty() && batches[0].num_rows() > 0 {
65 let col_arr = batches[0].column(0);
66 if let Some(f64_arr) = col_arr.as_any().downcast_ref::<Float64Array>() {
67 if f64_arr.is_null(0) {
68 0.0
69 } else {
70 f64_arr.value(0)
71 }
72 } else {
73 0.0
74 }
75 } else {
76 0.0
77 };
78 coalesce(vec![col(column), lit(mean_val)]).alias(column)
79 }
80 other => return Err(format!("Unknown fill_null strategy: {other}")),
81 };
82
83 let schema = df.schema().clone();
85 let mut select_exprs = Vec::new();
86 for field in schema.fields() {
87 if field.name() == column {
88 select_exprs.push(fill_expr.clone());
89 } else {
90 select_exprs.push(col(field.name()));
91 }
92 }
93
94 df.select(select_exprs)
95 .map_err(|e| format!("fill_null select error: {e}"))
96 }
97
98 pub fn drop_null(&self, df: DataFrame, column: &str) -> Result<DataFrame, String> {
100 df.filter(col(column).is_not_null())
101 .map_err(|e| format!("drop_null error: {e}"))
102 }
103
104 pub fn dedup(&self, df: DataFrame, columns: &[String]) -> Result<DataFrame, String> {
106 if columns.is_empty() {
107 return df.distinct().map_err(|e| format!("dedup error: {e}"));
108 }
109 let table_name = "__dedup_tmp";
111 self.ctx
112 .register_table(table_name, df.into_view())
113 .map_err(|e| format!("dedup register error: {e}"))?;
114
115 let cols_str = columns.join(", ");
116 let result = self.sql(&format!(
117 "SELECT DISTINCT ON ({cols_str}) * FROM {table_name}"
118 ));
119
120 match result {
122 Ok(r) => Ok(r),
123 Err(_) => {
124 let all_cols = self.sql(&format!("SELECT * FROM {table_name} GROUP BY {cols_str}"));
126 match all_cols {
127 Ok(r) => Ok(r),
128 Err(_) => {
129 self.sql(&format!("SELECT DISTINCT * FROM {table_name}"))
131 }
132 }
133 }
134 }
135 }
136
137 pub fn clamp(
139 &self,
140 df: DataFrame,
141 column: &str,
142 min_val: f64,
143 max_val: f64,
144 ) -> Result<DataFrame, String> {
145 let clamp_expr = when(col(column).lt(lit(min_val)), lit(min_val))
146 .when(col(column).gt(lit(max_val)), lit(max_val))
147 .otherwise(col(column))
148 .map_err(|e| format!("clamp expr error: {e}"))?
149 .alias(column);
150
151 let schema = df.schema().clone();
152 let mut select_exprs = Vec::new();
153 for field in schema.fields() {
154 if field.name() == column {
155 select_exprs.push(clamp_expr.clone());
156 } else {
157 select_exprs.push(col(field.name()));
158 }
159 }
160
161 df.select(select_exprs)
162 .map_err(|e| format!("clamp select error: {e}"))
163 }
164
165 pub fn data_profile(&self, df: DataFrame) -> Result<DataFrame, String> {
168 let schema = df.schema().clone();
169 let mut col_names = Vec::new();
170 let mut counts = Vec::new();
171 let mut null_counts = Vec::new();
172 let mut null_rates = Vec::new();
173 let mut mins = Vec::new();
174 let mut maxs = Vec::new();
175 let mut means = Vec::new();
176 let mut stddevs = Vec::new();
177
178 for field in schema.fields() {
179 let name = field.name();
180 let is_numeric = matches!(
181 field.data_type(),
182 DataType::Int8
183 | DataType::Int16
184 | DataType::Int32
185 | DataType::Int64
186 | DataType::UInt8
187 | DataType::UInt16
188 | DataType::UInt32
189 | DataType::UInt64
190 | DataType::Float32
191 | DataType::Float64
192 );
193
194 let mut agg_exprs = vec![count(col(name)).alias("__count")];
196 if is_numeric {
197 agg_exprs.push(agg_min(col(name)).alias("__min"));
198 agg_exprs.push(agg_max(col(name)).alias("__max"));
199 agg_exprs.push(avg(col(name)).alias("__mean"));
200 agg_exprs.push(stddev(col(name)).alias("__stddev"));
201 }
202
203 let agg_df = df
204 .clone()
205 .aggregate(vec![], agg_exprs)
206 .map_err(|e| format!("data_profile aggregate error for {name}: {e}"))?;
207 let batches = self.collect(agg_df)?;
208
209 if batches.is_empty() || batches[0].num_rows() == 0 {
210 continue;
211 }
212 let batch = &batches[0];
213
214 let non_null_cnt = Self::extract_i64_or_u64(batch.column(0));
215 let total = self.row_count(df.clone())?;
217 let null_cnt = total - non_null_cnt;
218 let nr = if total > 0 {
219 null_cnt as f64 / total as f64
220 } else {
221 0.0
222 };
223
224 col_names.push(name.clone());
225 counts.push(non_null_cnt);
226 null_counts.push(null_cnt);
227 null_rates.push(nr);
228
229 if is_numeric && batch.num_columns() >= 5 {
230 mins.push(Self::extract_f64(batch.column(1)));
231 maxs.push(Self::extract_f64(batch.column(2)));
232 means.push(Self::extract_f64(batch.column(3)));
233 stddevs.push(Self::extract_f64(batch.column(4)));
234 } else {
235 mins.push(f64::NAN);
236 maxs.push(f64::NAN);
237 means.push(f64::NAN);
238 stddevs.push(f64::NAN);
239 }
240 }
241
242 let result_schema = Arc::new(Schema::new(vec![
243 Field::new("column_name", DataType::Utf8, false),
244 Field::new("count", DataType::Int64, false),
245 Field::new("null_count", DataType::Int64, false),
246 Field::new("null_rate", DataType::Float64, false),
247 Field::new("min", DataType::Float64, true),
248 Field::new("max", DataType::Float64, true),
249 Field::new("mean", DataType::Float64, true),
250 Field::new("stddev", DataType::Float64, true),
251 ]));
252
253 let batch = RecordBatch::try_new(
254 result_schema,
255 vec![
256 Arc::new(StringArray::from(col_names)),
257 Arc::new(Int64Array::from(counts)),
258 Arc::new(Int64Array::from(null_counts)),
259 Arc::new(Float64Array::from(null_rates)),
260 Arc::new(Float64Array::from(mins)),
261 Arc::new(Float64Array::from(maxs)),
262 Arc::new(Float64Array::from(means)),
263 Arc::new(Float64Array::from(stddevs)),
264 ],
265 )
266 .map_err(|e| format!("data_profile batch error: {e}"))?;
267
268 self.register_batch("__data_profile", batch)?;
269 self.rt
270 .block_on(self.ctx.table("__data_profile"))
271 .map_err(|e| format!("data_profile table error: {e}"))
272 }
273
274 pub fn row_count(&self, df: DataFrame) -> Result<i64, String> {
276 let cnt = self
277 .rt
278 .block_on(df.count())
279 .map_err(|e| format!("row_count error: {e}"))?;
280 Ok(cnt as i64)
281 }
282
283 pub fn null_rate(&self, df: DataFrame, column: &str) -> Result<f64, String> {
285 let total = self
286 .rt
287 .block_on(df.clone().count())
288 .map_err(|e| format!("null_rate count error: {e}"))? as i64;
289 if total == 0 {
290 return Ok(0.0);
291 }
292 let non_null_df = df
293 .aggregate(vec![], vec![count(col(column)).alias("__non_null")])
294 .map_err(|e| format!("null_rate aggregate error: {e}"))?;
295 let batches = self.collect(non_null_df)?;
296 if batches.is_empty() || batches[0].num_rows() == 0 {
297 return Ok(0.0);
298 }
299 let non_null = Self::extract_i64_or_u64(batches[0].column(0));
300 Ok((total - non_null) as f64 / total as f64)
301 }
302
303 pub fn is_unique(&self, df: DataFrame, column: &str) -> Result<bool, String> {
305 let table_name = "__unique_check_tmp";
306 self.ctx
307 .register_table(table_name, df.into_view())
308 .map_err(|e| format!("is_unique register error: {e}"))?;
309
310 let result = self.sql(&format!(
311 "SELECT COUNT(DISTINCT \"{column}\") = COUNT(\"{column}\") AS is_uniq FROM {table_name} WHERE \"{column}\" IS NOT NULL"
312 ))?;
313
314 let batches = self.collect(result)?;
315 if batches.is_empty() || batches[0].num_rows() == 0 {
316 return Ok(true);
317 }
318 let col_arr = batches[0].column(0);
319 if let Some(bool_arr) = col_arr.as_any().downcast_ref::<BooleanArray>() {
320 Ok(!bool_arr.is_null(0) && bool_arr.value(0))
321 } else {
322 Ok(false)
323 }
324 }
325
326 fn extract_i64_or_u64(arr: &dyn Array) -> i64 {
328 if let Some(a) = arr.as_any().downcast_ref::<Int64Array>() {
329 if a.is_null(0) { 0 } else { a.value(0) }
330 } else if let Some(a) = arr.as_any().downcast_ref::<UInt64Array>() {
331 if a.is_null(0) { 0 } else { a.value(0) as i64 }
332 } else {
333 0
334 }
335 }
336
337 fn extract_f64(arr: &dyn Array) -> f64 {
339 if let Some(a) = arr.as_any().downcast_ref::<Float64Array>() {
340 if a.is_null(0) { f64::NAN } else { a.value(0) }
341 } else if let Some(a) = arr.as_any().downcast_ref::<Int64Array>() {
342 if a.is_null(0) {
343 f64::NAN
344 } else {
345 a.value(0) as f64
346 }
347 } else if let Some(a) = arr.as_any().downcast_ref::<Int32Array>() {
348 if a.is_null(0) {
349 f64::NAN
350 } else {
351 a.value(0) as f64
352 }
353 } else {
354 f64::NAN
355 }
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use datafusion::arrow::array::{Float64Array, Int64Array, StringArray};
363 use datafusion::arrow::datatypes::{DataType, Field, Schema};
364
365 fn make_test_engine_with_data() -> DataEngine {
366 let engine = DataEngine::new();
367 let schema = Arc::new(Schema::new(vec![
368 Field::new("id", DataType::Int64, false),
369 Field::new("name", DataType::Utf8, true),
370 Field::new("age", DataType::Float64, true),
371 ]));
372 let batch = RecordBatch::try_new(
373 schema,
374 vec![
375 Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
376 Arc::new(StringArray::from(vec![
377 Some("Alice"),
378 Some("Bob"),
379 None,
380 Some("Diana"),
381 Some("Eve"),
382 ])),
383 Arc::new(Float64Array::from(vec![
384 Some(30.0),
385 Some(25.0),
386 None,
387 Some(35.0),
388 Some(28.0),
389 ])),
390 ],
391 )
392 .unwrap();
393 engine.register_batch("test_data", batch).unwrap();
394 engine
395 }
396
397 #[test]
398 fn test_fill_null_value() {
399 let engine = make_test_engine_with_data();
400 let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
401 let result = engine.fill_null(df, "age", "value", Some(0.0)).unwrap();
402 let batches = engine.collect(result).unwrap();
403 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
404 assert_eq!(total_rows, 5);
405 let age_col = batches[0].column_by_name("age").unwrap();
407 let f64_arr = age_col.as_any().downcast_ref::<Float64Array>().unwrap();
408 assert!(!f64_arr.is_null(2)); assert_eq!(f64_arr.value(2), 0.0);
410 }
411
412 #[test]
413 fn test_fill_null_mean() {
414 let engine = make_test_engine_with_data();
415 let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
416 let result = engine.fill_null(df, "age", "mean", None).unwrap();
417 let batches = engine.collect(result).unwrap();
418 let age_col = batches[0].column_by_name("age").unwrap();
419 let f64_arr = age_col.as_any().downcast_ref::<Float64Array>().unwrap();
420 assert!(!f64_arr.is_null(2));
421 assert!((f64_arr.value(2) - 29.5).abs() < 0.01);
423 }
424
425 #[test]
426 fn test_drop_null() {
427 let engine = make_test_engine_with_data();
428 let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
429 let result = engine.drop_null(df, "name").unwrap();
430 let batches = engine.collect(result).unwrap();
431 let total: usize = batches.iter().map(|b| b.num_rows()).sum();
432 assert_eq!(total, 4); }
434
435 #[test]
436 fn test_dedup() {
437 let engine = DataEngine::new();
438 let schema = Arc::new(Schema::new(vec![
439 Field::new("id", DataType::Int64, false),
440 Field::new("val", DataType::Utf8, false),
441 ]));
442 let batch = RecordBatch::try_new(
443 schema,
444 vec![
445 Arc::new(Int64Array::from(vec![1, 2, 2, 3])),
446 Arc::new(StringArray::from(vec!["a", "b", "b", "c"])),
447 ],
448 )
449 .unwrap();
450 engine.register_batch("dup_data", batch).unwrap();
451 let df = engine.rt.block_on(engine.ctx.table("dup_data")).unwrap();
452 let result = engine.dedup(df, &[]).unwrap();
453 let batches = engine.collect(result).unwrap();
454 let total: usize = batches.iter().map(|b| b.num_rows()).sum();
455 assert_eq!(total, 3); }
457
458 #[test]
459 fn test_clamp() {
460 let engine = make_test_engine_with_data();
461 let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
462 let result = engine.clamp(df, "age", 26.0, 32.0).unwrap();
463 let batches = engine.collect(result).unwrap();
464 let age_col = batches[0].column_by_name("age").unwrap();
465 let f64_arr = age_col.as_any().downcast_ref::<Float64Array>().unwrap();
466 assert_eq!(f64_arr.value(0), 30.0);
468 assert_eq!(f64_arr.value(1), 26.0);
469 assert_eq!(f64_arr.value(3), 32.0);
470 assert_eq!(f64_arr.value(4), 28.0);
471 }
472
473 #[test]
474 fn test_data_profile() {
475 let engine = make_test_engine_with_data();
476 let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
477 let result = engine.data_profile(df).unwrap();
478 let batches = engine.collect(result).unwrap();
479 assert!(!batches.is_empty());
480 let total: usize = batches.iter().map(|b| b.num_rows()).sum();
482 assert!(total >= 2); }
484
485 #[test]
486 fn test_row_count() {
487 let engine = make_test_engine_with_data();
488 let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
489 let count = engine.row_count(df).unwrap();
490 assert_eq!(count, 5);
491 }
492
493 #[test]
494 fn test_null_rate_and_is_unique() {
495 let engine = make_test_engine_with_data();
496 let df = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
497 let rate = engine.null_rate(df, "name").unwrap();
498 assert!((rate - 0.2).abs() < 0.01); let df2 = engine.rt.block_on(engine.ctx.table("test_data")).unwrap();
501 let unique = engine.is_unique(df2, "id").unwrap();
502 assert!(unique);
503 }
504}