1use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct BenchmarkSuite {
21 pub id: String,
23 pub name: String,
25 pub description: String,
27 pub version: String,
29 pub task_type: BenchmarkTaskType,
31 pub dataset: DatasetSpec,
33 pub evaluation: EvaluationSpec,
35 pub baselines: Vec<BaselineResult>,
37 pub metadata: HashMap<String, String>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
43pub enum BenchmarkTaskType {
44 AnomalyDetection,
46 FraudClassification,
48 DataQualityDetection,
50 EntityMatching,
52 MultiLabelAnomalyDetection,
54 AmountPrediction,
56 TimeSeriesAnomalyDetection,
58 GraphFraudDetection,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct DatasetSpec {
65 pub total_records: usize,
67 pub positive_count: usize,
69 pub negative_count: usize,
71 pub class_distribution: HashMap<String, usize>,
73 pub features: FeatureSet,
75 pub split_ratios: SplitRatios,
77 pub seed: u64,
79 pub time_span_days: u32,
81 pub num_companies: usize,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct FeatureSet {
88 pub numerical_count: usize,
90 pub categorical_count: usize,
92 pub temporal_count: usize,
94 pub text_count: usize,
96 pub feature_descriptions: HashMap<String, String>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct SplitRatios {
103 pub train: f64,
105 pub validation: f64,
107 pub test: f64,
109 pub temporal_split: bool,
111}
112
113impl Default for SplitRatios {
114 fn default() -> Self {
115 Self {
116 train: 0.7,
117 validation: 0.15,
118 test: 0.15,
119 temporal_split: false,
120 }
121 }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct EvaluationSpec {
127 pub primary_metric: MetricType,
129 pub metrics: Vec<MetricType>,
131 pub classification_threshold: Option<f64>,
133 pub cv_folds: Option<usize>,
135 pub cost_matrix: Option<CostMatrix>,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
141pub enum MetricType {
142 AucRoc,
144 AucPr,
146 PrecisionAtK(usize),
148 RecallAtK(usize),
150 F1Score,
152 Precision,
154 Recall,
156 Accuracy,
158 Mcc,
160 Map,
162 Ndcg,
164 Mse,
166 Mae,
168 R2,
170 LogLoss,
172 CohenKappa,
174 MacroF1,
176 WeightedF1,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct CostMatrix {
183 pub false_positive_cost: f64,
185 pub false_negative_cost: f64,
187 pub true_positive_cost: f64,
189 pub true_negative_cost: f64,
191}
192
193impl Default for CostMatrix {
194 fn default() -> Self {
195 Self {
196 false_positive_cost: 1.0,
197 false_negative_cost: 10.0, true_positive_cost: 0.0,
199 true_negative_cost: 0.0,
200 }
201 }
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct BaselineResult {
207 pub model_name: String,
209 pub model_type: BaselineModelType,
211 pub metrics: HashMap<String, f64>,
213 pub training_time_seconds: Option<f64>,
215 pub inference_time_ms: Option<f64>,
217 pub notes: Option<String>,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
223pub enum BaselineModelType {
224 Random,
226 MajorityClass,
228 RuleBased,
230 IsolationForest,
232 OneClassSvm,
234 Lof,
236 LogisticRegression,
238 RandomForest,
240 XgBoost,
242 LightGbm,
244 NeuralNetwork,
246 Gnn,
248 Autoencoder,
250 Custom(String),
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct LeaderboardEntry {
257 pub rank: usize,
259 pub submission_name: String,
261 pub submitter: String,
263 pub model_description: String,
265 pub primary_score: f64,
267 pub all_scores: HashMap<String, f64>,
269 pub submission_date: String,
271 pub is_baseline: bool,
273}
274
275pub struct BenchmarkBuilder {
277 suite: BenchmarkSuite,
278}
279
280impl BenchmarkBuilder {
281 pub fn new(id: &str, name: &str) -> Self {
283 Self {
284 suite: BenchmarkSuite {
285 id: id.to_string(),
286 name: name.to_string(),
287 description: String::new(),
288 version: "1.0.0".to_string(),
289 task_type: BenchmarkTaskType::AnomalyDetection,
290 dataset: DatasetSpec {
291 total_records: 1000,
292 positive_count: 50,
293 negative_count: 950,
294 class_distribution: HashMap::new(),
295 features: FeatureSet {
296 numerical_count: 10,
297 categorical_count: 5,
298 temporal_count: 3,
299 text_count: 1,
300 feature_descriptions: HashMap::new(),
301 },
302 split_ratios: SplitRatios::default(),
303 seed: 42,
304 time_span_days: 365,
305 num_companies: 1,
306 },
307 evaluation: EvaluationSpec {
308 primary_metric: MetricType::AucRoc,
309 metrics: vec![MetricType::AucRoc, MetricType::AucPr, MetricType::F1Score],
310 classification_threshold: Some(0.5),
311 cv_folds: None,
312 cost_matrix: None,
313 },
314 baselines: Vec::new(),
315 metadata: HashMap::new(),
316 },
317 }
318 }
319
320 pub fn description(mut self, desc: &str) -> Self {
322 self.suite.description = desc.to_string();
323 self
324 }
325
326 pub fn task_type(mut self, task_type: BenchmarkTaskType) -> Self {
328 self.suite.task_type = task_type;
329 self
330 }
331
332 pub fn dataset_size(mut self, total: usize, positive: usize) -> Self {
334 self.suite.dataset.total_records = total;
335 self.suite.dataset.positive_count = positive;
336 self.suite.dataset.negative_count = total.saturating_sub(positive);
337 self
338 }
339
340 pub fn class_distribution(mut self, distribution: HashMap<String, usize>) -> Self {
342 self.suite.dataset.class_distribution = distribution;
343 self
344 }
345
346 pub fn split_ratios(mut self, train: f64, val: f64, test: f64, temporal: bool) -> Self {
348 self.suite.dataset.split_ratios = SplitRatios {
349 train,
350 validation: val,
351 test,
352 temporal_split: temporal,
353 };
354 self
355 }
356
357 pub fn primary_metric(mut self, metric: MetricType) -> Self {
359 self.suite.evaluation.primary_metric = metric;
360 self
361 }
362
363 pub fn metrics(mut self, metrics: Vec<MetricType>) -> Self {
365 self.suite.evaluation.metrics = metrics;
366 self
367 }
368
369 pub fn add_baseline(mut self, baseline: BaselineResult) -> Self {
371 self.suite.baselines.push(baseline);
372 self
373 }
374
375 pub fn seed(mut self, seed: u64) -> Self {
377 self.suite.dataset.seed = seed;
378 self
379 }
380
381 pub fn time_span_days(mut self, days: u32) -> Self {
383 self.suite.dataset.time_span_days = days;
384 self
385 }
386
387 pub fn num_companies(mut self, n: usize) -> Self {
389 self.suite.dataset.num_companies = n;
390 self
391 }
392
393 pub fn metadata(mut self, key: &str, value: &str) -> Self {
395 self.suite
396 .metadata
397 .insert(key.to_string(), value.to_string());
398 self
399 }
400
401 pub fn build(self) -> BenchmarkSuite {
403 self.suite
404 }
405}
406
407pub fn anomaly_bench_1k() -> BenchmarkSuite {
413 let mut class_dist = HashMap::new();
414 class_dist.insert("normal".to_string(), 950);
415 class_dist.insert("anomaly".to_string(), 50);
416
417 BenchmarkBuilder::new("anomaly-bench-1k", "AnomalyBench-1K")
418 .description("1000 journal entry transactions with 5% anomaly rate. Balanced mix of fraud, error, and statistical anomalies.")
419 .task_type(BenchmarkTaskType::AnomalyDetection)
420 .dataset_size(1000, 50)
421 .class_distribution(class_dist)
422 .split_ratios(0.7, 0.15, 0.15, true)
423 .primary_metric(MetricType::AucPr)
424 .metrics(vec![
425 MetricType::AucRoc,
426 MetricType::AucPr,
427 MetricType::F1Score,
428 MetricType::PrecisionAtK(10),
429 MetricType::PrecisionAtK(50),
430 MetricType::Recall,
431 ])
432 .seed(42)
433 .time_span_days(90)
434 .num_companies(1)
435 .add_baseline(BaselineResult {
436 model_name: "Random".to_string(),
437 model_type: BaselineModelType::Random,
438 metrics: [
439 ("auc_roc".to_string(), 0.50),
440 ("auc_pr".to_string(), 0.05),
441 ("f1".to_string(), 0.09),
442 ].into_iter().collect(),
443 training_time_seconds: Some(0.0),
444 inference_time_ms: Some(0.01),
445 notes: Some("Random baseline for reference".to_string()),
446 })
447 .add_baseline(BaselineResult {
448 model_name: "IsolationForest".to_string(),
449 model_type: BaselineModelType::IsolationForest,
450 metrics: [
451 ("auc_roc".to_string(), 0.78),
452 ("auc_pr".to_string(), 0.42),
453 ("f1".to_string(), 0.45),
454 ].into_iter().collect(),
455 training_time_seconds: Some(0.5),
456 inference_time_ms: Some(0.1),
457 notes: Some("Unsupervised baseline".to_string()),
458 })
459 .add_baseline(BaselineResult {
460 model_name: "XGBoost".to_string(),
461 model_type: BaselineModelType::XgBoost,
462 metrics: [
463 ("auc_roc".to_string(), 0.92),
464 ("auc_pr".to_string(), 0.68),
465 ("f1".to_string(), 0.72),
466 ].into_iter().collect(),
467 training_time_seconds: Some(2.0),
468 inference_time_ms: Some(0.05),
469 notes: Some("Supervised baseline with full labels".to_string()),
470 })
471 .metadata("domain", "accounting")
472 .metadata("difficulty", "easy")
473 .build()
474}
475
476pub fn fraud_detect_10k() -> BenchmarkSuite {
478 let mut class_dist = HashMap::new();
479 class_dist.insert("normal".to_string(), 9700);
480 class_dist.insert("fictitious_transaction".to_string(), 80);
481 class_dist.insert("duplicate_payment".to_string(), 60);
482 class_dist.insert("round_tripping".to_string(), 40);
483 class_dist.insert("threshold_manipulation".to_string(), 50);
484 class_dist.insert("self_approval".to_string(), 30);
485 class_dist.insert("other_fraud".to_string(), 40);
486
487 BenchmarkBuilder::new("fraud-detect-10k", "FraudDetect-10K")
488 .description("10K journal entries with multi-class fraud labels. Includes 6 fraud types with realistic class imbalance.")
489 .task_type(BenchmarkTaskType::FraudClassification)
490 .dataset_size(10000, 300)
491 .class_distribution(class_dist)
492 .split_ratios(0.7, 0.15, 0.15, true)
493 .primary_metric(MetricType::MacroF1)
494 .metrics(vec![
495 MetricType::AucRoc,
496 MetricType::MacroF1,
497 MetricType::WeightedF1,
498 MetricType::Recall,
499 MetricType::Precision,
500 MetricType::CohenKappa,
501 ])
502 .seed(12345)
503 .time_span_days(365)
504 .num_companies(3)
505 .add_baseline(BaselineResult {
506 model_name: "MajorityClass".to_string(),
507 model_type: BaselineModelType::MajorityClass,
508 metrics: [
509 ("macro_f1".to_string(), 0.07),
510 ("weighted_f1".to_string(), 0.94),
511 ].into_iter().collect(),
512 training_time_seconds: Some(0.0),
513 inference_time_ms: Some(0.01),
514 notes: Some("Predicts normal for all transactions".to_string()),
515 })
516 .add_baseline(BaselineResult {
517 model_name: "RandomForest".to_string(),
518 model_type: BaselineModelType::RandomForest,
519 metrics: [
520 ("macro_f1".to_string(), 0.58),
521 ("weighted_f1".to_string(), 0.96),
522 ("auc_roc".to_string(), 0.89),
523 ].into_iter().collect(),
524 training_time_seconds: Some(5.0),
525 inference_time_ms: Some(0.2),
526 notes: Some("Balanced class weights".to_string()),
527 })
528 .add_baseline(BaselineResult {
529 model_name: "LightGBM".to_string(),
530 model_type: BaselineModelType::LightGbm,
531 metrics: [
532 ("macro_f1".to_string(), 0.65),
533 ("weighted_f1".to_string(), 0.97),
534 ("auc_roc".to_string(), 0.93),
535 ].into_iter().collect(),
536 training_time_seconds: Some(3.0),
537 inference_time_ms: Some(0.05),
538 notes: Some("Optimized hyperparameters".to_string()),
539 })
540 .metadata("domain", "accounting")
541 .metadata("difficulty", "medium")
542 .build()
543}
544
545pub fn data_quality_100k() -> BenchmarkSuite {
547 let mut class_dist = HashMap::new();
548 class_dist.insert("clean".to_string(), 90000);
549 class_dist.insert("missing_value".to_string(), 3000);
550 class_dist.insert("typo".to_string(), 2000);
551 class_dist.insert("format_error".to_string(), 2000);
552 class_dist.insert("duplicate".to_string(), 1500);
553 class_dist.insert("encoding_issue".to_string(), 1000);
554 class_dist.insert("truncation".to_string(), 500);
555
556 BenchmarkBuilder::new("data-quality-100k", "DataQuality-100K")
557 .description("100K records with various data quality issues. Tests detection of missing values, typos, format errors, duplicates, and encoding issues.")
558 .task_type(BenchmarkTaskType::DataQualityDetection)
559 .dataset_size(100000, 10000)
560 .class_distribution(class_dist)
561 .split_ratios(0.8, 0.1, 0.1, false)
562 .primary_metric(MetricType::F1Score)
563 .metrics(vec![
564 MetricType::F1Score,
565 MetricType::Precision,
566 MetricType::Recall,
567 MetricType::AucRoc,
568 MetricType::MacroF1,
569 ])
570 .seed(99999)
571 .time_span_days(730) .num_companies(5)
573 .add_baseline(BaselineResult {
574 model_name: "RuleBased".to_string(),
575 model_type: BaselineModelType::RuleBased,
576 metrics: [
577 ("f1".to_string(), 0.72),
578 ("precision".to_string(), 0.85),
579 ("recall".to_string(), 0.62),
580 ].into_iter().collect(),
581 training_time_seconds: Some(0.0),
582 inference_time_ms: Some(0.5),
583 notes: Some("Regex patterns and null checks".to_string()),
584 })
585 .add_baseline(BaselineResult {
586 model_name: "LogisticRegression".to_string(),
587 model_type: BaselineModelType::LogisticRegression,
588 metrics: [
589 ("f1".to_string(), 0.78),
590 ("precision".to_string(), 0.80),
591 ("recall".to_string(), 0.76),
592 ].into_iter().collect(),
593 training_time_seconds: Some(2.0),
594 inference_time_ms: Some(0.02),
595 notes: Some("Character n-gram features".to_string()),
596 })
597 .add_baseline(BaselineResult {
598 model_name: "XGBoost".to_string(),
599 model_type: BaselineModelType::XgBoost,
600 metrics: [
601 ("f1".to_string(), 0.88),
602 ("precision".to_string(), 0.90),
603 ("recall".to_string(), 0.86),
604 ].into_iter().collect(),
605 training_time_seconds: Some(15.0),
606 inference_time_ms: Some(0.08),
607 notes: Some("Mixed feature types".to_string()),
608 })
609 .metadata("domain", "data_quality")
610 .metadata("difficulty", "medium")
611 .build()
612}
613
614pub fn entity_match_5k() -> BenchmarkSuite {
616 let mut class_dist = HashMap::new();
617 class_dist.insert("match".to_string(), 2000);
618 class_dist.insert("non_match".to_string(), 3000);
619
620 BenchmarkBuilder::new("entity-match-5k", "EntityMatch-5K")
621 .description("5K vendor/customer record pairs for entity matching. Includes name variations, typos, and abbreviations.")
622 .task_type(BenchmarkTaskType::EntityMatching)
623 .dataset_size(5000, 2000)
624 .class_distribution(class_dist)
625 .split_ratios(0.7, 0.15, 0.15, false)
626 .primary_metric(MetricType::F1Score)
627 .metrics(vec![
628 MetricType::F1Score,
629 MetricType::Precision,
630 MetricType::Recall,
631 MetricType::AucRoc,
632 ])
633 .seed(54321)
634 .time_span_days(365)
635 .num_companies(2)
636 .add_baseline(BaselineResult {
637 model_name: "ExactMatch".to_string(),
638 model_type: BaselineModelType::RuleBased,
639 metrics: [
640 ("f1".to_string(), 0.35),
641 ("precision".to_string(), 1.0),
642 ("recall".to_string(), 0.21),
643 ].into_iter().collect(),
644 training_time_seconds: Some(0.0),
645 inference_time_ms: Some(0.1),
646 notes: Some("Exact string matching only".to_string()),
647 })
648 .add_baseline(BaselineResult {
649 model_name: "FuzzyMatch".to_string(),
650 model_type: BaselineModelType::RuleBased,
651 metrics: [
652 ("f1".to_string(), 0.68),
653 ("precision".to_string(), 0.72),
654 ("recall".to_string(), 0.65),
655 ].into_iter().collect(),
656 training_time_seconds: Some(0.0),
657 inference_time_ms: Some(2.0),
658 notes: Some("Levenshtein distance threshold".to_string()),
659 })
660 .add_baseline(BaselineResult {
661 model_name: "Magellan".to_string(),
662 model_type: BaselineModelType::RandomForest,
663 metrics: [
664 ("f1".to_string(), 0.89),
665 ("precision".to_string(), 0.91),
666 ("recall".to_string(), 0.87),
667 ].into_iter().collect(),
668 training_time_seconds: Some(10.0),
669 inference_time_ms: Some(5.0),
670 notes: Some("Feature-based entity matcher".to_string()),
671 })
672 .metadata("domain", "master_data")
673 .metadata("difficulty", "hard")
674 .build()
675}
676
677pub fn graph_fraud_10k() -> BenchmarkSuite {
679 let mut class_dist = HashMap::new();
680 class_dist.insert("normal".to_string(), 9500);
681 class_dist.insert("fraud_network".to_string(), 500);
682
683 BenchmarkBuilder::new("graph-fraud-10k", "GraphFraud-10K")
684 .description("10K transactions with entity graph structure. Fraud detection using network features and GNN models.")
685 .task_type(BenchmarkTaskType::GraphFraudDetection)
686 .dataset_size(10000, 500)
687 .class_distribution(class_dist)
688 .split_ratios(0.7, 0.15, 0.15, true)
689 .primary_metric(MetricType::AucPr)
690 .metrics(vec![
691 MetricType::AucPr,
692 MetricType::AucRoc,
693 MetricType::F1Score,
694 MetricType::PrecisionAtK(100),
695 ])
696 .seed(77777)
697 .time_span_days(365)
698 .num_companies(4)
699 .add_baseline(BaselineResult {
700 model_name: "NodeFeatures".to_string(),
701 model_type: BaselineModelType::XgBoost,
702 metrics: [
703 ("auc_pr".to_string(), 0.45),
704 ("auc_roc".to_string(), 0.78),
705 ].into_iter().collect(),
706 training_time_seconds: Some(5.0),
707 inference_time_ms: Some(0.1),
708 notes: Some("XGBoost on node features only".to_string()),
709 })
710 .add_baseline(BaselineResult {
711 model_name: "GraphSAGE".to_string(),
712 model_type: BaselineModelType::Gnn,
713 metrics: [
714 ("auc_pr".to_string(), 0.62),
715 ("auc_roc".to_string(), 0.88),
716 ].into_iter().collect(),
717 training_time_seconds: Some(60.0),
718 inference_time_ms: Some(5.0),
719 notes: Some("2-layer GraphSAGE".to_string()),
720 })
721 .add_baseline(BaselineResult {
722 model_name: "GAT".to_string(),
723 model_type: BaselineModelType::Gnn,
724 metrics: [
725 ("auc_pr".to_string(), 0.68),
726 ("auc_roc".to_string(), 0.91),
727 ].into_iter().collect(),
728 training_time_seconds: Some(90.0),
729 inference_time_ms: Some(8.0),
730 notes: Some("Graph Attention Network".to_string()),
731 })
732 .metadata("domain", "graph_analytics")
733 .metadata("difficulty", "hard")
734 .build()
735}
736
737pub fn all_benchmarks() -> Vec<BenchmarkSuite> {
739 vec![
740 anomaly_bench_1k(),
741 fraud_detect_10k(),
742 data_quality_100k(),
743 entity_match_5k(),
744 graph_fraud_10k(),
745 ]
746}
747
748pub fn get_benchmark(id: &str) -> Option<BenchmarkSuite> {
750 all_benchmarks().into_iter().find(|b| b.id == id)
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756
757 #[test]
758 fn test_anomaly_bench_1k() {
759 let bench = anomaly_bench_1k();
760 assert_eq!(bench.id, "anomaly-bench-1k");
761 assert_eq!(bench.dataset.total_records, 1000);
762 assert_eq!(bench.dataset.positive_count, 50);
763 assert_eq!(bench.baselines.len(), 3);
764 }
765
766 #[test]
767 fn test_fraud_detect_10k() {
768 let bench = fraud_detect_10k();
769 assert_eq!(bench.id, "fraud-detect-10k");
770 assert_eq!(bench.dataset.total_records, 10000);
771 assert_eq!(bench.task_type, BenchmarkTaskType::FraudClassification);
772 }
773
774 #[test]
775 fn test_data_quality_100k() {
776 let bench = data_quality_100k();
777 assert_eq!(bench.id, "data-quality-100k");
778 assert_eq!(bench.dataset.total_records, 100000);
779 assert!(bench.dataset.class_distribution.len() > 5);
780 }
781
782 #[test]
783 fn test_all_benchmarks() {
784 let benchmarks = all_benchmarks();
785 assert_eq!(benchmarks.len(), 5);
786
787 for bench in &benchmarks {
789 assert!(
790 !bench.baselines.is_empty(),
791 "Benchmark {} has no baselines",
792 bench.id
793 );
794 }
795 }
796
797 #[test]
798 fn test_get_benchmark() {
799 assert!(get_benchmark("fraud-detect-10k").is_some());
800 assert!(get_benchmark("nonexistent").is_none());
801 }
802
803 #[test]
804 fn test_builder() {
805 let bench = BenchmarkBuilder::new("custom", "Custom Benchmark")
806 .description("A custom test benchmark")
807 .task_type(BenchmarkTaskType::AnomalyDetection)
808 .dataset_size(500, 25)
809 .seed(123)
810 .build();
811
812 assert_eq!(bench.id, "custom");
813 assert_eq!(bench.dataset.total_records, 500);
814 assert_eq!(bench.dataset.seed, 123);
815 }
816}