1pub mod acfe;
18pub mod industry;
19
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22
23pub use acfe::{
24 acfe_calibrated_1k, acfe_collusion_5k, acfe_management_override_2k, all_acfe_benchmarks,
25 AcfeAlignment, AcfeCalibration, AcfeCategoryDistribution,
26};
27pub use industry::{
28 all_industry_benchmarks, financial_services_fraud_5k, get_industry_benchmark,
29 healthcare_fraud_5k, manufacturing_fraud_5k, retail_fraud_10k, technology_fraud_3k,
30 IndustryBenchmarkAnalysis,
31};
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct BenchmarkSuite {
36 pub id: String,
38 pub name: String,
40 pub description: String,
42 pub version: String,
44 pub task_type: BenchmarkTaskType,
46 pub dataset: DatasetSpec,
48 pub evaluation: EvaluationSpec,
50 pub baselines: Vec<BaselineResult>,
52 pub metadata: HashMap<String, String>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
58pub enum BenchmarkTaskType {
59 AnomalyDetection,
61 FraudClassification,
63 DataQualityDetection,
65 EntityMatching,
67 MultiLabelAnomalyDetection,
69 AmountPrediction,
71 TimeSeriesAnomalyDetection,
73 GraphFraudDetection,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct DatasetSpec {
80 pub total_records: usize,
82 pub positive_count: usize,
84 pub negative_count: usize,
86 pub class_distribution: HashMap<String, usize>,
88 pub features: FeatureSet,
90 pub split_ratios: SplitRatios,
92 pub seed: u64,
94 pub time_span_days: u32,
96 pub num_companies: usize,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct FeatureSet {
103 pub numerical_count: usize,
105 pub categorical_count: usize,
107 pub temporal_count: usize,
109 pub text_count: usize,
111 pub feature_descriptions: HashMap<String, String>,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct SplitRatios {
118 pub train: f64,
120 pub validation: f64,
122 pub test: f64,
124 pub temporal_split: bool,
126}
127
128impl Default for SplitRatios {
129 fn default() -> Self {
130 Self {
131 train: 0.7,
132 validation: 0.15,
133 test: 0.15,
134 temporal_split: false,
135 }
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct EvaluationSpec {
142 pub primary_metric: MetricType,
144 pub metrics: Vec<MetricType>,
146 pub classification_threshold: Option<f64>,
148 pub cv_folds: Option<usize>,
150 pub cost_matrix: Option<CostMatrix>,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
156pub enum MetricType {
157 AucRoc,
159 AucPr,
161 PrecisionAtK(usize),
163 RecallAtK(usize),
165 F1Score,
167 Precision,
169 Recall,
171 Accuracy,
173 Mcc,
175 Map,
177 Ndcg,
179 Mse,
181 Mae,
183 R2,
185 LogLoss,
187 CohenKappa,
189 MacroF1,
191 WeightedF1,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct CostMatrix {
198 pub false_positive_cost: f64,
200 pub false_negative_cost: f64,
202 pub true_positive_cost: f64,
204 pub true_negative_cost: f64,
206}
207
208impl Default for CostMatrix {
209 fn default() -> Self {
210 Self {
211 false_positive_cost: 1.0,
212 false_negative_cost: 10.0, true_positive_cost: 0.0,
214 true_negative_cost: 0.0,
215 }
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct BaselineResult {
222 pub model_name: String,
224 pub model_type: BaselineModelType,
226 pub metrics: HashMap<String, f64>,
228 pub training_time_seconds: Option<f64>,
230 pub inference_time_ms: Option<f64>,
232 pub notes: Option<String>,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
238pub enum BaselineModelType {
239 Random,
241 MajorityClass,
243 RuleBased,
245 IsolationForest,
247 OneClassSvm,
249 Lof,
251 LogisticRegression,
253 RandomForest,
255 XgBoost,
257 LightGbm,
259 NeuralNetwork,
261 Gnn,
263 Autoencoder,
265 Custom(String),
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct LeaderboardEntry {
272 pub rank: usize,
274 pub submission_name: String,
276 pub submitter: String,
278 pub model_description: String,
280 pub primary_score: f64,
282 pub all_scores: HashMap<String, f64>,
284 pub submission_date: String,
286 pub is_baseline: bool,
288}
289
290pub struct BenchmarkBuilder {
292 suite: BenchmarkSuite,
293}
294
295impl BenchmarkBuilder {
296 pub fn new(id: &str, name: &str) -> Self {
298 Self {
299 suite: BenchmarkSuite {
300 id: id.to_string(),
301 name: name.to_string(),
302 description: String::new(),
303 version: "1.0.0".to_string(),
304 task_type: BenchmarkTaskType::AnomalyDetection,
305 dataset: DatasetSpec {
306 total_records: 1000,
307 positive_count: 50,
308 negative_count: 950,
309 class_distribution: HashMap::new(),
310 features: FeatureSet {
311 numerical_count: 10,
312 categorical_count: 5,
313 temporal_count: 3,
314 text_count: 1,
315 feature_descriptions: HashMap::new(),
316 },
317 split_ratios: SplitRatios::default(),
318 seed: 42,
319 time_span_days: 365,
320 num_companies: 1,
321 },
322 evaluation: EvaluationSpec {
323 primary_metric: MetricType::AucRoc,
324 metrics: vec![MetricType::AucRoc, MetricType::AucPr, MetricType::F1Score],
325 classification_threshold: Some(0.5),
326 cv_folds: None,
327 cost_matrix: None,
328 },
329 baselines: Vec::new(),
330 metadata: HashMap::new(),
331 },
332 }
333 }
334
335 pub fn description(mut self, desc: &str) -> Self {
337 self.suite.description = desc.to_string();
338 self
339 }
340
341 pub fn task_type(mut self, task_type: BenchmarkTaskType) -> Self {
343 self.suite.task_type = task_type;
344 self
345 }
346
347 pub fn dataset_size(mut self, total: usize, positive: usize) -> Self {
349 self.suite.dataset.total_records = total;
350 self.suite.dataset.positive_count = positive;
351 self.suite.dataset.negative_count = total.saturating_sub(positive);
352 self
353 }
354
355 pub fn class_distribution(mut self, distribution: HashMap<String, usize>) -> Self {
357 self.suite.dataset.class_distribution = distribution;
358 self
359 }
360
361 pub fn split_ratios(mut self, train: f64, val: f64, test: f64, temporal: bool) -> Self {
363 self.suite.dataset.split_ratios = SplitRatios {
364 train,
365 validation: val,
366 test,
367 temporal_split: temporal,
368 };
369 self
370 }
371
372 pub fn primary_metric(mut self, metric: MetricType) -> Self {
374 self.suite.evaluation.primary_metric = metric;
375 self
376 }
377
378 pub fn metrics(mut self, metrics: Vec<MetricType>) -> Self {
380 self.suite.evaluation.metrics = metrics;
381 self
382 }
383
384 pub fn add_baseline(mut self, baseline: BaselineResult) -> Self {
386 self.suite.baselines.push(baseline);
387 self
388 }
389
390 pub fn seed(mut self, seed: u64) -> Self {
392 self.suite.dataset.seed = seed;
393 self
394 }
395
396 pub fn time_span_days(mut self, days: u32) -> Self {
398 self.suite.dataset.time_span_days = days;
399 self
400 }
401
402 pub fn num_companies(mut self, n: usize) -> Self {
404 self.suite.dataset.num_companies = n;
405 self
406 }
407
408 pub fn metadata(mut self, key: &str, value: &str) -> Self {
410 self.suite
411 .metadata
412 .insert(key.to_string(), value.to_string());
413 self
414 }
415
416 pub fn build(self) -> BenchmarkSuite {
418 self.suite
419 }
420}
421
422pub fn anomaly_bench_1k() -> BenchmarkSuite {
428 let mut class_dist = HashMap::new();
429 class_dist.insert("normal".to_string(), 950);
430 class_dist.insert("anomaly".to_string(), 50);
431
432 BenchmarkBuilder::new("anomaly-bench-1k", "AnomalyBench-1K")
433 .description("1000 journal entry transactions with 5% anomaly rate. Balanced mix of fraud, error, and statistical anomalies.")
434 .task_type(BenchmarkTaskType::AnomalyDetection)
435 .dataset_size(1000, 50)
436 .class_distribution(class_dist)
437 .split_ratios(0.7, 0.15, 0.15, true)
438 .primary_metric(MetricType::AucPr)
439 .metrics(vec![
440 MetricType::AucRoc,
441 MetricType::AucPr,
442 MetricType::F1Score,
443 MetricType::PrecisionAtK(10),
444 MetricType::PrecisionAtK(50),
445 MetricType::Recall,
446 ])
447 .seed(42)
448 .time_span_days(90)
449 .num_companies(1)
450 .add_baseline(BaselineResult {
451 model_name: "Random".to_string(),
452 model_type: BaselineModelType::Random,
453 metrics: [
454 ("auc_roc".to_string(), 0.50),
455 ("auc_pr".to_string(), 0.05),
456 ("f1".to_string(), 0.09),
457 ].into_iter().collect(),
458 training_time_seconds: Some(0.0),
459 inference_time_ms: Some(0.01),
460 notes: Some("Random baseline for reference".to_string()),
461 })
462 .add_baseline(BaselineResult {
463 model_name: "IsolationForest".to_string(),
464 model_type: BaselineModelType::IsolationForest,
465 metrics: [
466 ("auc_roc".to_string(), 0.78),
467 ("auc_pr".to_string(), 0.42),
468 ("f1".to_string(), 0.45),
469 ].into_iter().collect(),
470 training_time_seconds: Some(0.5),
471 inference_time_ms: Some(0.1),
472 notes: Some("Unsupervised baseline".to_string()),
473 })
474 .add_baseline(BaselineResult {
475 model_name: "XGBoost".to_string(),
476 model_type: BaselineModelType::XgBoost,
477 metrics: [
478 ("auc_roc".to_string(), 0.92),
479 ("auc_pr".to_string(), 0.68),
480 ("f1".to_string(), 0.72),
481 ].into_iter().collect(),
482 training_time_seconds: Some(2.0),
483 inference_time_ms: Some(0.05),
484 notes: Some("Supervised baseline with full labels".to_string()),
485 })
486 .metadata("domain", "accounting")
487 .metadata("difficulty", "easy")
488 .build()
489}
490
491pub fn fraud_detect_10k() -> BenchmarkSuite {
493 let mut class_dist = HashMap::new();
494 class_dist.insert("normal".to_string(), 9700);
495 class_dist.insert("fictitious_transaction".to_string(), 80);
496 class_dist.insert("duplicate_payment".to_string(), 60);
497 class_dist.insert("round_tripping".to_string(), 40);
498 class_dist.insert("threshold_manipulation".to_string(), 50);
499 class_dist.insert("self_approval".to_string(), 30);
500 class_dist.insert("other_fraud".to_string(), 40);
501
502 BenchmarkBuilder::new("fraud-detect-10k", "FraudDetect-10K")
503 .description("10K journal entries with multi-class fraud labels. Includes 6 fraud types with realistic class imbalance.")
504 .task_type(BenchmarkTaskType::FraudClassification)
505 .dataset_size(10000, 300)
506 .class_distribution(class_dist)
507 .split_ratios(0.7, 0.15, 0.15, true)
508 .primary_metric(MetricType::MacroF1)
509 .metrics(vec![
510 MetricType::AucRoc,
511 MetricType::MacroF1,
512 MetricType::WeightedF1,
513 MetricType::Recall,
514 MetricType::Precision,
515 MetricType::CohenKappa,
516 ])
517 .seed(12345)
518 .time_span_days(365)
519 .num_companies(3)
520 .add_baseline(BaselineResult {
521 model_name: "MajorityClass".to_string(),
522 model_type: BaselineModelType::MajorityClass,
523 metrics: [
524 ("macro_f1".to_string(), 0.07),
525 ("weighted_f1".to_string(), 0.94),
526 ].into_iter().collect(),
527 training_time_seconds: Some(0.0),
528 inference_time_ms: Some(0.01),
529 notes: Some("Predicts normal for all transactions".to_string()),
530 })
531 .add_baseline(BaselineResult {
532 model_name: "RandomForest".to_string(),
533 model_type: BaselineModelType::RandomForest,
534 metrics: [
535 ("macro_f1".to_string(), 0.58),
536 ("weighted_f1".to_string(), 0.96),
537 ("auc_roc".to_string(), 0.89),
538 ].into_iter().collect(),
539 training_time_seconds: Some(5.0),
540 inference_time_ms: Some(0.2),
541 notes: Some("Balanced class weights".to_string()),
542 })
543 .add_baseline(BaselineResult {
544 model_name: "LightGBM".to_string(),
545 model_type: BaselineModelType::LightGbm,
546 metrics: [
547 ("macro_f1".to_string(), 0.65),
548 ("weighted_f1".to_string(), 0.97),
549 ("auc_roc".to_string(), 0.93),
550 ].into_iter().collect(),
551 training_time_seconds: Some(3.0),
552 inference_time_ms: Some(0.05),
553 notes: Some("Optimized hyperparameters".to_string()),
554 })
555 .metadata("domain", "accounting")
556 .metadata("difficulty", "medium")
557 .build()
558}
559
560pub fn data_quality_100k() -> BenchmarkSuite {
562 let mut class_dist = HashMap::new();
563 class_dist.insert("clean".to_string(), 90000);
564 class_dist.insert("missing_value".to_string(), 3000);
565 class_dist.insert("typo".to_string(), 2000);
566 class_dist.insert("format_error".to_string(), 2000);
567 class_dist.insert("duplicate".to_string(), 1500);
568 class_dist.insert("encoding_issue".to_string(), 1000);
569 class_dist.insert("truncation".to_string(), 500);
570
571 BenchmarkBuilder::new("data-quality-100k", "DataQuality-100K")
572 .description("100K records with various data quality issues. Tests detection of missing values, typos, format errors, duplicates, and encoding issues.")
573 .task_type(BenchmarkTaskType::DataQualityDetection)
574 .dataset_size(100000, 10000)
575 .class_distribution(class_dist)
576 .split_ratios(0.8, 0.1, 0.1, false)
577 .primary_metric(MetricType::F1Score)
578 .metrics(vec![
579 MetricType::F1Score,
580 MetricType::Precision,
581 MetricType::Recall,
582 MetricType::AucRoc,
583 MetricType::MacroF1,
584 ])
585 .seed(99999)
586 .time_span_days(730) .num_companies(5)
588 .add_baseline(BaselineResult {
589 model_name: "RuleBased".to_string(),
590 model_type: BaselineModelType::RuleBased,
591 metrics: [
592 ("f1".to_string(), 0.72),
593 ("precision".to_string(), 0.85),
594 ("recall".to_string(), 0.62),
595 ].into_iter().collect(),
596 training_time_seconds: Some(0.0),
597 inference_time_ms: Some(0.5),
598 notes: Some("Regex patterns and null checks".to_string()),
599 })
600 .add_baseline(BaselineResult {
601 model_name: "LogisticRegression".to_string(),
602 model_type: BaselineModelType::LogisticRegression,
603 metrics: [
604 ("f1".to_string(), 0.78),
605 ("precision".to_string(), 0.80),
606 ("recall".to_string(), 0.76),
607 ].into_iter().collect(),
608 training_time_seconds: Some(2.0),
609 inference_time_ms: Some(0.02),
610 notes: Some("Character n-gram features".to_string()),
611 })
612 .add_baseline(BaselineResult {
613 model_name: "XGBoost".to_string(),
614 model_type: BaselineModelType::XgBoost,
615 metrics: [
616 ("f1".to_string(), 0.88),
617 ("precision".to_string(), 0.90),
618 ("recall".to_string(), 0.86),
619 ].into_iter().collect(),
620 training_time_seconds: Some(15.0),
621 inference_time_ms: Some(0.08),
622 notes: Some("Mixed feature types".to_string()),
623 })
624 .metadata("domain", "data_quality")
625 .metadata("difficulty", "medium")
626 .build()
627}
628
629pub fn entity_match_5k() -> BenchmarkSuite {
631 let mut class_dist = HashMap::new();
632 class_dist.insert("match".to_string(), 2000);
633 class_dist.insert("non_match".to_string(), 3000);
634
635 BenchmarkBuilder::new("entity-match-5k", "EntityMatch-5K")
636 .description("5K vendor/customer record pairs for entity matching. Includes name variations, typos, and abbreviations.")
637 .task_type(BenchmarkTaskType::EntityMatching)
638 .dataset_size(5000, 2000)
639 .class_distribution(class_dist)
640 .split_ratios(0.7, 0.15, 0.15, false)
641 .primary_metric(MetricType::F1Score)
642 .metrics(vec![
643 MetricType::F1Score,
644 MetricType::Precision,
645 MetricType::Recall,
646 MetricType::AucRoc,
647 ])
648 .seed(54321)
649 .time_span_days(365)
650 .num_companies(2)
651 .add_baseline(BaselineResult {
652 model_name: "ExactMatch".to_string(),
653 model_type: BaselineModelType::RuleBased,
654 metrics: [
655 ("f1".to_string(), 0.35),
656 ("precision".to_string(), 1.0),
657 ("recall".to_string(), 0.21),
658 ].into_iter().collect(),
659 training_time_seconds: Some(0.0),
660 inference_time_ms: Some(0.1),
661 notes: Some("Exact string matching only".to_string()),
662 })
663 .add_baseline(BaselineResult {
664 model_name: "FuzzyMatch".to_string(),
665 model_type: BaselineModelType::RuleBased,
666 metrics: [
667 ("f1".to_string(), 0.68),
668 ("precision".to_string(), 0.72),
669 ("recall".to_string(), 0.65),
670 ].into_iter().collect(),
671 training_time_seconds: Some(0.0),
672 inference_time_ms: Some(2.0),
673 notes: Some("Levenshtein distance threshold".to_string()),
674 })
675 .add_baseline(BaselineResult {
676 model_name: "Magellan".to_string(),
677 model_type: BaselineModelType::RandomForest,
678 metrics: [
679 ("f1".to_string(), 0.89),
680 ("precision".to_string(), 0.91),
681 ("recall".to_string(), 0.87),
682 ].into_iter().collect(),
683 training_time_seconds: Some(10.0),
684 inference_time_ms: Some(5.0),
685 notes: Some("Feature-based entity matcher".to_string()),
686 })
687 .metadata("domain", "master_data")
688 .metadata("difficulty", "hard")
689 .build()
690}
691
692pub fn graph_fraud_10k() -> BenchmarkSuite {
694 let mut class_dist = HashMap::new();
695 class_dist.insert("normal".to_string(), 9500);
696 class_dist.insert("fraud_network".to_string(), 500);
697
698 BenchmarkBuilder::new("graph-fraud-10k", "GraphFraud-10K")
699 .description("10K transactions with entity graph structure. Fraud detection using network features and GNN models.")
700 .task_type(BenchmarkTaskType::GraphFraudDetection)
701 .dataset_size(10000, 500)
702 .class_distribution(class_dist)
703 .split_ratios(0.7, 0.15, 0.15, true)
704 .primary_metric(MetricType::AucPr)
705 .metrics(vec![
706 MetricType::AucPr,
707 MetricType::AucRoc,
708 MetricType::F1Score,
709 MetricType::PrecisionAtK(100),
710 ])
711 .seed(77777)
712 .time_span_days(365)
713 .num_companies(4)
714 .add_baseline(BaselineResult {
715 model_name: "NodeFeatures".to_string(),
716 model_type: BaselineModelType::XgBoost,
717 metrics: [
718 ("auc_pr".to_string(), 0.45),
719 ("auc_roc".to_string(), 0.78),
720 ].into_iter().collect(),
721 training_time_seconds: Some(5.0),
722 inference_time_ms: Some(0.1),
723 notes: Some("XGBoost on node features only".to_string()),
724 })
725 .add_baseline(BaselineResult {
726 model_name: "GraphSAGE".to_string(),
727 model_type: BaselineModelType::Gnn,
728 metrics: [
729 ("auc_pr".to_string(), 0.62),
730 ("auc_roc".to_string(), 0.88),
731 ].into_iter().collect(),
732 training_time_seconds: Some(60.0),
733 inference_time_ms: Some(5.0),
734 notes: Some("2-layer GraphSAGE".to_string()),
735 })
736 .add_baseline(BaselineResult {
737 model_name: "GAT".to_string(),
738 model_type: BaselineModelType::Gnn,
739 metrics: [
740 ("auc_pr".to_string(), 0.68),
741 ("auc_roc".to_string(), 0.91),
742 ].into_iter().collect(),
743 training_time_seconds: Some(90.0),
744 inference_time_ms: Some(8.0),
745 notes: Some("Graph Attention Network".to_string()),
746 })
747 .metadata("domain", "graph_analytics")
748 .metadata("difficulty", "hard")
749 .build()
750}
751
752pub fn all_benchmarks() -> Vec<BenchmarkSuite> {
754 let mut benchmarks = vec![
755 anomaly_bench_1k(),
756 fraud_detect_10k(),
757 data_quality_100k(),
758 entity_match_5k(),
759 graph_fraud_10k(),
760 ];
761
762 benchmarks.extend(all_acfe_benchmarks());
764
765 benchmarks.extend(all_industry_benchmarks());
767
768 benchmarks
769}
770
771pub fn get_benchmark(id: &str) -> Option<BenchmarkSuite> {
773 all_benchmarks().into_iter().find(|b| b.id == id)
774}
775
776#[cfg(test)]
777mod tests {
778 use super::*;
779
780 #[test]
781 fn test_anomaly_bench_1k() {
782 let bench = anomaly_bench_1k();
783 assert_eq!(bench.id, "anomaly-bench-1k");
784 assert_eq!(bench.dataset.total_records, 1000);
785 assert_eq!(bench.dataset.positive_count, 50);
786 assert_eq!(bench.baselines.len(), 3);
787 }
788
789 #[test]
790 fn test_fraud_detect_10k() {
791 let bench = fraud_detect_10k();
792 assert_eq!(bench.id, "fraud-detect-10k");
793 assert_eq!(bench.dataset.total_records, 10000);
794 assert_eq!(bench.task_type, BenchmarkTaskType::FraudClassification);
795 }
796
797 #[test]
798 fn test_data_quality_100k() {
799 let bench = data_quality_100k();
800 assert_eq!(bench.id, "data-quality-100k");
801 assert_eq!(bench.dataset.total_records, 100000);
802 assert!(bench.dataset.class_distribution.len() > 5);
803 }
804
805 #[test]
806 fn test_all_benchmarks() {
807 let benchmarks = all_benchmarks();
808 assert_eq!(benchmarks.len(), 13);
810
811 for bench in &benchmarks {
813 assert!(
814 !bench.baselines.is_empty(),
815 "Benchmark {} has no baselines",
816 bench.id
817 );
818 }
819 }
820
821 #[test]
822 fn test_get_benchmark() {
823 assert!(get_benchmark("fraud-detect-10k").is_some());
824 assert!(get_benchmark("nonexistent").is_none());
825 }
826
827 #[test]
828 fn test_builder() {
829 let bench = BenchmarkBuilder::new("custom", "Custom Benchmark")
830 .description("A custom test benchmark")
831 .task_type(BenchmarkTaskType::AnomalyDetection)
832 .dataset_size(500, 25)
833 .seed(123)
834 .build();
835
836 assert_eq!(bench.id, "custom");
837 assert_eq!(bench.dataset.total_records, 500);
838 assert_eq!(bench.dataset.seed, 123);
839 }
840}