Skip to main content

datasynth_eval/benchmarks/
mod.rs

1//! Benchmark suite definitions for ML evaluation.
2//!
3//! Provides standardized benchmark datasets for:
4//! - Anomaly detection (AnomalyBench-1K)
5//! - Fraud detection (FraudDetect-10K)
6//! - Data quality detection (DataQuality-100K)
7//! - Entity matching (EntityMatch-5K)
8//! - ACFE-calibrated fraud detection (ACFE-Calibrated-1K, ACFE-Collusion-5K)
9//! - Industry-specific fraud detection (Manufacturing, Retail, Healthcare, Technology, Financial Services)
10//!
11//! Each benchmark defines:
12//! - Dataset size and composition
13//! - Ground truth labels
14//! - Evaluation metrics
15//! - Expected baseline performance
16
17pub mod acfe;
18pub mod industry;
19
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22
23pub use acfe::{
24    acfe_calibrated_1k, acfe_collusion_5k, acfe_management_override_2k, all_acfe_benchmarks,
25    AcfeAlignment, AcfeCalibration, AcfeCategoryDistribution,
26};
27pub use industry::{
28    all_industry_benchmarks, financial_services_fraud_5k, get_industry_benchmark,
29    healthcare_fraud_5k, manufacturing_fraud_5k, retail_fraud_10k, technology_fraud_3k,
30    IndustryBenchmarkAnalysis,
31};
32
33/// A benchmark suite definition.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct BenchmarkSuite {
36    /// Unique identifier for the benchmark
37    pub id: String,
38    /// Human-readable name
39    pub name: String,
40    /// Description of the benchmark
41    pub description: String,
42    /// Version of the benchmark specification
43    pub version: String,
44    /// Task type being evaluated
45    pub task_type: BenchmarkTaskType,
46    /// Dataset specification
47    pub dataset: DatasetSpec,
48    /// Evaluation configuration
49    pub evaluation: EvaluationSpec,
50    /// Expected baseline results
51    pub baselines: Vec<BaselineResult>,
52    /// Metadata
53    pub metadata: HashMap<String, String>,
54}
55
56/// Types of benchmark tasks.
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
58pub enum BenchmarkTaskType {
59    /// Binary classification: anomaly vs normal
60    AnomalyDetection,
61    /// Multi-class classification: fraud type
62    FraudClassification,
63    /// Binary classification: data quality issue detection
64    DataQualityDetection,
65    /// Entity resolution: matching records
66    EntityMatching,
67    /// Multi-label classification: anomaly types
68    MultiLabelAnomalyDetection,
69    /// Regression: amount prediction
70    AmountPrediction,
71    /// Time series: anomaly detection over time
72    TimeSeriesAnomalyDetection,
73    /// Graph: fraud network detection
74    GraphFraudDetection,
75}
76
77/// Dataset specification.
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct DatasetSpec {
80    /// Total number of records
81    pub total_records: usize,
82    /// Number of positive examples (anomalies/fraud/issues)
83    pub positive_count: usize,
84    /// Number of negative examples (normal/clean)
85    pub negative_count: usize,
86    /// Class distribution by label
87    pub class_distribution: HashMap<String, usize>,
88    /// Feature set description
89    pub features: FeatureSet,
90    /// Split ratios (train/val/test)
91    pub split_ratios: SplitRatios,
92    /// Seed for reproducibility
93    pub seed: u64,
94    /// Time span in days
95    pub time_span_days: u32,
96    /// Number of companies
97    pub num_companies: usize,
98}
99
100/// Feature set for the benchmark.
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct FeatureSet {
103    /// Number of numerical features
104    pub numerical_count: usize,
105    /// Number of categorical features
106    pub categorical_count: usize,
107    /// Number of temporal features
108    pub temporal_count: usize,
109    /// Number of text features
110    pub text_count: usize,
111    /// Feature names and descriptions
112    pub feature_descriptions: HashMap<String, String>,
113}
114
115/// Train/validation/test split ratios.
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct SplitRatios {
118    /// Training set ratio (0.0-1.0)
119    pub train: f64,
120    /// Validation set ratio (0.0-1.0)
121    pub validation: f64,
122    /// Test set ratio (0.0-1.0)
123    pub test: f64,
124    /// Temporal split (if true, test set is the latest data)
125    pub temporal_split: bool,
126}
127
128impl Default for SplitRatios {
129    fn default() -> Self {
130        Self {
131            train: 0.7,
132            validation: 0.15,
133            test: 0.15,
134            temporal_split: false,
135        }
136    }
137}
138
139/// Evaluation specification.
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct EvaluationSpec {
142    /// Primary metric for ranking
143    pub primary_metric: MetricType,
144    /// All metrics to compute
145    pub metrics: Vec<MetricType>,
146    /// Threshold for binary classification
147    pub classification_threshold: Option<f64>,
148    /// Cross-validation folds (if applicable)
149    pub cv_folds: Option<usize>,
150    /// Cost matrix for cost-sensitive evaluation
151    pub cost_matrix: Option<CostMatrix>,
152}
153
154/// Types of evaluation metrics.
155#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
156pub enum MetricType {
157    /// Area Under ROC Curve
158    AucRoc,
159    /// Area Under Precision-Recall Curve
160    AucPr,
161    /// Precision at K
162    PrecisionAtK(usize),
163    /// Recall at K
164    RecallAtK(usize),
165    /// F1 Score
166    F1Score,
167    /// Precision
168    Precision,
169    /// Recall
170    Recall,
171    /// Accuracy
172    Accuracy,
173    /// Matthews Correlation Coefficient
174    Mcc,
175    /// Mean Average Precision
176    Map,
177    /// Normalized Discounted Cumulative Gain
178    Ndcg,
179    /// Mean Squared Error
180    Mse,
181    /// Mean Absolute Error
182    Mae,
183    /// R-squared
184    R2,
185    /// Log Loss
186    LogLoss,
187    /// Cohen's Kappa
188    CohenKappa,
189    /// Macro F1
190    MacroF1,
191    /// Weighted F1
192    WeightedF1,
193}
194
195/// Cost matrix for cost-sensitive evaluation.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct CostMatrix {
198    /// Cost of false positive
199    pub false_positive_cost: f64,
200    /// Cost of false negative
201    pub false_negative_cost: f64,
202    /// Cost of true positive (usually 0 or negative for reward)
203    pub true_positive_cost: f64,
204    /// Cost of true negative (usually 0)
205    pub true_negative_cost: f64,
206}
207
208impl Default for CostMatrix {
209    fn default() -> Self {
210        Self {
211            false_positive_cost: 1.0,
212            false_negative_cost: 10.0, // Missing fraud is usually worse
213            true_positive_cost: 0.0,
214            true_negative_cost: 0.0,
215        }
216    }
217}
218
219/// Expected baseline result for a benchmark.
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct BaselineResult {
222    /// Model/algorithm name
223    pub model_name: String,
224    /// Model type
225    pub model_type: BaselineModelType,
226    /// Metric results
227    pub metrics: HashMap<String, f64>,
228    /// Training time in seconds
229    pub training_time_seconds: Option<f64>,
230    /// Inference time per sample in milliseconds
231    pub inference_time_ms: Option<f64>,
232    /// Notes about the baseline
233    pub notes: Option<String>,
234}
235
236/// Types of baseline models.
237#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
238pub enum BaselineModelType {
239    /// Random baseline
240    Random,
241    /// Majority class baseline
242    MajorityClass,
243    /// Rule-based baseline
244    RuleBased,
245    /// Isolation Forest
246    IsolationForest,
247    /// One-Class SVM
248    OneClassSvm,
249    /// Local Outlier Factor
250    Lof,
251    /// Logistic Regression
252    LogisticRegression,
253    /// Random Forest
254    RandomForest,
255    /// XGBoost
256    XgBoost,
257    /// LightGBM
258    LightGbm,
259    /// Neural Network
260    NeuralNetwork,
261    /// Graph Neural Network
262    Gnn,
263    /// Autoencoder
264    Autoencoder,
265    /// Custom model
266    Custom(String),
267}
268
269/// Leaderboard entry for benchmark results.
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct LeaderboardEntry {
272    /// Rank on the leaderboard
273    pub rank: usize,
274    /// Submission name
275    pub submission_name: String,
276    /// Submitter name/organization
277    pub submitter: String,
278    /// Model description
279    pub model_description: String,
280    /// Primary metric score
281    pub primary_score: f64,
282    /// All metric scores
283    pub all_scores: HashMap<String, f64>,
284    /// Submission date
285    pub submission_date: String,
286    /// Is this a baseline entry
287    pub is_baseline: bool,
288}
289
290/// Builder for creating benchmark suites.
291pub struct BenchmarkBuilder {
292    suite: BenchmarkSuite,
293}
294
295impl BenchmarkBuilder {
296    /// Create a new benchmark builder.
297    pub fn new(id: &str, name: &str) -> Self {
298        Self {
299            suite: BenchmarkSuite {
300                id: id.to_string(),
301                name: name.to_string(),
302                description: String::new(),
303                version: "1.0.0".to_string(),
304                task_type: BenchmarkTaskType::AnomalyDetection,
305                dataset: DatasetSpec {
306                    total_records: 1000,
307                    positive_count: 50,
308                    negative_count: 950,
309                    class_distribution: HashMap::new(),
310                    features: FeatureSet {
311                        numerical_count: 10,
312                        categorical_count: 5,
313                        temporal_count: 3,
314                        text_count: 1,
315                        feature_descriptions: HashMap::new(),
316                    },
317                    split_ratios: SplitRatios::default(),
318                    seed: 42,
319                    time_span_days: 365,
320                    num_companies: 1,
321                },
322                evaluation: EvaluationSpec {
323                    primary_metric: MetricType::AucRoc,
324                    metrics: vec![MetricType::AucRoc, MetricType::AucPr, MetricType::F1Score],
325                    classification_threshold: Some(0.5),
326                    cv_folds: None,
327                    cost_matrix: None,
328                },
329                baselines: Vec::new(),
330                metadata: HashMap::new(),
331            },
332        }
333    }
334
335    /// Set the description.
336    pub fn description(mut self, desc: &str) -> Self {
337        self.suite.description = desc.to_string();
338        self
339    }
340
341    /// Set the task type.
342    pub fn task_type(mut self, task_type: BenchmarkTaskType) -> Self {
343        self.suite.task_type = task_type;
344        self
345    }
346
347    /// Set dataset size.
348    pub fn dataset_size(mut self, total: usize, positive: usize) -> Self {
349        self.suite.dataset.total_records = total;
350        self.suite.dataset.positive_count = positive;
351        self.suite.dataset.negative_count = total.saturating_sub(positive);
352        self
353    }
354
355    /// Set class distribution.
356    pub fn class_distribution(mut self, distribution: HashMap<String, usize>) -> Self {
357        self.suite.dataset.class_distribution = distribution;
358        self
359    }
360
361    /// Set split ratios.
362    pub fn split_ratios(mut self, train: f64, val: f64, test: f64, temporal: bool) -> Self {
363        self.suite.dataset.split_ratios = SplitRatios {
364            train,
365            validation: val,
366            test,
367            temporal_split: temporal,
368        };
369        self
370    }
371
372    /// Set primary metric.
373    pub fn primary_metric(mut self, metric: MetricType) -> Self {
374        self.suite.evaluation.primary_metric = metric;
375        self
376    }
377
378    /// Set all metrics.
379    pub fn metrics(mut self, metrics: Vec<MetricType>) -> Self {
380        self.suite.evaluation.metrics = metrics;
381        self
382    }
383
384    /// Add a baseline result.
385    pub fn add_baseline(mut self, baseline: BaselineResult) -> Self {
386        self.suite.baselines.push(baseline);
387        self
388    }
389
390    /// Set the seed.
391    pub fn seed(mut self, seed: u64) -> Self {
392        self.suite.dataset.seed = seed;
393        self
394    }
395
396    /// Set time span.
397    pub fn time_span_days(mut self, days: u32) -> Self {
398        self.suite.dataset.time_span_days = days;
399        self
400    }
401
402    /// Set number of companies.
403    pub fn num_companies(mut self, n: usize) -> Self {
404        self.suite.dataset.num_companies = n;
405        self
406    }
407
408    /// Add metadata.
409    pub fn metadata(mut self, key: &str, value: &str) -> Self {
410        self.suite
411            .metadata
412            .insert(key.to_string(), value.to_string());
413        self
414    }
415
416    /// Build the benchmark suite.
417    pub fn build(self) -> BenchmarkSuite {
418        self.suite
419    }
420}
421
422// ============================================================================
423// Pre-defined Benchmark Suites
424// ============================================================================
425
426/// AnomalyBench-1K: 1000 transactions with known anomalies.
427pub fn anomaly_bench_1k() -> BenchmarkSuite {
428    let mut class_dist = HashMap::new();
429    class_dist.insert("normal".to_string(), 950);
430    class_dist.insert("anomaly".to_string(), 50);
431
432    BenchmarkBuilder::new("anomaly-bench-1k", "AnomalyBench-1K")
433        .description("1000 journal entry transactions with 5% anomaly rate. Balanced mix of fraud, error, and statistical anomalies.")
434        .task_type(BenchmarkTaskType::AnomalyDetection)
435        .dataset_size(1000, 50)
436        .class_distribution(class_dist)
437        .split_ratios(0.7, 0.15, 0.15, true)
438        .primary_metric(MetricType::AucPr)
439        .metrics(vec![
440            MetricType::AucRoc,
441            MetricType::AucPr,
442            MetricType::F1Score,
443            MetricType::PrecisionAtK(10),
444            MetricType::PrecisionAtK(50),
445            MetricType::Recall,
446        ])
447        .seed(42)
448        .time_span_days(90)
449        .num_companies(1)
450        .add_baseline(BaselineResult {
451            model_name: "Random".to_string(),
452            model_type: BaselineModelType::Random,
453            metrics: [
454                ("auc_roc".to_string(), 0.50),
455                ("auc_pr".to_string(), 0.05),
456                ("f1".to_string(), 0.09),
457            ].into_iter().collect(),
458            training_time_seconds: Some(0.0),
459            inference_time_ms: Some(0.01),
460            notes: Some("Random baseline for reference".to_string()),
461        })
462        .add_baseline(BaselineResult {
463            model_name: "IsolationForest".to_string(),
464            model_type: BaselineModelType::IsolationForest,
465            metrics: [
466                ("auc_roc".to_string(), 0.78),
467                ("auc_pr".to_string(), 0.42),
468                ("f1".to_string(), 0.45),
469            ].into_iter().collect(),
470            training_time_seconds: Some(0.5),
471            inference_time_ms: Some(0.1),
472            notes: Some("Unsupervised baseline".to_string()),
473        })
474        .add_baseline(BaselineResult {
475            model_name: "XGBoost".to_string(),
476            model_type: BaselineModelType::XgBoost,
477            metrics: [
478                ("auc_roc".to_string(), 0.92),
479                ("auc_pr".to_string(), 0.68),
480                ("f1".to_string(), 0.72),
481            ].into_iter().collect(),
482            training_time_seconds: Some(2.0),
483            inference_time_ms: Some(0.05),
484            notes: Some("Supervised baseline with full labels".to_string()),
485        })
486        .metadata("domain", "accounting")
487        .metadata("difficulty", "easy")
488        .build()
489}
490
491/// FraudDetect-10K: 10K transactions for fraud detection.
492pub fn fraud_detect_10k() -> BenchmarkSuite {
493    let mut class_dist = HashMap::new();
494    class_dist.insert("normal".to_string(), 9700);
495    class_dist.insert("fictitious_transaction".to_string(), 80);
496    class_dist.insert("duplicate_payment".to_string(), 60);
497    class_dist.insert("round_tripping".to_string(), 40);
498    class_dist.insert("threshold_manipulation".to_string(), 50);
499    class_dist.insert("self_approval".to_string(), 30);
500    class_dist.insert("other_fraud".to_string(), 40);
501
502    BenchmarkBuilder::new("fraud-detect-10k", "FraudDetect-10K")
503        .description("10K journal entries with multi-class fraud labels. Includes 6 fraud types with realistic class imbalance.")
504        .task_type(BenchmarkTaskType::FraudClassification)
505        .dataset_size(10000, 300)
506        .class_distribution(class_dist)
507        .split_ratios(0.7, 0.15, 0.15, true)
508        .primary_metric(MetricType::MacroF1)
509        .metrics(vec![
510            MetricType::AucRoc,
511            MetricType::MacroF1,
512            MetricType::WeightedF1,
513            MetricType::Recall,
514            MetricType::Precision,
515            MetricType::CohenKappa,
516        ])
517        .seed(12345)
518        .time_span_days(365)
519        .num_companies(3)
520        .add_baseline(BaselineResult {
521            model_name: "MajorityClass".to_string(),
522            model_type: BaselineModelType::MajorityClass,
523            metrics: [
524                ("macro_f1".to_string(), 0.07),
525                ("weighted_f1".to_string(), 0.94),
526            ].into_iter().collect(),
527            training_time_seconds: Some(0.0),
528            inference_time_ms: Some(0.01),
529            notes: Some("Predicts normal for all transactions".to_string()),
530        })
531        .add_baseline(BaselineResult {
532            model_name: "RandomForest".to_string(),
533            model_type: BaselineModelType::RandomForest,
534            metrics: [
535                ("macro_f1".to_string(), 0.58),
536                ("weighted_f1".to_string(), 0.96),
537                ("auc_roc".to_string(), 0.89),
538            ].into_iter().collect(),
539            training_time_seconds: Some(5.0),
540            inference_time_ms: Some(0.2),
541            notes: Some("Balanced class weights".to_string()),
542        })
543        .add_baseline(BaselineResult {
544            model_name: "LightGBM".to_string(),
545            model_type: BaselineModelType::LightGbm,
546            metrics: [
547                ("macro_f1".to_string(), 0.65),
548                ("weighted_f1".to_string(), 0.97),
549                ("auc_roc".to_string(), 0.93),
550            ].into_iter().collect(),
551            training_time_seconds: Some(3.0),
552            inference_time_ms: Some(0.05),
553            notes: Some("Optimized hyperparameters".to_string()),
554        })
555        .metadata("domain", "accounting")
556        .metadata("difficulty", "medium")
557        .build()
558}
559
560/// DataQuality-100K: 100K records for data quality detection.
561pub fn data_quality_100k() -> BenchmarkSuite {
562    let mut class_dist = HashMap::new();
563    class_dist.insert("clean".to_string(), 90000);
564    class_dist.insert("missing_value".to_string(), 3000);
565    class_dist.insert("typo".to_string(), 2000);
566    class_dist.insert("format_error".to_string(), 2000);
567    class_dist.insert("duplicate".to_string(), 1500);
568    class_dist.insert("encoding_issue".to_string(), 1000);
569    class_dist.insert("truncation".to_string(), 500);
570
571    BenchmarkBuilder::new("data-quality-100k", "DataQuality-100K")
572        .description("100K records with various data quality issues. Tests detection of missing values, typos, format errors, duplicates, and encoding issues.")
573        .task_type(BenchmarkTaskType::DataQualityDetection)
574        .dataset_size(100000, 10000)
575        .class_distribution(class_dist)
576        .split_ratios(0.8, 0.1, 0.1, false)
577        .primary_metric(MetricType::F1Score)
578        .metrics(vec![
579            MetricType::F1Score,
580            MetricType::Precision,
581            MetricType::Recall,
582            MetricType::AucRoc,
583            MetricType::MacroF1,
584        ])
585        .seed(99999)
586        .time_span_days(730) // 2 years
587        .num_companies(5)
588        .add_baseline(BaselineResult {
589            model_name: "RuleBased".to_string(),
590            model_type: BaselineModelType::RuleBased,
591            metrics: [
592                ("f1".to_string(), 0.72),
593                ("precision".to_string(), 0.85),
594                ("recall".to_string(), 0.62),
595            ].into_iter().collect(),
596            training_time_seconds: Some(0.0),
597            inference_time_ms: Some(0.5),
598            notes: Some("Regex patterns and null checks".to_string()),
599        })
600        .add_baseline(BaselineResult {
601            model_name: "LogisticRegression".to_string(),
602            model_type: BaselineModelType::LogisticRegression,
603            metrics: [
604                ("f1".to_string(), 0.78),
605                ("precision".to_string(), 0.80),
606                ("recall".to_string(), 0.76),
607            ].into_iter().collect(),
608            training_time_seconds: Some(2.0),
609            inference_time_ms: Some(0.02),
610            notes: Some("Character n-gram features".to_string()),
611        })
612        .add_baseline(BaselineResult {
613            model_name: "XGBoost".to_string(),
614            model_type: BaselineModelType::XgBoost,
615            metrics: [
616                ("f1".to_string(), 0.88),
617                ("precision".to_string(), 0.90),
618                ("recall".to_string(), 0.86),
619            ].into_iter().collect(),
620            training_time_seconds: Some(15.0),
621            inference_time_ms: Some(0.08),
622            notes: Some("Mixed feature types".to_string()),
623        })
624        .metadata("domain", "data_quality")
625        .metadata("difficulty", "medium")
626        .build()
627}
628
629/// EntityMatch-5K: 5K records for entity matching.
630pub fn entity_match_5k() -> BenchmarkSuite {
631    let mut class_dist = HashMap::new();
632    class_dist.insert("match".to_string(), 2000);
633    class_dist.insert("non_match".to_string(), 3000);
634
635    BenchmarkBuilder::new("entity-match-5k", "EntityMatch-5K")
636        .description("5K vendor/customer record pairs for entity matching. Includes name variations, typos, and abbreviations.")
637        .task_type(BenchmarkTaskType::EntityMatching)
638        .dataset_size(5000, 2000)
639        .class_distribution(class_dist)
640        .split_ratios(0.7, 0.15, 0.15, false)
641        .primary_metric(MetricType::F1Score)
642        .metrics(vec![
643            MetricType::F1Score,
644            MetricType::Precision,
645            MetricType::Recall,
646            MetricType::AucRoc,
647        ])
648        .seed(54321)
649        .time_span_days(365)
650        .num_companies(2)
651        .add_baseline(BaselineResult {
652            model_name: "ExactMatch".to_string(),
653            model_type: BaselineModelType::RuleBased,
654            metrics: [
655                ("f1".to_string(), 0.35),
656                ("precision".to_string(), 1.0),
657                ("recall".to_string(), 0.21),
658            ].into_iter().collect(),
659            training_time_seconds: Some(0.0),
660            inference_time_ms: Some(0.1),
661            notes: Some("Exact string matching only".to_string()),
662        })
663        .add_baseline(BaselineResult {
664            model_name: "FuzzyMatch".to_string(),
665            model_type: BaselineModelType::RuleBased,
666            metrics: [
667                ("f1".to_string(), 0.68),
668                ("precision".to_string(), 0.72),
669                ("recall".to_string(), 0.65),
670            ].into_iter().collect(),
671            training_time_seconds: Some(0.0),
672            inference_time_ms: Some(2.0),
673            notes: Some("Levenshtein distance threshold".to_string()),
674        })
675        .add_baseline(BaselineResult {
676            model_name: "Magellan".to_string(),
677            model_type: BaselineModelType::RandomForest,
678            metrics: [
679                ("f1".to_string(), 0.89),
680                ("precision".to_string(), 0.91),
681                ("recall".to_string(), 0.87),
682            ].into_iter().collect(),
683            training_time_seconds: Some(10.0),
684            inference_time_ms: Some(5.0),
685            notes: Some("Feature-based entity matcher".to_string()),
686        })
687        .metadata("domain", "master_data")
688        .metadata("difficulty", "hard")
689        .build()
690}
691
692/// GraphFraud-10K: 10K transactions with network structure.
693pub fn graph_fraud_10k() -> BenchmarkSuite {
694    let mut class_dist = HashMap::new();
695    class_dist.insert("normal".to_string(), 9500);
696    class_dist.insert("fraud_network".to_string(), 500);
697
698    BenchmarkBuilder::new("graph-fraud-10k", "GraphFraud-10K")
699        .description("10K transactions with entity graph structure. Fraud detection using network features and GNN models.")
700        .task_type(BenchmarkTaskType::GraphFraudDetection)
701        .dataset_size(10000, 500)
702        .class_distribution(class_dist)
703        .split_ratios(0.7, 0.15, 0.15, true)
704        .primary_metric(MetricType::AucPr)
705        .metrics(vec![
706            MetricType::AucPr,
707            MetricType::AucRoc,
708            MetricType::F1Score,
709            MetricType::PrecisionAtK(100),
710        ])
711        .seed(77777)
712        .time_span_days(365)
713        .num_companies(4)
714        .add_baseline(BaselineResult {
715            model_name: "NodeFeatures".to_string(),
716            model_type: BaselineModelType::XgBoost,
717            metrics: [
718                ("auc_pr".to_string(), 0.45),
719                ("auc_roc".to_string(), 0.78),
720            ].into_iter().collect(),
721            training_time_seconds: Some(5.0),
722            inference_time_ms: Some(0.1),
723            notes: Some("XGBoost on node features only".to_string()),
724        })
725        .add_baseline(BaselineResult {
726            model_name: "GraphSAGE".to_string(),
727            model_type: BaselineModelType::Gnn,
728            metrics: [
729                ("auc_pr".to_string(), 0.62),
730                ("auc_roc".to_string(), 0.88),
731            ].into_iter().collect(),
732            training_time_seconds: Some(60.0),
733            inference_time_ms: Some(5.0),
734            notes: Some("2-layer GraphSAGE".to_string()),
735        })
736        .add_baseline(BaselineResult {
737            model_name: "GAT".to_string(),
738            model_type: BaselineModelType::Gnn,
739            metrics: [
740                ("auc_pr".to_string(), 0.68),
741                ("auc_roc".to_string(), 0.91),
742            ].into_iter().collect(),
743            training_time_seconds: Some(90.0),
744            inference_time_ms: Some(8.0),
745            notes: Some("Graph Attention Network".to_string()),
746        })
747        .metadata("domain", "graph_analytics")
748        .metadata("difficulty", "hard")
749        .build()
750}
751
752/// Get all available benchmark suites.
753pub fn all_benchmarks() -> Vec<BenchmarkSuite> {
754    let mut benchmarks = vec![
755        anomaly_bench_1k(),
756        fraud_detect_10k(),
757        data_quality_100k(),
758        entity_match_5k(),
759        graph_fraud_10k(),
760    ];
761
762    // Add ACFE-calibrated benchmarks
763    benchmarks.extend(all_acfe_benchmarks());
764
765    // Add industry-specific benchmarks
766    benchmarks.extend(all_industry_benchmarks());
767
768    benchmarks
769}
770
771/// Get a benchmark by ID.
772pub fn get_benchmark(id: &str) -> Option<BenchmarkSuite> {
773    all_benchmarks().into_iter().find(|b| b.id == id)
774}
775
776#[cfg(test)]
777mod tests {
778    use super::*;
779
780    #[test]
781    fn test_anomaly_bench_1k() {
782        let bench = anomaly_bench_1k();
783        assert_eq!(bench.id, "anomaly-bench-1k");
784        assert_eq!(bench.dataset.total_records, 1000);
785        assert_eq!(bench.dataset.positive_count, 50);
786        assert_eq!(bench.baselines.len(), 3);
787    }
788
789    #[test]
790    fn test_fraud_detect_10k() {
791        let bench = fraud_detect_10k();
792        assert_eq!(bench.id, "fraud-detect-10k");
793        assert_eq!(bench.dataset.total_records, 10000);
794        assert_eq!(bench.task_type, BenchmarkTaskType::FraudClassification);
795    }
796
797    #[test]
798    fn test_data_quality_100k() {
799        let bench = data_quality_100k();
800        assert_eq!(bench.id, "data-quality-100k");
801        assert_eq!(bench.dataset.total_records, 100000);
802        assert!(bench.dataset.class_distribution.len() > 5);
803    }
804
805    #[test]
806    fn test_all_benchmarks() {
807        let benchmarks = all_benchmarks();
808        // 5 original + 3 ACFE + 5 industry = 13
809        assert_eq!(benchmarks.len(), 13);
810
811        // Verify all have baselines
812        for bench in &benchmarks {
813            assert!(
814                !bench.baselines.is_empty(),
815                "Benchmark {} has no baselines",
816                bench.id
817            );
818        }
819    }
820
821    #[test]
822    fn test_get_benchmark() {
823        assert!(get_benchmark("fraud-detect-10k").is_some());
824        assert!(get_benchmark("nonexistent").is_none());
825    }
826
827    #[test]
828    fn test_builder() {
829        let bench = BenchmarkBuilder::new("custom", "Custom Benchmark")
830            .description("A custom test benchmark")
831            .task_type(BenchmarkTaskType::AnomalyDetection)
832            .dataset_size(500, 25)
833            .seed(123)
834            .build();
835
836        assert_eq!(bench.id, "custom");
837        assert_eq!(bench.dataset.total_records, 500);
838        assert_eq!(bench.dataset.seed, 123);
839    }
840}