term_guard/constraints/
histogram.rs

1//! Histogram analysis constraint for value distribution analysis.
2
3use crate::core::{Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
4use crate::prelude::*;
5use arrow::array::{Array, LargeStringArray, StringViewArray};
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use std::fmt;
9use std::sync::Arc;
10use tracing::instrument;
11
12/// A bucket in a histogram representing a value and its frequency information.
13#[derive(Debug, Clone, PartialEq)]
14pub struct HistogramBucket {
15    /// The value in this bucket
16    pub value: String,
17    /// The count of occurrences
18    pub count: i64,
19    /// The ratio of this value to the total count
20    pub ratio: f64,
21}
22
23/// A histogram representing the distribution of values in a column.
24#[derive(Debug, Clone)]
25pub struct Histogram {
26    /// The buckets in the histogram, ordered by frequency (descending)
27    pub buckets: Vec<HistogramBucket>,
28    /// Total number of values (including nulls if present)
29    pub total_count: i64,
30    /// Number of distinct values
31    pub distinct_count: usize,
32    /// Number of null values
33    pub null_count: i64,
34}
35
36impl Histogram {
37    /// Creates a new histogram from buckets.
38    pub fn new(buckets: Vec<HistogramBucket>, total_count: i64, null_count: i64) -> Self {
39        let distinct_count = buckets.len();
40        Self {
41            buckets,
42            total_count,
43            distinct_count,
44            null_count,
45        }
46    }
47
48    /// Returns the ratio of the most common value.
49    pub fn most_common_ratio(&self) -> f64 {
50        self.buckets.first().map(|b| b.ratio).unwrap_or(0.0)
51    }
52
53    /// Returns the ratio of the least common value.
54    pub fn least_common_ratio(&self) -> f64 {
55        self.buckets.last().map(|b| b.ratio).unwrap_or(0.0)
56    }
57
58    /// Returns the number of buckets (distinct values).
59    pub fn bucket_count(&self) -> usize {
60        self.buckets.len()
61    }
62
63    /// Returns the top N most common values and their ratios.
64    pub fn top_n(&self, n: usize) -> Vec<(&str, f64)> {
65        self.buckets
66            .iter()
67            .take(n)
68            .map(|b| (b.value.as_str(), b.ratio))
69            .collect()
70    }
71
72    /// Checks if the distribution is roughly uniform (all values have similar frequencies).
73    ///
74    /// A distribution is considered roughly uniform if the ratio between the most common
75    /// and least common values is less than the threshold (default 1.5).
76    pub fn is_roughly_uniform(&self, threshold: f64) -> bool {
77        if self.buckets.is_empty() {
78            return true;
79        }
80
81        let max_ratio = self.most_common_ratio();
82        let min_ratio = self.least_common_ratio();
83
84        if min_ratio == 0.0 {
85            return false;
86        }
87
88        max_ratio / min_ratio <= threshold
89    }
90
91    /// Gets the ratio for a specific value, if it exists in the histogram.
92    pub fn get_value_ratio(&self, value: &str) -> Option<f64> {
93        self.buckets
94            .iter()
95            .find(|b| b.value == value)
96            .map(|b| b.ratio)
97    }
98
99    /// Returns the entropy of the distribution.
100    ///
101    /// Higher entropy indicates more uniform distribution.
102    pub fn entropy(&self) -> f64 {
103        self.buckets
104            .iter()
105            .filter(|b| b.ratio > 0.0)
106            .map(|b| -b.ratio * b.ratio.ln())
107            .sum()
108    }
109
110    /// Checks if the distribution follows a power law (few values dominate).
111    ///
112    /// Returns true if the top `n` values account for more than `threshold` of the distribution.
113    pub fn follows_power_law(&self, top_n: usize, threshold: f64) -> bool {
114        let top_sum: f64 = self.buckets.iter().take(top_n).map(|b| b.ratio).sum();
115        top_sum >= threshold
116    }
117
118    /// Returns the null ratio in the data.
119    pub fn null_ratio(&self) -> f64 {
120        if self.total_count == 0 {
121            0.0
122        } else {
123            self.null_count as f64 / self.total_count as f64
124        }
125    }
126}
127
128/// Type alias for histogram assertion function.
129pub type HistogramAssertion = Arc<dyn Fn(&Histogram) -> bool + Send + Sync>;
130
131/// A constraint that analyzes value distribution in a column and applies custom assertions.
132///
133/// This constraint computes a histogram of value frequencies and allows custom assertion
134/// functions to validate the distribution characteristics.
135///
136/// # Examples
137///
138/// ```rust
139/// use term_guard::constraints::{HistogramConstraint, Histogram};
140/// use term_guard::core::Constraint;
141/// use std::sync::Arc;
142///
143/// // Check that no single value dominates
144/// let constraint = HistogramConstraint::new("status", Arc::new(|hist: &Histogram| {
145///     hist.most_common_ratio() < 0.5
146/// }));
147///
148/// // Verify distribution has expected number of categories
149/// let constraint = HistogramConstraint::new("category", Arc::new(|hist| {
150///     hist.bucket_count() >= 5 && hist.bucket_count() <= 10
151/// }));
152/// ```
153#[derive(Clone)]
154pub struct HistogramConstraint {
155    column: String,
156    assertion: HistogramAssertion,
157    assertion_description: String,
158}
159
160impl fmt::Debug for HistogramConstraint {
161    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162        f.debug_struct("HistogramConstraint")
163            .field("column", &self.column)
164            .field("assertion_description", &self.assertion_description)
165            .finish()
166    }
167}
168
169impl HistogramConstraint {
170    /// Creates a new histogram constraint.
171    ///
172    /// # Arguments
173    ///
174    /// * `column` - The column to analyze
175    /// * `assertion` - The assertion function to apply to the histogram
176    pub fn new(column: impl Into<String>, assertion: HistogramAssertion) -> Self {
177        Self {
178            column: column.into(),
179            assertion,
180            assertion_description: "custom assertion".to_string(),
181        }
182    }
183
184    /// Creates a new histogram constraint with a description.
185    ///
186    /// # Arguments
187    ///
188    /// * `column` - The column to analyze
189    /// * `assertion` - The assertion function to apply to the histogram
190    /// * `description` - A description of what the assertion checks
191    pub fn new_with_description(
192        column: impl Into<String>,
193        assertion: HistogramAssertion,
194        description: impl Into<String>,
195    ) -> Self {
196        Self {
197            column: column.into(),
198            assertion,
199            assertion_description: description.into(),
200        }
201    }
202}
203
204#[async_trait]
205impl Constraint for HistogramConstraint {
206    #[instrument(skip(self, ctx), fields(column = %self.column))]
207    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
208        // SQL query to compute value frequencies
209        let sql = format!(
210            r#"
211            WITH value_counts AS (
212                SELECT 
213                    CAST({} AS VARCHAR) as value,
214                    COUNT(*) as count
215                FROM data
216                WHERE {} IS NOT NULL
217                GROUP BY {}
218            ),
219            totals AS (
220                SELECT 
221                    COUNT(*) as total_cnt,
222                    SUM(CASE WHEN {} IS NULL THEN 1 ELSE 0 END) as null_cnt
223                FROM data
224            )
225            SELECT 
226                vc.value,
227                vc.count,
228                vc.count * 1.0 / (t.total_cnt - t.null_cnt) as ratio,
229                t.total_cnt as total_count,
230                t.null_cnt as null_count
231            FROM value_counts vc
232            CROSS JOIN totals t
233            ORDER BY vc.count DESC, vc.value
234            "#,
235            self.column, self.column, self.column, self.column
236        );
237
238        let df = ctx.sql(&sql).await.map_err(|e| {
239            TermError::constraint_evaluation(
240                self.name(),
241                format!("Failed to execute histogram query: {e}"),
242            )
243        })?;
244
245        let batches = df.collect().await?;
246
247        if batches.is_empty() || batches[0].num_rows() == 0 {
248            return Ok(ConstraintResult::skipped("No data to analyze"));
249        }
250
251        // Extract histogram data from results
252        let mut buckets = Vec::new();
253        let mut total_count = 0i64;
254        let mut null_count = 0i64;
255
256        for batch in &batches {
257            // DataFusion might return various string types
258            let values_col = batch.column(0);
259            let value_strings: Vec<String> = match values_col.data_type() {
260                arrow::datatypes::DataType::Utf8 => {
261                    let arr = values_col
262                        .as_any()
263                        .downcast_ref::<arrow::array::StringArray>()
264                        .ok_or_else(|| {
265                            TermError::constraint_evaluation(
266                                self.name(),
267                                "Failed to extract string values",
268                            )
269                        })?;
270                    (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
271                }
272                arrow::datatypes::DataType::Utf8View => {
273                    let arr = values_col
274                        .as_any()
275                        .downcast_ref::<StringViewArray>()
276                        .ok_or_else(|| {
277                            TermError::constraint_evaluation(
278                                self.name(),
279                                "Failed to extract string view values",
280                            )
281                        })?;
282                    (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
283                }
284                arrow::datatypes::DataType::LargeUtf8 => {
285                    let arr = values_col
286                        .as_any()
287                        .downcast_ref::<LargeStringArray>()
288                        .ok_or_else(|| {
289                            TermError::constraint_evaluation(
290                                self.name(),
291                                "Failed to extract large string values",
292                            )
293                        })?;
294                    (0..arr.len()).map(|i| arr.value(i).to_string()).collect()
295                }
296                _ => {
297                    return Err(TermError::constraint_evaluation(
298                        self.name(),
299                        format!("Unexpected value column type: {:?}", values_col.data_type()),
300                    ));
301                }
302            };
303
304            let count_array = batch
305                .column(1)
306                .as_any()
307                .downcast_ref::<arrow::array::Int64Array>()
308                .ok_or_else(|| {
309                    TermError::constraint_evaluation(self.name(), "Failed to extract counts")
310                })?;
311
312            let ratio_array = batch
313                .column(2)
314                .as_any()
315                .downcast_ref::<arrow::array::Float64Array>()
316                .ok_or_else(|| {
317                    TermError::constraint_evaluation(self.name(), "Failed to extract ratios")
318                })?;
319
320            let total_array = batch
321                .column(3)
322                .as_any()
323                .downcast_ref::<arrow::array::Int64Array>()
324                .ok_or_else(|| {
325                    TermError::constraint_evaluation(self.name(), "Failed to extract total count")
326                })?;
327
328            let null_array = batch
329                .column(4)
330                .as_any()
331                .downcast_ref::<arrow::array::Int64Array>()
332                .ok_or_else(|| {
333                    TermError::constraint_evaluation(self.name(), "Failed to extract null count")
334                })?;
335
336            // Get total and null counts from first row
337            if batch.num_rows() > 0 {
338                total_count = total_array.value(0);
339                null_count = null_array.value(0);
340            }
341
342            // Collect buckets
343            for (i, value) in value_strings.into_iter().enumerate() {
344                let count = count_array.value(i);
345                let ratio = ratio_array.value(i);
346
347                buckets.push(HistogramBucket {
348                    value,
349                    count,
350                    ratio,
351                });
352            }
353        }
354
355        // Create histogram
356        let histogram = Histogram::new(buckets, total_count, null_count);
357
358        // Apply assertion
359        let assertion_result = (self.assertion)(&histogram);
360
361        let status = if assertion_result {
362            ConstraintStatus::Success
363        } else {
364            ConstraintStatus::Failure
365        };
366
367        let message = if status == ConstraintStatus::Failure {
368            let most_common_pct = histogram.most_common_ratio() * 100.0;
369            let null_pct = histogram.null_ratio() * 100.0;
370            Some(format!(
371                "Histogram assertion '{}' failed for column '{}'. Distribution: {} distinct values, most common ratio: {most_common_pct:.2}%, null ratio: {null_pct:.2}%",
372                self.assertion_description,
373                self.column,
374                histogram.distinct_count
375            ))
376        } else {
377            None
378        };
379
380        // Store histogram entropy as metric
381        Ok(ConstraintResult {
382            status,
383            metric: Some(histogram.entropy()),
384            message,
385        })
386    }
387
388    fn name(&self) -> &str {
389        "histogram"
390    }
391
392    fn column(&self) -> Option<&str> {
393        Some(&self.column)
394    }
395
396    fn metadata(&self) -> ConstraintMetadata {
397        ConstraintMetadata::for_column(&self.column)
398            .with_description(format!(
399                "Analyzes value distribution in column '{}' and applies assertion: {}",
400                self.column, self.assertion_description
401            ))
402            .with_custom("assertion", &self.assertion_description)
403            .with_custom("constraint_type", "histogram")
404    }
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use crate::core::ConstraintStatus;
411    use arrow::array::StringArray;
412    use arrow::datatypes::{DataType, Field, Schema};
413    use arrow::record_batch::RecordBatch;
414    use datafusion::datasource::MemTable;
415    use std::sync::Arc;
416
417    async fn create_test_context_with_data(values: Vec<Option<&str>>) -> SessionContext {
418        let ctx = SessionContext::new();
419
420        let schema = Arc::new(Schema::new(vec![Field::new(
421            "test_col",
422            DataType::Utf8,
423            true,
424        )]));
425
426        let array = StringArray::from(values);
427        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
428
429        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
430        ctx.register_table("data", Arc::new(provider)).unwrap();
431
432        ctx
433    }
434
435    #[test]
436    fn test_histogram_basic() {
437        let buckets = vec![
438            HistogramBucket {
439                value: "A".to_string(),
440                count: 50,
441                ratio: 0.5,
442            },
443            HistogramBucket {
444                value: "B".to_string(),
445                count: 30,
446                ratio: 0.3,
447            },
448            HistogramBucket {
449                value: "C".to_string(),
450                count: 20,
451                ratio: 0.2,
452            },
453        ];
454
455        let histogram = Histogram::new(buckets, 100, 0);
456
457        assert_eq!(histogram.most_common_ratio(), 0.5);
458        assert_eq!(histogram.least_common_ratio(), 0.2);
459        assert_eq!(histogram.bucket_count(), 3);
460        assert_eq!(histogram.null_ratio(), 0.0);
461    }
462
463    #[test]
464    fn test_histogram_entropy() {
465        // Uniform distribution should have higher entropy
466        let uniform_buckets = vec![
467            HistogramBucket {
468                value: "A".to_string(),
469                count: 25,
470                ratio: 0.25,
471            },
472            HistogramBucket {
473                value: "B".to_string(),
474                count: 25,
475                ratio: 0.25,
476            },
477            HistogramBucket {
478                value: "C".to_string(),
479                count: 25,
480                ratio: 0.25,
481            },
482            HistogramBucket {
483                value: "D".to_string(),
484                count: 25,
485                ratio: 0.25,
486            },
487        ];
488
489        let uniform_hist = Histogram::new(uniform_buckets, 100, 0);
490
491        // Skewed distribution should have lower entropy
492        let skewed_buckets = vec![
493            HistogramBucket {
494                value: "A".to_string(),
495                count: 90,
496                ratio: 0.9,
497            },
498            HistogramBucket {
499                value: "B".to_string(),
500                count: 10,
501                ratio: 0.1,
502            },
503        ];
504
505        let skewed_hist = Histogram::new(skewed_buckets, 100, 0);
506
507        assert!(uniform_hist.entropy() > skewed_hist.entropy());
508    }
509
510    #[tokio::test]
511    async fn test_most_common_ratio_constraint() {
512        // Create data where "A" appears 60% of the time
513        let values = vec![
514            Some("A"),
515            Some("A"),
516            Some("A"),
517            Some("A"),
518            Some("A"),
519            Some("A"),
520            Some("B"),
521            Some("B"),
522            Some("C"),
523            Some("C"),
524        ];
525        let ctx = create_test_context_with_data(values).await;
526
527        // Constraint that fails: most common should be < 50%
528        let constraint = HistogramConstraint::new_with_description(
529            "test_col",
530            Arc::new(|hist| hist.most_common_ratio() < 0.5),
531            "most common value appears less than 50%",
532        );
533
534        let result = constraint.evaluate(&ctx).await.unwrap();
535        assert_eq!(result.status, ConstraintStatus::Failure);
536        assert!(result.message.is_some());
537
538        // Constraint that passes: most common should be < 70%
539        let constraint =
540            HistogramConstraint::new("test_col", Arc::new(|hist| hist.most_common_ratio() < 0.7));
541
542        let result = constraint.evaluate(&ctx).await.unwrap();
543        assert_eq!(result.status, ConstraintStatus::Success);
544    }
545
546    #[tokio::test]
547    async fn test_bucket_count_constraint() {
548        // Create data with 4 distinct values
549        let values = vec![
550            Some("RED"),
551            Some("BLUE"),
552            Some("GREEN"),
553            Some("YELLOW"),
554            Some("RED"),
555            Some("BLUE"),
556        ];
557        let ctx = create_test_context_with_data(values).await;
558
559        let constraint = HistogramConstraint::new_with_description(
560            "test_col",
561            Arc::new(|hist| hist.bucket_count() >= 3 && hist.bucket_count() <= 5),
562            "has between 3 and 5 distinct values",
563        );
564
565        let result = constraint.evaluate(&ctx).await.unwrap();
566        assert_eq!(result.status, ConstraintStatus::Success);
567    }
568
569    #[tokio::test]
570    async fn test_uniform_distribution_check() {
571        // Create roughly uniform distribution
572        let values = vec![
573            Some("A"),
574            Some("A"),
575            Some("B"),
576            Some("B"),
577            Some("C"),
578            Some("C"),
579            Some("D"),
580            Some("D"),
581        ];
582        let ctx = create_test_context_with_data(values).await;
583
584        let constraint =
585            HistogramConstraint::new("test_col", Arc::new(|hist| hist.is_roughly_uniform(1.5)));
586
587        let result = constraint.evaluate(&ctx).await.unwrap();
588        assert_eq!(result.status, ConstraintStatus::Success);
589    }
590
591    #[tokio::test]
592    async fn test_power_law_distribution() {
593        // Create power law distribution where top 2 values dominate
594        let values = vec![
595            Some("Popular1"),
596            Some("Popular1"),
597            Some("Popular1"),
598            Some("Popular1"),
599            Some("Popular2"),
600            Some("Popular2"),
601            Some("Popular2"),
602            Some("Rare1"),
603            Some("Rare2"),
604            Some("Rare3"),
605        ];
606        let ctx = create_test_context_with_data(values).await;
607
608        let constraint = HistogramConstraint::new_with_description(
609            "test_col",
610            Arc::new(|hist| hist.follows_power_law(2, 0.7)),
611            "top 2 values account for 70% of distribution",
612        );
613
614        let result = constraint.evaluate(&ctx).await.unwrap();
615        assert_eq!(result.status, ConstraintStatus::Success);
616    }
617
618    #[tokio::test]
619    async fn test_with_nulls() {
620        let values = vec![
621            Some("A"),
622            Some("A"),
623            None,
624            None,
625            None,
626            Some("B"),
627            Some("B"),
628            Some("C"),
629        ];
630        let ctx = create_test_context_with_data(values).await;
631
632        let constraint = HistogramConstraint::new(
633            "test_col",
634            Arc::new(|hist| hist.null_ratio() > 0.3 && hist.null_ratio() < 0.4),
635        );
636
637        let result = constraint.evaluate(&ctx).await.unwrap();
638        assert_eq!(result.status, ConstraintStatus::Success);
639    }
640
641    #[tokio::test]
642    async fn test_empty_data() {
643        let ctx = create_test_context_with_data(vec![]).await;
644
645        let constraint = HistogramConstraint::new("test_col", Arc::new(|_| true));
646
647        let result = constraint.evaluate(&ctx).await.unwrap();
648        assert_eq!(result.status, ConstraintStatus::Skipped);
649    }
650
651    #[tokio::test]
652    async fn test_specific_value_check() {
653        let values = vec![
654            Some("PENDING"),
655            Some("PENDING"),
656            Some("APPROVED"),
657            Some("APPROVED"),
658            Some("APPROVED"),
659            Some("REJECTED"),
660        ];
661        let ctx = create_test_context_with_data(values).await;
662
663        let constraint = HistogramConstraint::new_with_description(
664            "test_col",
665            Arc::new(|hist| {
666                // Check that APPROVED is the most common status
667                hist.get_value_ratio("APPROVED").unwrap_or(0.0) > 0.4
668            }),
669            "APPROVED status is most common",
670        );
671
672        let result = constraint.evaluate(&ctx).await.unwrap();
673        assert_eq!(result.status, ConstraintStatus::Success);
674    }
675
676    #[tokio::test]
677    async fn test_top_n_values() {
678        let values = vec![
679            Some("A"),
680            Some("A"),
681            Some("A"),
682            Some("A"), // 40%
683            Some("B"),
684            Some("B"),
685            Some("B"), // 30%
686            Some("C"),
687            Some("C"), // 20%
688            Some("D"), // 10%
689        ];
690        let ctx = create_test_context_with_data(values).await;
691
692        let constraint = HistogramConstraint::new(
693            "test_col",
694            Arc::new(|hist| {
695                let top_2 = hist.top_n(2);
696                top_2.len() == 2 && top_2[0].1 == 0.4 && top_2[1].1 == 0.3
697            }),
698        );
699
700        let result = constraint.evaluate(&ctx).await.unwrap();
701        assert_eq!(result.status, ConstraintStatus::Success);
702    }
703
704    #[tokio::test]
705    async fn test_numeric_data_histogram() {
706        use arrow::array::Int64Array;
707        use arrow::datatypes::{DataType, Field, Schema};
708
709        let ctx = SessionContext::new();
710
711        let schema = Arc::new(Schema::new(vec![Field::new("age", DataType::Int64, true)]));
712
713        let values = vec![
714            Some(25),
715            Some(25),
716            Some(30),
717            Some(30),
718            Some(30),
719            Some(35),
720            Some(35),
721            Some(40),
722            Some(45),
723            Some(50),
724        ];
725        let array = Int64Array::from(values);
726        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap();
727
728        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
729        ctx.register_table("data", Arc::new(provider)).unwrap();
730
731        let constraint = HistogramConstraint::new_with_description(
732            "age",
733            Arc::new(|hist| {
734                // Check we have reasonable age distribution
735                hist.bucket_count() >= 5 && hist.most_common_ratio() < 0.4
736            }),
737            "age distribution is reasonable",
738        );
739
740        let result = constraint.evaluate(&ctx).await.unwrap();
741        assert_eq!(result.status, ConstraintStatus::Success);
742    }
743}