Skip to main content

datasynth_eval/benchmarks/
mod.rs

1//! Benchmark suite definitions for ML evaluation.
2//!
3//! Provides standardized benchmark datasets for:
4//! - Anomaly detection (AnomalyBench-1K)
5//! - Fraud detection (FraudDetect-10K)
6//! - Data quality detection (DataQuality-100K)
7//! - Entity matching (EntityMatch-5K)
8//!
9//! Each benchmark defines:
10//! - Dataset size and composition
11//! - Ground truth labels
12//! - Evaluation metrics
13//! - Expected baseline performance
14
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17
18/// A benchmark suite definition.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct BenchmarkSuite {
21    /// Unique identifier for the benchmark
22    pub id: String,
23    /// Human-readable name
24    pub name: String,
25    /// Description of the benchmark
26    pub description: String,
27    /// Version of the benchmark specification
28    pub version: String,
29    /// Task type being evaluated
30    pub task_type: BenchmarkTaskType,
31    /// Dataset specification
32    pub dataset: DatasetSpec,
33    /// Evaluation configuration
34    pub evaluation: EvaluationSpec,
35    /// Expected baseline results
36    pub baselines: Vec<BaselineResult>,
37    /// Metadata
38    pub metadata: HashMap<String, String>,
39}
40
41/// Types of benchmark tasks.
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
43pub enum BenchmarkTaskType {
44    /// Binary classification: anomaly vs normal
45    AnomalyDetection,
46    /// Multi-class classification: fraud type
47    FraudClassification,
48    /// Binary classification: data quality issue detection
49    DataQualityDetection,
50    /// Entity resolution: matching records
51    EntityMatching,
52    /// Multi-label classification: anomaly types
53    MultiLabelAnomalyDetection,
54    /// Regression: amount prediction
55    AmountPrediction,
56    /// Time series: anomaly detection over time
57    TimeSeriesAnomalyDetection,
58    /// Graph: fraud network detection
59    GraphFraudDetection,
60}
61
62/// Dataset specification.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct DatasetSpec {
65    /// Total number of records
66    pub total_records: usize,
67    /// Number of positive examples (anomalies/fraud/issues)
68    pub positive_count: usize,
69    /// Number of negative examples (normal/clean)
70    pub negative_count: usize,
71    /// Class distribution by label
72    pub class_distribution: HashMap<String, usize>,
73    /// Feature set description
74    pub features: FeatureSet,
75    /// Split ratios (train/val/test)
76    pub split_ratios: SplitRatios,
77    /// Seed for reproducibility
78    pub seed: u64,
79    /// Time span in days
80    pub time_span_days: u32,
81    /// Number of companies
82    pub num_companies: usize,
83}
84
85/// Feature set for the benchmark.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct FeatureSet {
88    /// Number of numerical features
89    pub numerical_count: usize,
90    /// Number of categorical features
91    pub categorical_count: usize,
92    /// Number of temporal features
93    pub temporal_count: usize,
94    /// Number of text features
95    pub text_count: usize,
96    /// Feature names and descriptions
97    pub feature_descriptions: HashMap<String, String>,
98}
99
100/// Train/validation/test split ratios.
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct SplitRatios {
103    /// Training set ratio (0.0-1.0)
104    pub train: f64,
105    /// Validation set ratio (0.0-1.0)
106    pub validation: f64,
107    /// Test set ratio (0.0-1.0)
108    pub test: f64,
109    /// Temporal split (if true, test set is the latest data)
110    pub temporal_split: bool,
111}
112
113impl Default for SplitRatios {
114    fn default() -> Self {
115        Self {
116            train: 0.7,
117            validation: 0.15,
118            test: 0.15,
119            temporal_split: false,
120        }
121    }
122}
123
124/// Evaluation specification.
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct EvaluationSpec {
127    /// Primary metric for ranking
128    pub primary_metric: MetricType,
129    /// All metrics to compute
130    pub metrics: Vec<MetricType>,
131    /// Threshold for binary classification
132    pub classification_threshold: Option<f64>,
133    /// Cross-validation folds (if applicable)
134    pub cv_folds: Option<usize>,
135    /// Cost matrix for cost-sensitive evaluation
136    pub cost_matrix: Option<CostMatrix>,
137}
138
139/// Types of evaluation metrics.
140#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
141pub enum MetricType {
142    /// Area Under ROC Curve
143    AucRoc,
144    /// Area Under Precision-Recall Curve
145    AucPr,
146    /// Precision at K
147    PrecisionAtK(usize),
148    /// Recall at K
149    RecallAtK(usize),
150    /// F1 Score
151    F1Score,
152    /// Precision
153    Precision,
154    /// Recall
155    Recall,
156    /// Accuracy
157    Accuracy,
158    /// Matthews Correlation Coefficient
159    Mcc,
160    /// Mean Average Precision
161    Map,
162    /// Normalized Discounted Cumulative Gain
163    Ndcg,
164    /// Mean Squared Error
165    Mse,
166    /// Mean Absolute Error
167    Mae,
168    /// R-squared
169    R2,
170    /// Log Loss
171    LogLoss,
172    /// Cohen's Kappa
173    CohenKappa,
174    /// Macro F1
175    MacroF1,
176    /// Weighted F1
177    WeightedF1,
178}
179
180/// Cost matrix for cost-sensitive evaluation.
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct CostMatrix {
183    /// Cost of false positive
184    pub false_positive_cost: f64,
185    /// Cost of false negative
186    pub false_negative_cost: f64,
187    /// Cost of true positive (usually 0 or negative for reward)
188    pub true_positive_cost: f64,
189    /// Cost of true negative (usually 0)
190    pub true_negative_cost: f64,
191}
192
193impl Default for CostMatrix {
194    fn default() -> Self {
195        Self {
196            false_positive_cost: 1.0,
197            false_negative_cost: 10.0, // Missing fraud is usually worse
198            true_positive_cost: 0.0,
199            true_negative_cost: 0.0,
200        }
201    }
202}
203
204/// Expected baseline result for a benchmark.
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct BaselineResult {
207    /// Model/algorithm name
208    pub model_name: String,
209    /// Model type
210    pub model_type: BaselineModelType,
211    /// Metric results
212    pub metrics: HashMap<String, f64>,
213    /// Training time in seconds
214    pub training_time_seconds: Option<f64>,
215    /// Inference time per sample in milliseconds
216    pub inference_time_ms: Option<f64>,
217    /// Notes about the baseline
218    pub notes: Option<String>,
219}
220
221/// Types of baseline models.
222#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
223pub enum BaselineModelType {
224    /// Random baseline
225    Random,
226    /// Majority class baseline
227    MajorityClass,
228    /// Rule-based baseline
229    RuleBased,
230    /// Isolation Forest
231    IsolationForest,
232    /// One-Class SVM
233    OneClassSvm,
234    /// Local Outlier Factor
235    Lof,
236    /// Logistic Regression
237    LogisticRegression,
238    /// Random Forest
239    RandomForest,
240    /// XGBoost
241    XgBoost,
242    /// LightGBM
243    LightGbm,
244    /// Neural Network
245    NeuralNetwork,
246    /// Graph Neural Network
247    Gnn,
248    /// Autoencoder
249    Autoencoder,
250    /// Custom model
251    Custom(String),
252}
253
254/// Leaderboard entry for benchmark results.
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct LeaderboardEntry {
257    /// Rank on the leaderboard
258    pub rank: usize,
259    /// Submission name
260    pub submission_name: String,
261    /// Submitter name/organization
262    pub submitter: String,
263    /// Model description
264    pub model_description: String,
265    /// Primary metric score
266    pub primary_score: f64,
267    /// All metric scores
268    pub all_scores: HashMap<String, f64>,
269    /// Submission date
270    pub submission_date: String,
271    /// Is this a baseline entry
272    pub is_baseline: bool,
273}
274
275/// Builder for creating benchmark suites.
276pub struct BenchmarkBuilder {
277    suite: BenchmarkSuite,
278}
279
280impl BenchmarkBuilder {
281    /// Create a new benchmark builder.
282    pub fn new(id: &str, name: &str) -> Self {
283        Self {
284            suite: BenchmarkSuite {
285                id: id.to_string(),
286                name: name.to_string(),
287                description: String::new(),
288                version: "1.0.0".to_string(),
289                task_type: BenchmarkTaskType::AnomalyDetection,
290                dataset: DatasetSpec {
291                    total_records: 1000,
292                    positive_count: 50,
293                    negative_count: 950,
294                    class_distribution: HashMap::new(),
295                    features: FeatureSet {
296                        numerical_count: 10,
297                        categorical_count: 5,
298                        temporal_count: 3,
299                        text_count: 1,
300                        feature_descriptions: HashMap::new(),
301                    },
302                    split_ratios: SplitRatios::default(),
303                    seed: 42,
304                    time_span_days: 365,
305                    num_companies: 1,
306                },
307                evaluation: EvaluationSpec {
308                    primary_metric: MetricType::AucRoc,
309                    metrics: vec![MetricType::AucRoc, MetricType::AucPr, MetricType::F1Score],
310                    classification_threshold: Some(0.5),
311                    cv_folds: None,
312                    cost_matrix: None,
313                },
314                baselines: Vec::new(),
315                metadata: HashMap::new(),
316            },
317        }
318    }
319
320    /// Set the description.
321    pub fn description(mut self, desc: &str) -> Self {
322        self.suite.description = desc.to_string();
323        self
324    }
325
326    /// Set the task type.
327    pub fn task_type(mut self, task_type: BenchmarkTaskType) -> Self {
328        self.suite.task_type = task_type;
329        self
330    }
331
332    /// Set dataset size.
333    pub fn dataset_size(mut self, total: usize, positive: usize) -> Self {
334        self.suite.dataset.total_records = total;
335        self.suite.dataset.positive_count = positive;
336        self.suite.dataset.negative_count = total.saturating_sub(positive);
337        self
338    }
339
340    /// Set class distribution.
341    pub fn class_distribution(mut self, distribution: HashMap<String, usize>) -> Self {
342        self.suite.dataset.class_distribution = distribution;
343        self
344    }
345
346    /// Set split ratios.
347    pub fn split_ratios(mut self, train: f64, val: f64, test: f64, temporal: bool) -> Self {
348        self.suite.dataset.split_ratios = SplitRatios {
349            train,
350            validation: val,
351            test,
352            temporal_split: temporal,
353        };
354        self
355    }
356
357    /// Set primary metric.
358    pub fn primary_metric(mut self, metric: MetricType) -> Self {
359        self.suite.evaluation.primary_metric = metric;
360        self
361    }
362
363    /// Set all metrics.
364    pub fn metrics(mut self, metrics: Vec<MetricType>) -> Self {
365        self.suite.evaluation.metrics = metrics;
366        self
367    }
368
369    /// Add a baseline result.
370    pub fn add_baseline(mut self, baseline: BaselineResult) -> Self {
371        self.suite.baselines.push(baseline);
372        self
373    }
374
375    /// Set the seed.
376    pub fn seed(mut self, seed: u64) -> Self {
377        self.suite.dataset.seed = seed;
378        self
379    }
380
381    /// Set time span.
382    pub fn time_span_days(mut self, days: u32) -> Self {
383        self.suite.dataset.time_span_days = days;
384        self
385    }
386
387    /// Set number of companies.
388    pub fn num_companies(mut self, n: usize) -> Self {
389        self.suite.dataset.num_companies = n;
390        self
391    }
392
393    /// Add metadata.
394    pub fn metadata(mut self, key: &str, value: &str) -> Self {
395        self.suite
396            .metadata
397            .insert(key.to_string(), value.to_string());
398        self
399    }
400
401    /// Build the benchmark suite.
402    pub fn build(self) -> BenchmarkSuite {
403        self.suite
404    }
405}
406
407// ============================================================================
408// Pre-defined Benchmark Suites
409// ============================================================================
410
411/// AnomalyBench-1K: 1000 transactions with known anomalies.
412pub fn anomaly_bench_1k() -> BenchmarkSuite {
413    let mut class_dist = HashMap::new();
414    class_dist.insert("normal".to_string(), 950);
415    class_dist.insert("anomaly".to_string(), 50);
416
417    BenchmarkBuilder::new("anomaly-bench-1k", "AnomalyBench-1K")
418        .description("1000 journal entry transactions with 5% anomaly rate. Balanced mix of fraud, error, and statistical anomalies.")
419        .task_type(BenchmarkTaskType::AnomalyDetection)
420        .dataset_size(1000, 50)
421        .class_distribution(class_dist)
422        .split_ratios(0.7, 0.15, 0.15, true)
423        .primary_metric(MetricType::AucPr)
424        .metrics(vec![
425            MetricType::AucRoc,
426            MetricType::AucPr,
427            MetricType::F1Score,
428            MetricType::PrecisionAtK(10),
429            MetricType::PrecisionAtK(50),
430            MetricType::Recall,
431        ])
432        .seed(42)
433        .time_span_days(90)
434        .num_companies(1)
435        .add_baseline(BaselineResult {
436            model_name: "Random".to_string(),
437            model_type: BaselineModelType::Random,
438            metrics: [
439                ("auc_roc".to_string(), 0.50),
440                ("auc_pr".to_string(), 0.05),
441                ("f1".to_string(), 0.09),
442            ].into_iter().collect(),
443            training_time_seconds: Some(0.0),
444            inference_time_ms: Some(0.01),
445            notes: Some("Random baseline for reference".to_string()),
446        })
447        .add_baseline(BaselineResult {
448            model_name: "IsolationForest".to_string(),
449            model_type: BaselineModelType::IsolationForest,
450            metrics: [
451                ("auc_roc".to_string(), 0.78),
452                ("auc_pr".to_string(), 0.42),
453                ("f1".to_string(), 0.45),
454            ].into_iter().collect(),
455            training_time_seconds: Some(0.5),
456            inference_time_ms: Some(0.1),
457            notes: Some("Unsupervised baseline".to_string()),
458        })
459        .add_baseline(BaselineResult {
460            model_name: "XGBoost".to_string(),
461            model_type: BaselineModelType::XgBoost,
462            metrics: [
463                ("auc_roc".to_string(), 0.92),
464                ("auc_pr".to_string(), 0.68),
465                ("f1".to_string(), 0.72),
466            ].into_iter().collect(),
467            training_time_seconds: Some(2.0),
468            inference_time_ms: Some(0.05),
469            notes: Some("Supervised baseline with full labels".to_string()),
470        })
471        .metadata("domain", "accounting")
472        .metadata("difficulty", "easy")
473        .build()
474}
475
476/// FraudDetect-10K: 10K transactions for fraud detection.
477pub fn fraud_detect_10k() -> BenchmarkSuite {
478    let mut class_dist = HashMap::new();
479    class_dist.insert("normal".to_string(), 9700);
480    class_dist.insert("fictitious_transaction".to_string(), 80);
481    class_dist.insert("duplicate_payment".to_string(), 60);
482    class_dist.insert("round_tripping".to_string(), 40);
483    class_dist.insert("threshold_manipulation".to_string(), 50);
484    class_dist.insert("self_approval".to_string(), 30);
485    class_dist.insert("other_fraud".to_string(), 40);
486
487    BenchmarkBuilder::new("fraud-detect-10k", "FraudDetect-10K")
488        .description("10K journal entries with multi-class fraud labels. Includes 6 fraud types with realistic class imbalance.")
489        .task_type(BenchmarkTaskType::FraudClassification)
490        .dataset_size(10000, 300)
491        .class_distribution(class_dist)
492        .split_ratios(0.7, 0.15, 0.15, true)
493        .primary_metric(MetricType::MacroF1)
494        .metrics(vec![
495            MetricType::AucRoc,
496            MetricType::MacroF1,
497            MetricType::WeightedF1,
498            MetricType::Recall,
499            MetricType::Precision,
500            MetricType::CohenKappa,
501        ])
502        .seed(12345)
503        .time_span_days(365)
504        .num_companies(3)
505        .add_baseline(BaselineResult {
506            model_name: "MajorityClass".to_string(),
507            model_type: BaselineModelType::MajorityClass,
508            metrics: [
509                ("macro_f1".to_string(), 0.07),
510                ("weighted_f1".to_string(), 0.94),
511            ].into_iter().collect(),
512            training_time_seconds: Some(0.0),
513            inference_time_ms: Some(0.01),
514            notes: Some("Predicts normal for all transactions".to_string()),
515        })
516        .add_baseline(BaselineResult {
517            model_name: "RandomForest".to_string(),
518            model_type: BaselineModelType::RandomForest,
519            metrics: [
520                ("macro_f1".to_string(), 0.58),
521                ("weighted_f1".to_string(), 0.96),
522                ("auc_roc".to_string(), 0.89),
523            ].into_iter().collect(),
524            training_time_seconds: Some(5.0),
525            inference_time_ms: Some(0.2),
526            notes: Some("Balanced class weights".to_string()),
527        })
528        .add_baseline(BaselineResult {
529            model_name: "LightGBM".to_string(),
530            model_type: BaselineModelType::LightGbm,
531            metrics: [
532                ("macro_f1".to_string(), 0.65),
533                ("weighted_f1".to_string(), 0.97),
534                ("auc_roc".to_string(), 0.93),
535            ].into_iter().collect(),
536            training_time_seconds: Some(3.0),
537            inference_time_ms: Some(0.05),
538            notes: Some("Optimized hyperparameters".to_string()),
539        })
540        .metadata("domain", "accounting")
541        .metadata("difficulty", "medium")
542        .build()
543}
544
545/// DataQuality-100K: 100K records for data quality detection.
546pub fn data_quality_100k() -> BenchmarkSuite {
547    let mut class_dist = HashMap::new();
548    class_dist.insert("clean".to_string(), 90000);
549    class_dist.insert("missing_value".to_string(), 3000);
550    class_dist.insert("typo".to_string(), 2000);
551    class_dist.insert("format_error".to_string(), 2000);
552    class_dist.insert("duplicate".to_string(), 1500);
553    class_dist.insert("encoding_issue".to_string(), 1000);
554    class_dist.insert("truncation".to_string(), 500);
555
556    BenchmarkBuilder::new("data-quality-100k", "DataQuality-100K")
557        .description("100K records with various data quality issues. Tests detection of missing values, typos, format errors, duplicates, and encoding issues.")
558        .task_type(BenchmarkTaskType::DataQualityDetection)
559        .dataset_size(100000, 10000)
560        .class_distribution(class_dist)
561        .split_ratios(0.8, 0.1, 0.1, false)
562        .primary_metric(MetricType::F1Score)
563        .metrics(vec![
564            MetricType::F1Score,
565            MetricType::Precision,
566            MetricType::Recall,
567            MetricType::AucRoc,
568            MetricType::MacroF1,
569        ])
570        .seed(99999)
571        .time_span_days(730) // 2 years
572        .num_companies(5)
573        .add_baseline(BaselineResult {
574            model_name: "RuleBased".to_string(),
575            model_type: BaselineModelType::RuleBased,
576            metrics: [
577                ("f1".to_string(), 0.72),
578                ("precision".to_string(), 0.85),
579                ("recall".to_string(), 0.62),
580            ].into_iter().collect(),
581            training_time_seconds: Some(0.0),
582            inference_time_ms: Some(0.5),
583            notes: Some("Regex patterns and null checks".to_string()),
584        })
585        .add_baseline(BaselineResult {
586            model_name: "LogisticRegression".to_string(),
587            model_type: BaselineModelType::LogisticRegression,
588            metrics: [
589                ("f1".to_string(), 0.78),
590                ("precision".to_string(), 0.80),
591                ("recall".to_string(), 0.76),
592            ].into_iter().collect(),
593            training_time_seconds: Some(2.0),
594            inference_time_ms: Some(0.02),
595            notes: Some("Character n-gram features".to_string()),
596        })
597        .add_baseline(BaselineResult {
598            model_name: "XGBoost".to_string(),
599            model_type: BaselineModelType::XgBoost,
600            metrics: [
601                ("f1".to_string(), 0.88),
602                ("precision".to_string(), 0.90),
603                ("recall".to_string(), 0.86),
604            ].into_iter().collect(),
605            training_time_seconds: Some(15.0),
606            inference_time_ms: Some(0.08),
607            notes: Some("Mixed feature types".to_string()),
608        })
609        .metadata("domain", "data_quality")
610        .metadata("difficulty", "medium")
611        .build()
612}
613
614/// EntityMatch-5K: 5K records for entity matching.
615pub fn entity_match_5k() -> BenchmarkSuite {
616    let mut class_dist = HashMap::new();
617    class_dist.insert("match".to_string(), 2000);
618    class_dist.insert("non_match".to_string(), 3000);
619
620    BenchmarkBuilder::new("entity-match-5k", "EntityMatch-5K")
621        .description("5K vendor/customer record pairs for entity matching. Includes name variations, typos, and abbreviations.")
622        .task_type(BenchmarkTaskType::EntityMatching)
623        .dataset_size(5000, 2000)
624        .class_distribution(class_dist)
625        .split_ratios(0.7, 0.15, 0.15, false)
626        .primary_metric(MetricType::F1Score)
627        .metrics(vec![
628            MetricType::F1Score,
629            MetricType::Precision,
630            MetricType::Recall,
631            MetricType::AucRoc,
632        ])
633        .seed(54321)
634        .time_span_days(365)
635        .num_companies(2)
636        .add_baseline(BaselineResult {
637            model_name: "ExactMatch".to_string(),
638            model_type: BaselineModelType::RuleBased,
639            metrics: [
640                ("f1".to_string(), 0.35),
641                ("precision".to_string(), 1.0),
642                ("recall".to_string(), 0.21),
643            ].into_iter().collect(),
644            training_time_seconds: Some(0.0),
645            inference_time_ms: Some(0.1),
646            notes: Some("Exact string matching only".to_string()),
647        })
648        .add_baseline(BaselineResult {
649            model_name: "FuzzyMatch".to_string(),
650            model_type: BaselineModelType::RuleBased,
651            metrics: [
652                ("f1".to_string(), 0.68),
653                ("precision".to_string(), 0.72),
654                ("recall".to_string(), 0.65),
655            ].into_iter().collect(),
656            training_time_seconds: Some(0.0),
657            inference_time_ms: Some(2.0),
658            notes: Some("Levenshtein distance threshold".to_string()),
659        })
660        .add_baseline(BaselineResult {
661            model_name: "Magellan".to_string(),
662            model_type: BaselineModelType::RandomForest,
663            metrics: [
664                ("f1".to_string(), 0.89),
665                ("precision".to_string(), 0.91),
666                ("recall".to_string(), 0.87),
667            ].into_iter().collect(),
668            training_time_seconds: Some(10.0),
669            inference_time_ms: Some(5.0),
670            notes: Some("Feature-based entity matcher".to_string()),
671        })
672        .metadata("domain", "master_data")
673        .metadata("difficulty", "hard")
674        .build()
675}
676
677/// GraphFraud-10K: 10K transactions with network structure.
678pub fn graph_fraud_10k() -> BenchmarkSuite {
679    let mut class_dist = HashMap::new();
680    class_dist.insert("normal".to_string(), 9500);
681    class_dist.insert("fraud_network".to_string(), 500);
682
683    BenchmarkBuilder::new("graph-fraud-10k", "GraphFraud-10K")
684        .description("10K transactions with entity graph structure. Fraud detection using network features and GNN models.")
685        .task_type(BenchmarkTaskType::GraphFraudDetection)
686        .dataset_size(10000, 500)
687        .class_distribution(class_dist)
688        .split_ratios(0.7, 0.15, 0.15, true)
689        .primary_metric(MetricType::AucPr)
690        .metrics(vec![
691            MetricType::AucPr,
692            MetricType::AucRoc,
693            MetricType::F1Score,
694            MetricType::PrecisionAtK(100),
695        ])
696        .seed(77777)
697        .time_span_days(365)
698        .num_companies(4)
699        .add_baseline(BaselineResult {
700            model_name: "NodeFeatures".to_string(),
701            model_type: BaselineModelType::XgBoost,
702            metrics: [
703                ("auc_pr".to_string(), 0.45),
704                ("auc_roc".to_string(), 0.78),
705            ].into_iter().collect(),
706            training_time_seconds: Some(5.0),
707            inference_time_ms: Some(0.1),
708            notes: Some("XGBoost on node features only".to_string()),
709        })
710        .add_baseline(BaselineResult {
711            model_name: "GraphSAGE".to_string(),
712            model_type: BaselineModelType::Gnn,
713            metrics: [
714                ("auc_pr".to_string(), 0.62),
715                ("auc_roc".to_string(), 0.88),
716            ].into_iter().collect(),
717            training_time_seconds: Some(60.0),
718            inference_time_ms: Some(5.0),
719            notes: Some("2-layer GraphSAGE".to_string()),
720        })
721        .add_baseline(BaselineResult {
722            model_name: "GAT".to_string(),
723            model_type: BaselineModelType::Gnn,
724            metrics: [
725                ("auc_pr".to_string(), 0.68),
726                ("auc_roc".to_string(), 0.91),
727            ].into_iter().collect(),
728            training_time_seconds: Some(90.0),
729            inference_time_ms: Some(8.0),
730            notes: Some("Graph Attention Network".to_string()),
731        })
732        .metadata("domain", "graph_analytics")
733        .metadata("difficulty", "hard")
734        .build()
735}
736
737/// Get all available benchmark suites.
738pub fn all_benchmarks() -> Vec<BenchmarkSuite> {
739    vec![
740        anomaly_bench_1k(),
741        fraud_detect_10k(),
742        data_quality_100k(),
743        entity_match_5k(),
744        graph_fraud_10k(),
745    ]
746}
747
748/// Get a benchmark by ID.
749pub fn get_benchmark(id: &str) -> Option<BenchmarkSuite> {
750    all_benchmarks().into_iter().find(|b| b.id == id)
751}
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756
757    #[test]
758    fn test_anomaly_bench_1k() {
759        let bench = anomaly_bench_1k();
760        assert_eq!(bench.id, "anomaly-bench-1k");
761        assert_eq!(bench.dataset.total_records, 1000);
762        assert_eq!(bench.dataset.positive_count, 50);
763        assert_eq!(bench.baselines.len(), 3);
764    }
765
766    #[test]
767    fn test_fraud_detect_10k() {
768        let bench = fraud_detect_10k();
769        assert_eq!(bench.id, "fraud-detect-10k");
770        assert_eq!(bench.dataset.total_records, 10000);
771        assert_eq!(bench.task_type, BenchmarkTaskType::FraudClassification);
772    }
773
774    #[test]
775    fn test_data_quality_100k() {
776        let bench = data_quality_100k();
777        assert_eq!(bench.id, "data-quality-100k");
778        assert_eq!(bench.dataset.total_records, 100000);
779        assert!(bench.dataset.class_distribution.len() > 5);
780    }
781
782    #[test]
783    fn test_all_benchmarks() {
784        let benchmarks = all_benchmarks();
785        assert_eq!(benchmarks.len(), 5);
786
787        // Verify all have baselines
788        for bench in &benchmarks {
789            assert!(
790                !bench.baselines.is_empty(),
791                "Benchmark {} has no baselines",
792                bench.id
793            );
794        }
795    }
796
797    #[test]
798    fn test_get_benchmark() {
799        assert!(get_benchmark("fraud-detect-10k").is_some());
800        assert!(get_benchmark("nonexistent").is_none());
801    }
802
803    #[test]
804    fn test_builder() {
805        let bench = BenchmarkBuilder::new("custom", "Custom Benchmark")
806            .description("A custom test benchmark")
807            .task_type(BenchmarkTaskType::AnomalyDetection)
808            .dataset_size(500, 25)
809            .seed(123)
810            .build();
811
812        assert_eq!(bench.id, "custom");
813        assert_eq!(bench.dataset.total_records, 500);
814        assert_eq!(bench.dataset.seed, 123);
815    }
816}