Skip to main content

datasynth_eval/quality/
consistency.rs

1//! Cross-field consistency evaluation.
2//!
3//! Validates consistency rules across related fields within records.
4
5use crate::error::EvalResult;
6use chrono::{Datelike, NaiveDate};
7use rust_decimal::Decimal;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// Results of consistency analysis.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ConsistencyAnalysis {
15    /// Total records analyzed.
16    pub total_records: usize,
17    /// Per-rule results.
18    pub rule_results: Vec<RuleResult>,
19    /// Overall pass rate (0.0-1.0).
20    pub pass_rate: f64,
21    /// Total violations.
22    pub total_violations: usize,
23    /// Violations by rule type.
24    pub violations_by_type: HashMap<String, usize>,
25}
26
27/// Result for a single consistency rule.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct RuleResult {
30    /// Rule name.
31    pub rule_name: String,
32    /// Rule description.
33    pub description: String,
34    /// Number of records checked.
35    pub records_checked: usize,
36    /// Number of records passed.
37    pub records_passed: usize,
38    /// Pass rate for this rule.
39    pub pass_rate: f64,
40    /// Example violations.
41    pub example_violations: Vec<String>,
42}
43
44/// A consistency rule definition.
45#[derive(Debug, Clone)]
46pub struct ConsistencyRule {
47    /// Rule name.
48    pub name: String,
49    /// Rule description.
50    pub description: String,
51    /// Rule type.
52    pub rule_type: RuleType,
53}
54
55/// Type of consistency rule.
56pub enum RuleType {
57    /// Date ordering (e.g., document_date <= posting_date).
58    DateOrdering {
59        earlier_field: String,
60        later_field: String,
61    },
62    /// Mutual exclusion (e.g., debit XOR credit).
63    MutualExclusion { field1: String, field2: String },
64    /// Fiscal period matches date.
65    FiscalPeriodDateAlignment {
66        date_field: String,
67        period_field: String,
68        year_field: String,
69    },
70    /// Amount sign consistency.
71    AmountSign {
72        amount_field: String,
73        indicator_field: String,
74        positive_indicator: String,
75    },
76    /// Required if present (if field A has value, field B must too).
77    RequiredIfPresent {
78        trigger_field: String,
79        required_field: String,
80    },
81    /// Value range (field must be within range).
82    ValueRange {
83        field: String,
84        min: Option<Decimal>,
85        max: Option<Decimal>,
86    },
87    /// Custom rule with closure.
88    Custom {
89        checker: Arc<dyn Fn(&ConsistencyRecord) -> bool + Send + Sync>,
90    },
91}
92
93impl Clone for RuleType {
94    fn clone(&self) -> Self {
95        match self {
96            RuleType::DateOrdering {
97                earlier_field,
98                later_field,
99            } => RuleType::DateOrdering {
100                earlier_field: earlier_field.clone(),
101                later_field: later_field.clone(),
102            },
103            RuleType::MutualExclusion { field1, field2 } => RuleType::MutualExclusion {
104                field1: field1.clone(),
105                field2: field2.clone(),
106            },
107            RuleType::FiscalPeriodDateAlignment {
108                date_field,
109                period_field,
110                year_field,
111            } => RuleType::FiscalPeriodDateAlignment {
112                date_field: date_field.clone(),
113                period_field: period_field.clone(),
114                year_field: year_field.clone(),
115            },
116            RuleType::AmountSign {
117                amount_field,
118                indicator_field,
119                positive_indicator,
120            } => RuleType::AmountSign {
121                amount_field: amount_field.clone(),
122                indicator_field: indicator_field.clone(),
123                positive_indicator: positive_indicator.clone(),
124            },
125            RuleType::RequiredIfPresent {
126                trigger_field,
127                required_field,
128            } => RuleType::RequiredIfPresent {
129                trigger_field: trigger_field.clone(),
130                required_field: required_field.clone(),
131            },
132            RuleType::ValueRange { field, min, max } => RuleType::ValueRange {
133                field: field.clone(),
134                min: *min,
135                max: *max,
136            },
137            RuleType::Custom { checker } => RuleType::Custom {
138                checker: Arc::clone(checker),
139            },
140        }
141    }
142}
143
144impl std::fmt::Debug for RuleType {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        match self {
147            RuleType::DateOrdering {
148                earlier_field,
149                later_field,
150            } => f
151                .debug_struct("DateOrdering")
152                .field("earlier_field", earlier_field)
153                .field("later_field", later_field)
154                .finish(),
155            RuleType::MutualExclusion { field1, field2 } => f
156                .debug_struct("MutualExclusion")
157                .field("field1", field1)
158                .field("field2", field2)
159                .finish(),
160            RuleType::FiscalPeriodDateAlignment {
161                date_field,
162                period_field,
163                year_field,
164            } => f
165                .debug_struct("FiscalPeriodDateAlignment")
166                .field("date_field", date_field)
167                .field("period_field", period_field)
168                .field("year_field", year_field)
169                .finish(),
170            RuleType::AmountSign {
171                amount_field,
172                indicator_field,
173                positive_indicator,
174            } => f
175                .debug_struct("AmountSign")
176                .field("amount_field", amount_field)
177                .field("indicator_field", indicator_field)
178                .field("positive_indicator", positive_indicator)
179                .finish(),
180            RuleType::RequiredIfPresent {
181                trigger_field,
182                required_field,
183            } => f
184                .debug_struct("RequiredIfPresent")
185                .field("trigger_field", trigger_field)
186                .field("required_field", required_field)
187                .finish(),
188            RuleType::ValueRange { field, min, max } => f
189                .debug_struct("ValueRange")
190                .field("field", field)
191                .field("min", min)
192                .field("max", max)
193                .finish(),
194            RuleType::Custom { .. } => f
195                .debug_struct("Custom")
196                .field("checker", &"<custom_fn>")
197                .finish(),
198        }
199    }
200}
201
202/// A record for consistency checking.
203#[derive(Debug, Clone, Default)]
204pub struct ConsistencyRecord {
205    /// String field values.
206    pub string_fields: HashMap<String, String>,
207    /// Decimal field values.
208    pub decimal_fields: HashMap<String, Decimal>,
209    /// Date field values.
210    pub date_fields: HashMap<String, NaiveDate>,
211    /// Integer field values.
212    pub integer_fields: HashMap<String, i64>,
213    /// Boolean field values.
214    pub boolean_fields: HashMap<String, bool>,
215}
216
217/// Analyzer for cross-field consistency.
218pub struct ConsistencyAnalyzer {
219    /// Rules to check.
220    rules: Vec<ConsistencyRule>,
221    /// Maximum example violations to collect per rule.
222    max_examples: usize,
223}
224
225impl ConsistencyAnalyzer {
226    /// Create a new analyzer with specified rules.
227    pub fn new(rules: Vec<ConsistencyRule>) -> Self {
228        Self {
229            rules,
230            max_examples: 5,
231        }
232    }
233
234    /// Create with default accounting rules.
235    pub fn with_default_rules() -> Self {
236        let rules = vec![
237            ConsistencyRule {
238                name: "date_ordering".to_string(),
239                description: "Document date must be on or before posting date".to_string(),
240                rule_type: RuleType::DateOrdering {
241                    earlier_field: "document_date".to_string(),
242                    later_field: "posting_date".to_string(),
243                },
244            },
245            ConsistencyRule {
246                name: "debit_credit_exclusion".to_string(),
247                description: "Each line must have either debit or credit, not both".to_string(),
248                rule_type: RuleType::MutualExclusion {
249                    field1: "debit_amount".to_string(),
250                    field2: "credit_amount".to_string(),
251                },
252            },
253            ConsistencyRule {
254                name: "fiscal_period_alignment".to_string(),
255                description: "Fiscal period must match posting date".to_string(),
256                rule_type: RuleType::FiscalPeriodDateAlignment {
257                    date_field: "posting_date".to_string(),
258                    period_field: "fiscal_period".to_string(),
259                    year_field: "fiscal_year".to_string(),
260                },
261            },
262        ];
263
264        Self::new(rules)
265    }
266
267    /// Analyze consistency of records.
268    pub fn analyze(&self, records: &[ConsistencyRecord]) -> EvalResult<ConsistencyAnalysis> {
269        let total_records = records.len();
270        let mut rule_results = Vec::new();
271        let mut total_violations = 0;
272        let mut violations_by_type: HashMap<String, usize> = HashMap::new();
273
274        for rule in &self.rules {
275            let mut records_checked = 0;
276            let mut records_passed = 0;
277            let mut example_violations = Vec::new();
278
279            for (idx, record) in records.iter().enumerate() {
280                let applicable = self.is_rule_applicable(rule, record);
281                if !applicable {
282                    continue;
283                }
284
285                records_checked += 1;
286                let passed = self.check_rule(rule, record);
287
288                if passed {
289                    records_passed += 1;
290                } else {
291                    total_violations += 1;
292                    *violations_by_type.entry(rule.name.clone()).or_insert(0) += 1;
293
294                    if example_violations.len() < self.max_examples {
295                        example_violations.push(format!("Record {}: {:?}", idx, record));
296                    }
297                }
298            }
299
300            let pass_rate = if records_checked > 0 {
301                records_passed as f64 / records_checked as f64
302            } else {
303                1.0
304            };
305
306            rule_results.push(RuleResult {
307                rule_name: rule.name.clone(),
308                description: rule.description.clone(),
309                records_checked,
310                records_passed,
311                pass_rate,
312                example_violations,
313            });
314        }
315
316        let total_checked: usize = rule_results.iter().map(|r| r.records_checked).sum();
317        let total_passed: usize = rule_results.iter().map(|r| r.records_passed).sum();
318        let pass_rate = if total_checked > 0 {
319            total_passed as f64 / total_checked as f64
320        } else {
321            1.0
322        };
323
324        Ok(ConsistencyAnalysis {
325            total_records,
326            rule_results,
327            pass_rate,
328            total_violations,
329            violations_by_type,
330        })
331    }
332
333    /// Check if rule is applicable to record (has required fields).
334    fn is_rule_applicable(&self, rule: &ConsistencyRule, record: &ConsistencyRecord) -> bool {
335        match &rule.rule_type {
336            RuleType::DateOrdering {
337                earlier_field,
338                later_field,
339            } => {
340                record.date_fields.contains_key(earlier_field)
341                    && record.date_fields.contains_key(later_field)
342            }
343            RuleType::MutualExclusion { field1, field2 } => {
344                record.decimal_fields.contains_key(field1)
345                    || record.decimal_fields.contains_key(field2)
346            }
347            RuleType::FiscalPeriodDateAlignment {
348                date_field,
349                period_field,
350                year_field,
351            } => {
352                record.date_fields.contains_key(date_field)
353                    && record.integer_fields.contains_key(period_field)
354                    && record.integer_fields.contains_key(year_field)
355            }
356            RuleType::AmountSign {
357                amount_field,
358                indicator_field,
359                ..
360            } => {
361                record.decimal_fields.contains_key(amount_field)
362                    && record.string_fields.contains_key(indicator_field)
363            }
364            RuleType::RequiredIfPresent { trigger_field, .. } => {
365                record.string_fields.contains_key(trigger_field)
366                    || record.decimal_fields.contains_key(trigger_field)
367            }
368            RuleType::ValueRange { field, .. } => record.decimal_fields.contains_key(field),
369            RuleType::Custom { .. } => true,
370        }
371    }
372
373    /// Check if record passes rule.
374    fn check_rule(&self, rule: &ConsistencyRule, record: &ConsistencyRecord) -> bool {
375        match &rule.rule_type {
376            RuleType::DateOrdering {
377                earlier_field,
378                later_field,
379            } => {
380                let earlier = record.date_fields.get(earlier_field);
381                let later = record.date_fields.get(later_field);
382                match (earlier, later) {
383                    (Some(e), Some(l)) => e <= l,
384                    _ => true,
385                }
386            }
387            RuleType::MutualExclusion { field1, field2 } => {
388                let val1 = record
389                    .decimal_fields
390                    .get(field1)
391                    .map(|v| *v != Decimal::ZERO)
392                    .unwrap_or(false);
393                let val2 = record
394                    .decimal_fields
395                    .get(field2)
396                    .map(|v| *v != Decimal::ZERO)
397                    .unwrap_or(false);
398                // XOR: exactly one should be non-zero
399                val1 != val2
400            }
401            RuleType::FiscalPeriodDateAlignment {
402                date_field,
403                period_field,
404                year_field,
405            } => {
406                let date = record.date_fields.get(date_field);
407                let period = record.integer_fields.get(period_field);
408                let year = record.integer_fields.get(year_field);
409
410                match (date, period, year) {
411                    (Some(d), Some(p), Some(y)) => d.month() as i64 == *p && d.year() as i64 == *y,
412                    _ => true,
413                }
414            }
415            RuleType::AmountSign {
416                amount_field,
417                indicator_field,
418                positive_indicator,
419            } => {
420                let amount = record.decimal_fields.get(amount_field);
421                let indicator = record.string_fields.get(indicator_field);
422
423                match (amount, indicator) {
424                    (Some(a), Some(i)) => {
425                        let should_be_positive = i == positive_indicator;
426                        let is_positive = *a >= Decimal::ZERO;
427                        should_be_positive == is_positive
428                    }
429                    _ => true,
430                }
431            }
432            RuleType::RequiredIfPresent {
433                trigger_field,
434                required_field,
435            } => {
436                let has_trigger = record.string_fields.contains_key(trigger_field)
437                    || record.decimal_fields.contains_key(trigger_field);
438
439                if !has_trigger {
440                    return true;
441                }
442
443                record.string_fields.contains_key(required_field)
444                    || record.decimal_fields.contains_key(required_field)
445            }
446            RuleType::ValueRange { field, min, max } => {
447                let value = record.decimal_fields.get(field);
448                match value {
449                    Some(v) => {
450                        let above_min = min.map(|m| *v >= m).unwrap_or(true);
451                        let below_max = max.map(|m| *v <= m).unwrap_or(true);
452                        above_min && below_max
453                    }
454                    None => true,
455                }
456            }
457            RuleType::Custom { checker } => checker(record),
458        }
459    }
460}
461
462impl Default for ConsistencyAnalyzer {
463    fn default() -> Self {
464        Self::with_default_rules()
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471
472    #[test]
473    fn test_date_ordering_pass() {
474        let mut record = ConsistencyRecord::default();
475        record.date_fields.insert(
476            "document_date".to_string(),
477            NaiveDate::from_ymd_opt(2024, 1, 10).unwrap(),
478        );
479        record.date_fields.insert(
480            "posting_date".to_string(),
481            NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(),
482        );
483
484        let analyzer = ConsistencyAnalyzer::with_default_rules();
485        let result = analyzer.analyze(&[record]).unwrap();
486
487        let date_rule = result
488            .rule_results
489            .iter()
490            .find(|r| r.rule_name == "date_ordering")
491            .unwrap();
492        assert_eq!(date_rule.pass_rate, 1.0);
493    }
494
495    #[test]
496    fn test_date_ordering_fail() {
497        let mut record = ConsistencyRecord::default();
498        record.date_fields.insert(
499            "document_date".to_string(),
500            NaiveDate::from_ymd_opt(2024, 1, 20).unwrap(),
501        );
502        record.date_fields.insert(
503            "posting_date".to_string(),
504            NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(),
505        );
506
507        let analyzer = ConsistencyAnalyzer::with_default_rules();
508        let result = analyzer.analyze(&[record]).unwrap();
509
510        let date_rule = result
511            .rule_results
512            .iter()
513            .find(|r| r.rule_name == "date_ordering")
514            .unwrap();
515        assert_eq!(date_rule.pass_rate, 0.0);
516    }
517
518    #[test]
519    fn test_mutual_exclusion() {
520        let mut record = ConsistencyRecord::default();
521        record
522            .decimal_fields
523            .insert("debit_amount".to_string(), Decimal::new(100, 0));
524        record
525            .decimal_fields
526            .insert("credit_amount".to_string(), Decimal::ZERO);
527
528        let analyzer = ConsistencyAnalyzer::with_default_rules();
529        let result = analyzer.analyze(&[record]).unwrap();
530
531        let excl_rule = result
532            .rule_results
533            .iter()
534            .find(|r| r.rule_name == "debit_credit_exclusion")
535            .unwrap();
536        assert_eq!(excl_rule.pass_rate, 1.0);
537    }
538
539    #[test]
540    fn test_mutual_exclusion_fail_both_nonzero() {
541        let mut record = ConsistencyRecord::default();
542        record
543            .decimal_fields
544            .insert("debit_amount".to_string(), Decimal::new(100, 0));
545        record
546            .decimal_fields
547            .insert("credit_amount".to_string(), Decimal::new(50, 0));
548
549        let analyzer = ConsistencyAnalyzer::with_default_rules();
550        let result = analyzer.analyze(&[record]).unwrap();
551
552        let excl_rule = result
553            .rule_results
554            .iter()
555            .find(|r| r.rule_name == "debit_credit_exclusion")
556            .unwrap();
557        assert_eq!(excl_rule.pass_rate, 0.0);
558    }
559}