1use crate::core::{Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
4use crate::prelude::*;
5use arrow::array::{Array, LargeStringArray, StringViewArray};
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use std::fmt;
9use std::sync::Arc;
10use tracing::instrument;
11
12#[derive(Debug, Clone, PartialEq)]
14pub struct HistogramBucket {
15 pub value: String,
17 pub count: i64,
19 pub ratio: f64,
21}
22
23#[derive(Debug, Clone)]
25pub struct Histogram {
26 pub buckets: Vec<HistogramBucket>,
28 pub total_count: i64,
30 pub distinct_count: usize,
32 pub null_count: i64,
34}
35
36impl Histogram {
37 pub fn new(buckets: Vec<HistogramBucket>, total_count: i64, null_count: i64) -> Self {
39 let distinct_count = buckets.len();
40 Self {
41 buckets,
42 total_count,
43 distinct_count,
44 null_count,
45 }
46 }
47
48 pub fn most_common_ratio(&self) -> f64 {
50 self.buckets.first().map(|b| b.ratio).unwrap_or(0.0)
51 }
52
53 pub fn least_common_ratio(&self) -> f64 {
55 self.buckets.last().map(|b| b.ratio).unwrap_or(0.0)
56 }
57
58 pub fn bucket_count(&self) -> usize {
60 self.buckets.len()
61 }
62
63 pub fn top_n(&self, n: usize) -> Vec<(&str, f64)> {
65 self.buckets
66 .iter()
67 .take(n)
68 .map(|b| (b.value.as_str(), b.ratio))
69 .collect()
70 }
71
72 pub fn is_roughly_uniform(&self, threshold: f64) -> bool {
77 if self.buckets.is_empty() {
78 return true;
79 }
80
81 let max_ratio = self.most_common_ratio();
82 let min_ratio = self.least_common_ratio();
83
84 if min_ratio == 0.0 {
85 return false;
86 }
87
88 max_ratio / min_ratio <= threshold
89 }
90
91 pub fn get_value_ratio(&self, value: &str) -> Option<f64> {
93 self.buckets
94 .iter()
95 .find(|b| b.value == value)
96 .map(|b| b.ratio)
97 }
98
99 pub fn entropy(&self) -> f64 {
103 self.buckets
104 .iter()
105 .filter(|b| b.ratio > 0.0)
106 .map(|b| -b.ratio * b.ratio.ln())
107 .sum()
108 }
109
110 pub fn follows_power_law(&self, top_n: usize, threshold: f64) -> bool {
114 let top_sum: f64 = self.buckets.iter().take(top_n).map(|b| b.ratio).sum();
115 top_sum >= threshold
116 }
117
118 pub fn null_ratio(&self) -> f64 {
120 if self.total_count == 0 {
121 0.0
122 } else {
123 self.null_count as f64 / self.total_count as f64
124 }
125 }
126}
127
128pub type HistogramAssertion = Arc<dyn Fn(&Histogram) -> bool + Send + Sync>;
130
131#[derive(Clone)]
154pub struct HistogramConstraint {
155 column: String,
156 assertion: HistogramAssertion,
157 assertion_description: String,
158}
159
160impl fmt::Debug for HistogramConstraint {
161 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162 f.debug_struct("HistogramConstraint")
163 .field("column", &self.column)
164 .field("assertion_description", &self.assertion_description)
165 .finish()
166 }
167}
168
169impl HistogramConstraint {
170 pub fn new(column: impl Into<String>, assertion: HistogramAssertion) -> Self {
177 Self {
178 column: column.into(),
179 assertion,
180 assertion_description: "custom assertion".to_string(),
181 }
182 }
183
184 pub fn new_with_description(
192 column: impl Into<String>,
193 assertion: HistogramAssertion,
194 description: impl Into<String>,
195 ) -> Self {
196 Self {
197 column: column.into(),
198 assertion,
199 assertion_description: description.into(),
200 }
201 }
202}
203
204#[async_trait]
205impl Constraint for HistogramConstraint {
206 #[instrument(skip(self, ctx), fields(column = %self.column))]
207 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
208 let sql = format!(
210 r#"
211 WITH value_counts AS (
212 SELECT
213 CAST({} AS VARCHAR) as value,
214 COUNT(*) as count
215 FROM data
216 WHERE {} IS NOT NULL
217 GROUP BY {}
218 ),
219 totals AS (
220 SELECT
221 COUNT(*) as total_cnt,
222 SUM(CASE WHEN {} IS NULL THEN 1 ELSE 0 END) as null_cnt
223 FROM data
224 )
225 SELECT
226 vc.value,
227 vc.count,
228 vc.count * 1.0 / (t.total_cnt - t.null_cnt) as ratio,
229 t.total_cnt as total_count,
230 t.null_cnt as null_count
231 FROM value_counts vc
232 CROSS JOIN totals t
233 ORDER BY vc.count DESC, vc.value
234 "#,
235 self.column, self.column, self.column, self.column
236 );
237
238 let df = ctx.sql(&sql).await.map_err(|e| {
239 TermError::constraint_evaluation(
240 self.name(),
241 format!("Failed to execute histogram query: {e}"),
242 )
243 })?;
244
245 let batches = df.collect().await?;
246
247 if batches.is_empty() || batches[0].num_rows() == 0 {
248 return Ok(ConstraintResult::skipped("No data to analyze"));
249 }
250
251 let mut buckets = Vec::new();
253 let mut total_count = 0i64;
254 let mut null_count = 0i64;
255
256 for batch in &batches {
257 let values_col = batch.column(0);
259 let value_strings: Vec<String> = match values_col.data_type() {
260 arrow::datatypes::DataType::Utf8 => {
261 let arr = values_col
262 .as_any()
263 .downcast_ref::<arrow::array::StringArray>()
264 .ok_or_else(|| {
265 TermError::constraint_evaluation(
266 self.name(),
267 "Failed to extract string values",
268 )
269 })?;
270 (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
271 }
272 arrow::datatypes::DataType::Utf8View => {
273 let arr = values_col
274 .as_any()
275 .downcast_ref::<StringViewArray>()
276 .ok_or_else(|| {
277 TermError::constraint_evaluation(
278 self.name(),
279 "Failed to extract string view values",
280 )
281 })?;
282 (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
283 }
284 arrow::datatypes::DataType::LargeUtf8 => {
285 let arr = values_col
286 .as_any()
287 .downcast_ref::<LargeStringArray>()
288 .ok_or_else(|| {
289 TermError::constraint_evaluation(
290 self.name(),
291 "Failed to extract large string values",
292 )
293 })?;
294 (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
295 }
296 _ => {
297 return Err(TermError::constraint_evaluation(
298 self.name(),
299 format!("Unexpected value column type: {:?}", values_col.data_type()),
300 ));
301 }
302 };
303
304 let count_array = batch
305 .column(1)
306 .as_any()
307 .downcast_ref::<arrow::array::Int64Array>()
308 .ok_or_else(|| {
309 TermError::constraint_evaluation(self.name(), "Failed to extract counts")
310 })?;
311
312 let ratio_array = batch
313 .column(2)
314 .as_any()
315 .downcast_ref::<arrow::array::Float64Array>()
316 .ok_or_else(|| {
317 TermError::constraint_evaluation(self.name(), "Failed to extract ratios")
318 })?;
319
320 let total_array = batch
321 .column(3)
322 .as_any()
323 .downcast_ref::<arrow::array::Int64Array>()
324 .ok_or_else(|| {
325 TermError::constraint_evaluation(self.name(), "Failed to extract total count")
326 })?;
327
328 let null_array = batch
329 .column(4)
330 .as_any()
331 .downcast_ref::<arrow::array::Int64Array>()
332 .ok_or_else(|| {
333 TermError::constraint_evaluation(self.name(), "Failed to extract null count")
334 })?;
335
336 if batch.num_rows() > 0 {
338 total_count = total_array.value(0);
339 null_count = null_array.value(0);
340 }
341
342 for (i, value) in value_strings.into_iter().enumerate() {
344 let count = count_array.value(i);
345 let ratio = ratio_array.value(i);
346
347 buckets.push(HistogramBucket {
348 value,
349 count,
350 ratio,
351 });
352 }
353 }
354
355 let histogram = Histogram::new(buckets, total_count, null_count);
357
358 let assertion_result = (self.assertion)(&histogram);
360
361 let status = if assertion_result {
362 ConstraintStatus::Success
363 } else {
364 ConstraintStatus::Failure
365 };
366
367 let message = if status == ConstraintStatus::Failure {
368 let most_common_pct = histogram.most_common_ratio() * 100.0;
369 let null_pct = histogram.null_ratio() * 100.0;
370 Some(format!(
371 "Histogram assertion '{}' failed for column '{}'. Distribution: {} distinct values, most common ratio: {most_common_pct:.2}%, null ratio: {null_pct:.2}%",
372 self.assertion_description,
373 self.column,
374 histogram.distinct_count
375 ))
376 } else {
377 None
378 };
379
380 Ok(ConstraintResult {
382 status,
383 metric: Some(histogram.entropy()),
384 message,
385 })
386 }
387
388 fn name(&self) -> &str {
389 "histogram"
390 }
391
392 fn column(&self) -> Option<&str> {
393 Some(&self.column)
394 }
395
396 fn metadata(&self) -> ConstraintMetadata {
397 ConstraintMetadata::for_column(&self.column)
398 .with_description(format!(
399 "Analyzes value distribution in column '{}' and applies assertion: {}",
400 self.column, self.assertion_description
401 ))
402 .with_custom("assertion", &self.assertion_description)
403 .with_custom("constraint_type", "histogram")
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use crate::core::ConstraintStatus;
411 use arrow::array::StringArray;
412 use arrow::datatypes::{DataType, Field, Schema};
413 use arrow::record_batch::RecordBatch;
414 use datafusion::datasource::MemTable;
415 use std::sync::Arc;
416
417 async fn create_test_context_with_data(values: Vec<Option<&str>>) -> SessionContext {
418 let ctx = SessionContext::new();
419
420 let schema = Arc::new(Schema::new(vec![Field::new(
421 "test_col",
422 DataType::Utf8,
423 true,
424 )]));
425
426 let array = StringArray::from(values);
427 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
428
429 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
430 ctx.register_table("data", Arc::new(provider)).unwrap();
431
432 ctx
433 }
434
435 #[test]
436 fn test_histogram_basic() {
437 let buckets = vec![
438 HistogramBucket {
439 value: "A".to_string(),
440 count: 50,
441 ratio: 0.5,
442 },
443 HistogramBucket {
444 value: "B".to_string(),
445 count: 30,
446 ratio: 0.3,
447 },
448 HistogramBucket {
449 value: "C".to_string(),
450 count: 20,
451 ratio: 0.2,
452 },
453 ];
454
455 let histogram = Histogram::new(buckets, 100, 0);
456
457 assert_eq!(histogram.most_common_ratio(), 0.5);
458 assert_eq!(histogram.least_common_ratio(), 0.2);
459 assert_eq!(histogram.bucket_count(), 3);
460 assert_eq!(histogram.null_ratio(), 0.0);
461 }
462
463 #[test]
464 fn test_histogram_entropy() {
465 let uniform_buckets = vec![
467 HistogramBucket {
468 value: "A".to_string(),
469 count: 25,
470 ratio: 0.25,
471 },
472 HistogramBucket {
473 value: "B".to_string(),
474 count: 25,
475 ratio: 0.25,
476 },
477 HistogramBucket {
478 value: "C".to_string(),
479 count: 25,
480 ratio: 0.25,
481 },
482 HistogramBucket {
483 value: "D".to_string(),
484 count: 25,
485 ratio: 0.25,
486 },
487 ];
488
489 let uniform_hist = Histogram::new(uniform_buckets, 100, 0);
490
491 let skewed_buckets = vec![
493 HistogramBucket {
494 value: "A".to_string(),
495 count: 90,
496 ratio: 0.9,
497 },
498 HistogramBucket {
499 value: "B".to_string(),
500 count: 10,
501 ratio: 0.1,
502 },
503 ];
504
505 let skewed_hist = Histogram::new(skewed_buckets, 100, 0);
506
507 assert!(uniform_hist.entropy() > skewed_hist.entropy());
508 }
509
510 #[tokio::test]
511 async fn test_most_common_ratio_constraint() {
512 let values = vec![
514 Some("A"),
515 Some("A"),
516 Some("A"),
517 Some("A"),
518 Some("A"),
519 Some("A"),
520 Some("B"),
521 Some("B"),
522 Some("C"),
523 Some("C"),
524 ];
525 let ctx = create_test_context_with_data(values).await;
526
527 let constraint = HistogramConstraint::new_with_description(
529 "test_col",
530 Arc::new(|hist| hist.most_common_ratio() < 0.5),
531 "most common value appears less than 50%",
532 );
533
534 let result = constraint.evaluate(&ctx).await.unwrap();
535 assert_eq!(result.status, ConstraintStatus::Failure);
536 assert!(result.message.is_some());
537
538 let constraint =
540 HistogramConstraint::new("test_col", Arc::new(|hist| hist.most_common_ratio() < 0.7));
541
542 let result = constraint.evaluate(&ctx).await.unwrap();
543 assert_eq!(result.status, ConstraintStatus::Success);
544 }
545
546 #[tokio::test]
547 async fn test_bucket_count_constraint() {
548 let values = vec![
550 Some("RED"),
551 Some("BLUE"),
552 Some("GREEN"),
553 Some("YELLOW"),
554 Some("RED"),
555 Some("BLUE"),
556 ];
557 let ctx = create_test_context_with_data(values).await;
558
559 let constraint = HistogramConstraint::new_with_description(
560 "test_col",
561 Arc::new(|hist| hist.bucket_count() >= 3 && hist.bucket_count() <= 5),
562 "has between 3 and 5 distinct values",
563 );
564
565 let result = constraint.evaluate(&ctx).await.unwrap();
566 assert_eq!(result.status, ConstraintStatus::Success);
567 }
568
569 #[tokio::test]
570 async fn test_uniform_distribution_check() {
571 let values = vec![
573 Some("A"),
574 Some("A"),
575 Some("B"),
576 Some("B"),
577 Some("C"),
578 Some("C"),
579 Some("D"),
580 Some("D"),
581 ];
582 let ctx = create_test_context_with_data(values).await;
583
584 let constraint =
585 HistogramConstraint::new("test_col", Arc::new(|hist| hist.is_roughly_uniform(1.5)));
586
587 let result = constraint.evaluate(&ctx).await.unwrap();
588 assert_eq!(result.status, ConstraintStatus::Success);
589 }
590
591 #[tokio::test]
592 async fn test_power_law_distribution() {
593 let values = vec![
595 Some("Popular1"),
596 Some("Popular1"),
597 Some("Popular1"),
598 Some("Popular1"),
599 Some("Popular2"),
600 Some("Popular2"),
601 Some("Popular2"),
602 Some("Rare1"),
603 Some("Rare2"),
604 Some("Rare3"),
605 ];
606 let ctx = create_test_context_with_data(values).await;
607
608 let constraint = HistogramConstraint::new_with_description(
609 "test_col",
610 Arc::new(|hist| hist.follows_power_law(2, 0.7)),
611 "top 2 values account for 70% of distribution",
612 );
613
614 let result = constraint.evaluate(&ctx).await.unwrap();
615 assert_eq!(result.status, ConstraintStatus::Success);
616 }
617
618 #[tokio::test]
619 async fn test_with_nulls() {
620 let values = vec![
621 Some("A"),
622 Some("A"),
623 None,
624 None,
625 None,
626 Some("B"),
627 Some("B"),
628 Some("C"),
629 ];
630 let ctx = create_test_context_with_data(values).await;
631
632 let constraint = HistogramConstraint::new(
633 "test_col",
634 Arc::new(|hist| hist.null_ratio() > 0.3 && hist.null_ratio() < 0.4),
635 );
636
637 let result = constraint.evaluate(&ctx).await.unwrap();
638 assert_eq!(result.status, ConstraintStatus::Success);
639 }
640
641 #[tokio::test]
642 async fn test_empty_data() {
643 let ctx = create_test_context_with_data(vec![]).await;
644
645 let constraint = HistogramConstraint::new("test_col", Arc::new(|_| true));
646
647 let result = constraint.evaluate(&ctx).await.unwrap();
648 assert_eq!(result.status, ConstraintStatus::Skipped);
649 }
650
651 #[tokio::test]
652 async fn test_specific_value_check() {
653 let values = vec![
654 Some("PENDING"),
655 Some("PENDING"),
656 Some("APPROVED"),
657 Some("APPROVED"),
658 Some("APPROVED"),
659 Some("REJECTED"),
660 ];
661 let ctx = create_test_context_with_data(values).await;
662
663 let constraint = HistogramConstraint::new_with_description(
664 "test_col",
665 Arc::new(|hist| {
666 hist.get_value_ratio("APPROVED").unwrap_or(0.0) > 0.4
668 }),
669 "APPROVED status is most common",
670 );
671
672 let result = constraint.evaluate(&ctx).await.unwrap();
673 assert_eq!(result.status, ConstraintStatus::Success);
674 }
675
676 #[tokio::test]
677 async fn test_top_n_values() {
678 let values = vec![
679 Some("A"),
680 Some("A"),
681 Some("A"),
682 Some("A"), Some("B"),
684 Some("B"),
685 Some("B"), Some("C"),
687 Some("C"), Some("D"), ];
690 let ctx = create_test_context_with_data(values).await;
691
692 let constraint = HistogramConstraint::new(
693 "test_col",
694 Arc::new(|hist| {
695 let top_2 = hist.top_n(2);
696 top_2.len() == 2 && top_2[0].1 == 0.4 && top_2[1].1 == 0.3
697 }),
698 );
699
700 let result = constraint.evaluate(&ctx).await.unwrap();
701 assert_eq!(result.status, ConstraintStatus::Success);
702 }
703
704 #[tokio::test]
705 async fn test_numeric_data_histogram() {
706 use arrow::array::Int64Array;
707 use arrow::datatypes::{DataType, Field, Schema};
708
709 let ctx = SessionContext::new();
710
711 let schema = Arc::new(Schema::new(vec![Field::new("age", DataType::Int64, true)]));
712
713 let values = vec![
714 Some(25),
715 Some(25),
716 Some(30),
717 Some(30),
718 Some(30),
719 Some(35),
720 Some(35),
721 Some(40),
722 Some(45),
723 Some(50),
724 ];
725 let array = Int64Array::from(values);
726 let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
727
728 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
729 ctx.register_table("data", Arc::new(provider)).unwrap();
730
731 let constraint = HistogramConstraint::new_with_description(
732 "age",
733 Arc::new(|hist| {
734 hist.bucket_count() >= 5 && hist.most_common_ratio() < 0.4
736 }),
737 "age distribution is reasonable",
738 );
739
740 let result = constraint.evaluate(&ctx).await.unwrap();
741 assert_eq!(result.status, ConstraintStatus::Success);
742 }
743}