1use crate::error::EvalResult;
6use chrono::{Datelike, NaiveDate};
7use rust_decimal::Decimal;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ConsistencyAnalysis {
15 pub total_records: usize,
17 pub rule_results: Vec<RuleResult>,
19 pub pass_rate: f64,
21 pub total_violations: usize,
23 pub violations_by_type: HashMap<String, usize>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct RuleResult {
30 pub rule_name: String,
32 pub description: String,
34 pub records_checked: usize,
36 pub records_passed: usize,
38 pub pass_rate: f64,
40 pub example_violations: Vec<String>,
42}
43
44#[derive(Debug, Clone)]
46pub struct ConsistencyRule {
47 pub name: String,
49 pub description: String,
51 pub rule_type: RuleType,
53}
54
55pub enum RuleType {
57 DateOrdering {
59 earlier_field: String,
60 later_field: String,
61 },
62 MutualExclusion { field1: String, field2: String },
64 FiscalPeriodDateAlignment {
66 date_field: String,
67 period_field: String,
68 year_field: String,
69 },
70 AmountSign {
72 amount_field: String,
73 indicator_field: String,
74 positive_indicator: String,
75 },
76 RequiredIfPresent {
78 trigger_field: String,
79 required_field: String,
80 },
81 ValueRange {
83 field: String,
84 min: Option<Decimal>,
85 max: Option<Decimal>,
86 },
87 Custom {
89 checker: Arc<dyn Fn(&ConsistencyRecord) -> bool + Send + Sync>,
90 },
91}
92
93impl Clone for RuleType {
94 fn clone(&self) -> Self {
95 match self {
96 RuleType::DateOrdering {
97 earlier_field,
98 later_field,
99 } => RuleType::DateOrdering {
100 earlier_field: earlier_field.clone(),
101 later_field: later_field.clone(),
102 },
103 RuleType::MutualExclusion { field1, field2 } => RuleType::MutualExclusion {
104 field1: field1.clone(),
105 field2: field2.clone(),
106 },
107 RuleType::FiscalPeriodDateAlignment {
108 date_field,
109 period_field,
110 year_field,
111 } => RuleType::FiscalPeriodDateAlignment {
112 date_field: date_field.clone(),
113 period_field: period_field.clone(),
114 year_field: year_field.clone(),
115 },
116 RuleType::AmountSign {
117 amount_field,
118 indicator_field,
119 positive_indicator,
120 } => RuleType::AmountSign {
121 amount_field: amount_field.clone(),
122 indicator_field: indicator_field.clone(),
123 positive_indicator: positive_indicator.clone(),
124 },
125 RuleType::RequiredIfPresent {
126 trigger_field,
127 required_field,
128 } => RuleType::RequiredIfPresent {
129 trigger_field: trigger_field.clone(),
130 required_field: required_field.clone(),
131 },
132 RuleType::ValueRange { field, min, max } => RuleType::ValueRange {
133 field: field.clone(),
134 min: *min,
135 max: *max,
136 },
137 RuleType::Custom { checker } => RuleType::Custom {
138 checker: Arc::clone(checker),
139 },
140 }
141 }
142}
143
144impl std::fmt::Debug for RuleType {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 match self {
147 RuleType::DateOrdering {
148 earlier_field,
149 later_field,
150 } => f
151 .debug_struct("DateOrdering")
152 .field("earlier_field", earlier_field)
153 .field("later_field", later_field)
154 .finish(),
155 RuleType::MutualExclusion { field1, field2 } => f
156 .debug_struct("MutualExclusion")
157 .field("field1", field1)
158 .field("field2", field2)
159 .finish(),
160 RuleType::FiscalPeriodDateAlignment {
161 date_field,
162 period_field,
163 year_field,
164 } => f
165 .debug_struct("FiscalPeriodDateAlignment")
166 .field("date_field", date_field)
167 .field("period_field", period_field)
168 .field("year_field", year_field)
169 .finish(),
170 RuleType::AmountSign {
171 amount_field,
172 indicator_field,
173 positive_indicator,
174 } => f
175 .debug_struct("AmountSign")
176 .field("amount_field", amount_field)
177 .field("indicator_field", indicator_field)
178 .field("positive_indicator", positive_indicator)
179 .finish(),
180 RuleType::RequiredIfPresent {
181 trigger_field,
182 required_field,
183 } => f
184 .debug_struct("RequiredIfPresent")
185 .field("trigger_field", trigger_field)
186 .field("required_field", required_field)
187 .finish(),
188 RuleType::ValueRange { field, min, max } => f
189 .debug_struct("ValueRange")
190 .field("field", field)
191 .field("min", min)
192 .field("max", max)
193 .finish(),
194 RuleType::Custom { .. } => f
195 .debug_struct("Custom")
196 .field("checker", &"<custom_fn>")
197 .finish(),
198 }
199 }
200}
201
202#[derive(Debug, Clone, Default)]
204pub struct ConsistencyRecord {
205 pub string_fields: HashMap<String, String>,
207 pub decimal_fields: HashMap<String, Decimal>,
209 pub date_fields: HashMap<String, NaiveDate>,
211 pub integer_fields: HashMap<String, i64>,
213 pub boolean_fields: HashMap<String, bool>,
215}
216
217pub struct ConsistencyAnalyzer {
219 rules: Vec<ConsistencyRule>,
221 max_examples: usize,
223}
224
225impl ConsistencyAnalyzer {
226 pub fn new(rules: Vec<ConsistencyRule>) -> Self {
228 Self {
229 rules,
230 max_examples: 5,
231 }
232 }
233
234 pub fn with_default_rules() -> Self {
236 let rules = vec![
237 ConsistencyRule {
238 name: "date_ordering".to_string(),
239 description: "Document date must be on or before posting date".to_string(),
240 rule_type: RuleType::DateOrdering {
241 earlier_field: "document_date".to_string(),
242 later_field: "posting_date".to_string(),
243 },
244 },
245 ConsistencyRule {
246 name: "debit_credit_exclusion".to_string(),
247 description: "Each line must have either debit or credit, not both".to_string(),
248 rule_type: RuleType::MutualExclusion {
249 field1: "debit_amount".to_string(),
250 field2: "credit_amount".to_string(),
251 },
252 },
253 ConsistencyRule {
254 name: "fiscal_period_alignment".to_string(),
255 description: "Fiscal period must match posting date".to_string(),
256 rule_type: RuleType::FiscalPeriodDateAlignment {
257 date_field: "posting_date".to_string(),
258 period_field: "fiscal_period".to_string(),
259 year_field: "fiscal_year".to_string(),
260 },
261 },
262 ];
263
264 Self::new(rules)
265 }
266
267 pub fn analyze(&self, records: &[ConsistencyRecord]) -> EvalResult<ConsistencyAnalysis> {
269 let total_records = records.len();
270 let mut rule_results = Vec::new();
271 let mut total_violations = 0;
272 let mut violations_by_type: HashMap<String, usize> = HashMap::new();
273
274 for rule in &self.rules {
275 let mut records_checked = 0;
276 let mut records_passed = 0;
277 let mut example_violations = Vec::new();
278
279 for (idx, record) in records.iter().enumerate() {
280 let applicable = self.is_rule_applicable(rule, record);
281 if !applicable {
282 continue;
283 }
284
285 records_checked += 1;
286 let passed = self.check_rule(rule, record);
287
288 if passed {
289 records_passed += 1;
290 } else {
291 total_violations += 1;
292 *violations_by_type.entry(rule.name.clone()).or_insert(0) += 1;
293
294 if example_violations.len() < self.max_examples {
295 example_violations.push(format!("Record {}: {:?}", idx, record));
296 }
297 }
298 }
299
300 let pass_rate = if records_checked > 0 {
301 records_passed as f64 / records_checked as f64
302 } else {
303 1.0
304 };
305
306 rule_results.push(RuleResult {
307 rule_name: rule.name.clone(),
308 description: rule.description.clone(),
309 records_checked,
310 records_passed,
311 pass_rate,
312 example_violations,
313 });
314 }
315
316 let total_checked: usize = rule_results.iter().map(|r| r.records_checked).sum();
317 let total_passed: usize = rule_results.iter().map(|r| r.records_passed).sum();
318 let pass_rate = if total_checked > 0 {
319 total_passed as f64 / total_checked as f64
320 } else {
321 1.0
322 };
323
324 Ok(ConsistencyAnalysis {
325 total_records,
326 rule_results,
327 pass_rate,
328 total_violations,
329 violations_by_type,
330 })
331 }
332
333 fn is_rule_applicable(&self, rule: &ConsistencyRule, record: &ConsistencyRecord) -> bool {
335 match &rule.rule_type {
336 RuleType::DateOrdering {
337 earlier_field,
338 later_field,
339 } => {
340 record.date_fields.contains_key(earlier_field)
341 && record.date_fields.contains_key(later_field)
342 }
343 RuleType::MutualExclusion { field1, field2 } => {
344 record.decimal_fields.contains_key(field1)
345 || record.decimal_fields.contains_key(field2)
346 }
347 RuleType::FiscalPeriodDateAlignment {
348 date_field,
349 period_field,
350 year_field,
351 } => {
352 record.date_fields.contains_key(date_field)
353 && record.integer_fields.contains_key(period_field)
354 && record.integer_fields.contains_key(year_field)
355 }
356 RuleType::AmountSign {
357 amount_field,
358 indicator_field,
359 ..
360 } => {
361 record.decimal_fields.contains_key(amount_field)
362 && record.string_fields.contains_key(indicator_field)
363 }
364 RuleType::RequiredIfPresent { trigger_field, .. } => {
365 record.string_fields.contains_key(trigger_field)
366 || record.decimal_fields.contains_key(trigger_field)
367 }
368 RuleType::ValueRange { field, .. } => record.decimal_fields.contains_key(field),
369 RuleType::Custom { .. } => true,
370 }
371 }
372
373 fn check_rule(&self, rule: &ConsistencyRule, record: &ConsistencyRecord) -> bool {
375 match &rule.rule_type {
376 RuleType::DateOrdering {
377 earlier_field,
378 later_field,
379 } => {
380 let earlier = record.date_fields.get(earlier_field);
381 let later = record.date_fields.get(later_field);
382 match (earlier, later) {
383 (Some(e), Some(l)) => e <= l,
384 _ => true,
385 }
386 }
387 RuleType::MutualExclusion { field1, field2 } => {
388 let val1 = record
389 .decimal_fields
390 .get(field1)
391 .map(|v| *v != Decimal::ZERO)
392 .unwrap_or(false);
393 let val2 = record
394 .decimal_fields
395 .get(field2)
396 .map(|v| *v != Decimal::ZERO)
397 .unwrap_or(false);
398 val1 != val2
400 }
401 RuleType::FiscalPeriodDateAlignment {
402 date_field,
403 period_field,
404 year_field,
405 } => {
406 let date = record.date_fields.get(date_field);
407 let period = record.integer_fields.get(period_field);
408 let year = record.integer_fields.get(year_field);
409
410 match (date, period, year) {
411 (Some(d), Some(p), Some(y)) => d.month() as i64 == *p && d.year() as i64 == *y,
412 _ => true,
413 }
414 }
415 RuleType::AmountSign {
416 amount_field,
417 indicator_field,
418 positive_indicator,
419 } => {
420 let amount = record.decimal_fields.get(amount_field);
421 let indicator = record.string_fields.get(indicator_field);
422
423 match (amount, indicator) {
424 (Some(a), Some(i)) => {
425 let should_be_positive = i == positive_indicator;
426 let is_positive = *a >= Decimal::ZERO;
427 should_be_positive == is_positive
428 }
429 _ => true,
430 }
431 }
432 RuleType::RequiredIfPresent {
433 trigger_field,
434 required_field,
435 } => {
436 let has_trigger = record.string_fields.contains_key(trigger_field)
437 || record.decimal_fields.contains_key(trigger_field);
438
439 if !has_trigger {
440 return true;
441 }
442
443 record.string_fields.contains_key(required_field)
444 || record.decimal_fields.contains_key(required_field)
445 }
446 RuleType::ValueRange { field, min, max } => {
447 let value = record.decimal_fields.get(field);
448 match value {
449 Some(v) => {
450 let above_min = min.map(|m| *v >= m).unwrap_or(true);
451 let below_max = max.map(|m| *v <= m).unwrap_or(true);
452 above_min && below_max
453 }
454 None => true,
455 }
456 }
457 RuleType::Custom { checker } => checker(record),
458 }
459 }
460}
461
462impl Default for ConsistencyAnalyzer {
463 fn default() -> Self {
464 Self::with_default_rules()
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471
472 #[test]
473 fn test_date_ordering_pass() {
474 let mut record = ConsistencyRecord::default();
475 record.date_fields.insert(
476 "document_date".to_string(),
477 NaiveDate::from_ymd_opt(2024, 1, 10).unwrap(),
478 );
479 record.date_fields.insert(
480 "posting_date".to_string(),
481 NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(),
482 );
483
484 let analyzer = ConsistencyAnalyzer::with_default_rules();
485 let result = analyzer.analyze(&[record]).unwrap();
486
487 let date_rule = result
488 .rule_results
489 .iter()
490 .find(|r| r.rule_name == "date_ordering")
491 .unwrap();
492 assert_eq!(date_rule.pass_rate, 1.0);
493 }
494
495 #[test]
496 fn test_date_ordering_fail() {
497 let mut record = ConsistencyRecord::default();
498 record.date_fields.insert(
499 "document_date".to_string(),
500 NaiveDate::from_ymd_opt(2024, 1, 20).unwrap(),
501 );
502 record.date_fields.insert(
503 "posting_date".to_string(),
504 NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(),
505 );
506
507 let analyzer = ConsistencyAnalyzer::with_default_rules();
508 let result = analyzer.analyze(&[record]).unwrap();
509
510 let date_rule = result
511 .rule_results
512 .iter()
513 .find(|r| r.rule_name == "date_ordering")
514 .unwrap();
515 assert_eq!(date_rule.pass_rate, 0.0);
516 }
517
518 #[test]
519 fn test_mutual_exclusion() {
520 let mut record = ConsistencyRecord::default();
521 record
522 .decimal_fields
523 .insert("debit_amount".to_string(), Decimal::new(100, 0));
524 record
525 .decimal_fields
526 .insert("credit_amount".to_string(), Decimal::ZERO);
527
528 let analyzer = ConsistencyAnalyzer::with_default_rules();
529 let result = analyzer.analyze(&[record]).unwrap();
530
531 let excl_rule = result
532 .rule_results
533 .iter()
534 .find(|r| r.rule_name == "debit_credit_exclusion")
535 .unwrap();
536 assert_eq!(excl_rule.pass_rate, 1.0);
537 }
538
539 #[test]
540 fn test_mutual_exclusion_fail_both_nonzero() {
541 let mut record = ConsistencyRecord::default();
542 record
543 .decimal_fields
544 .insert("debit_amount".to_string(), Decimal::new(100, 0));
545 record
546 .decimal_fields
547 .insert("credit_amount".to_string(), Decimal::new(50, 0));
548
549 let analyzer = ConsistencyAnalyzer::with_default_rules();
550 let result = analyzer.analyze(&[record]).unwrap();
551
552 let excl_rule = result
553 .rule_results
554 .iter()
555 .find(|r| r.rule_name == "debit_credit_exclusion")
556 .unwrap();
557 assert_eq!(excl_rule.pass_rate, 0.0);
558 }
559}