Skip to main content

datasynth_eval/ml/
baselines.rs

1//! ML Baseline Tasks and Metrics
2//!
3//! Defines standard ML tasks for benchmarking synthetic data quality:
4//! - Anomaly Detection (isolation forest, autoencoder baselines)
5//! - Entity Matching (duplicate detection)
6//! - Link Prediction (graph-based fraud detection)
7//! - Time Series Forecasting (amount and volume prediction)
8
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// ML task type for benchmarking.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum MLTaskType {
16    /// Anomaly detection (fraud, errors, outliers).
17    AnomalyDetection,
18    /// Entity matching (duplicate detection, record linkage).
19    EntityMatching,
20    /// Link prediction (graph-based fraud, relationship discovery).
21    LinkPrediction,
22    /// Time series forecasting (amount, volume prediction).
23    TimeSeriesForecasting,
24}
25
26impl std::fmt::Display for MLTaskType {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            MLTaskType::AnomalyDetection => write!(f, "Anomaly Detection"),
30            MLTaskType::EntityMatching => write!(f, "Entity Matching"),
31            MLTaskType::LinkPrediction => write!(f, "Link Prediction"),
32            MLTaskType::TimeSeriesForecasting => write!(f, "Time Series Forecasting"),
33        }
34    }
35}
36
37/// Baseline algorithm for a task.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case")]
40pub enum BaselineAlgorithm {
41    // Anomaly Detection
42    IsolationForest,
43    LocalOutlierFactor,
44    OneClassSVM,
45    Autoencoder,
46
47    // Entity Matching
48    ExactMatch,
49    JaccardSimilarity,
50    LevenshteinDistance,
51    TFIDFCosine,
52
53    // Link Prediction
54    CommonNeighbors,
55    AdamicAdar,
56    ResourceAllocation,
57    GraphNeuralNetwork,
58
59    // Time Series
60    ARIMA,
61    ExponentialSmoothing,
62    Prophet,
63    LSTM,
64}
65
66impl std::fmt::Display for BaselineAlgorithm {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        match self {
69            BaselineAlgorithm::IsolationForest => write!(f, "Isolation Forest"),
70            BaselineAlgorithm::LocalOutlierFactor => write!(f, "Local Outlier Factor"),
71            BaselineAlgorithm::OneClassSVM => write!(f, "One-Class SVM"),
72            BaselineAlgorithm::Autoencoder => write!(f, "Autoencoder"),
73            BaselineAlgorithm::ExactMatch => write!(f, "Exact Match"),
74            BaselineAlgorithm::JaccardSimilarity => write!(f, "Jaccard Similarity"),
75            BaselineAlgorithm::LevenshteinDistance => write!(f, "Levenshtein Distance"),
76            BaselineAlgorithm::TFIDFCosine => write!(f, "TF-IDF Cosine"),
77            BaselineAlgorithm::CommonNeighbors => write!(f, "Common Neighbors"),
78            BaselineAlgorithm::AdamicAdar => write!(f, "Adamic-Adar"),
79            BaselineAlgorithm::ResourceAllocation => write!(f, "Resource Allocation"),
80            BaselineAlgorithm::GraphNeuralNetwork => write!(f, "Graph Neural Network"),
81            BaselineAlgorithm::ARIMA => write!(f, "ARIMA"),
82            BaselineAlgorithm::ExponentialSmoothing => write!(f, "Exponential Smoothing"),
83            BaselineAlgorithm::Prophet => write!(f, "Prophet"),
84            BaselineAlgorithm::LSTM => write!(f, "LSTM"),
85        }
86    }
87}
88
89/// Classification metrics for binary/multiclass tasks.
90#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91pub struct ClassificationMetrics {
92    /// Accuracy (correct predictions / total).
93    pub accuracy: f64,
94    /// Precision (true positives / predicted positives).
95    pub precision: f64,
96    /// Recall (true positives / actual positives).
97    pub recall: f64,
98    /// F1 score (harmonic mean of precision and recall).
99    pub f1_score: f64,
100    /// Area under ROC curve.
101    pub auc_roc: f64,
102    /// Area under precision-recall curve.
103    pub auc_pr: f64,
104    /// Matthews correlation coefficient.
105    pub mcc: f64,
106}
107
108impl ClassificationMetrics {
109    /// Create metrics from confusion matrix values.
110    pub fn from_confusion(tp: u64, tn: u64, fp: u64, fn_: u64) -> Self {
111        let total = (tp + tn + fp + fn_) as f64;
112        let accuracy = if total > 0.0 {
113            (tp + tn) as f64 / total
114        } else {
115            0.0
116        };
117
118        let precision = if tp + fp > 0 {
119            tp as f64 / (tp + fp) as f64
120        } else {
121            0.0
122        };
123        let recall = if tp + fn_ > 0 {
124            tp as f64 / (tp + fn_) as f64
125        } else {
126            0.0
127        };
128        let f1_score = if precision + recall > 0.0 {
129            2.0 * precision * recall / (precision + recall)
130        } else {
131            0.0
132        };
133
134        // Matthews Correlation Coefficient
135        let mcc_num = (tp * tn) as f64 - (fp * fn_) as f64;
136        let mcc_denom =
137            ((tp + fp) as f64 * (tp + fn_) as f64 * (tn + fp) as f64 * (tn + fn_) as f64).sqrt();
138        let mcc = if mcc_denom > 0.0 {
139            mcc_num / mcc_denom
140        } else {
141            0.0
142        };
143
144        Self {
145            accuracy,
146            precision,
147            recall,
148            f1_score,
149            auc_roc: 0.0, // Requires probability scores
150            auc_pr: 0.0,  // Requires probability scores
151            mcc,
152        }
153    }
154}
155
156/// Regression metrics for continuous prediction tasks.
157#[derive(Debug, Clone, Default, Serialize, Deserialize)]
158pub struct RegressionMetrics {
159    /// Mean Absolute Error.
160    pub mae: f64,
161    /// Mean Squared Error.
162    pub mse: f64,
163    /// Root Mean Squared Error.
164    pub rmse: f64,
165    /// Mean Absolute Percentage Error.
166    pub mape: f64,
167    /// R-squared (coefficient of determination).
168    pub r2: f64,
169}
170
171impl RegressionMetrics {
172    /// Create metrics from predictions and actuals.
173    pub fn from_predictions(predictions: &[f64], actuals: &[f64]) -> Self {
174        if predictions.len() != actuals.len() || predictions.is_empty() {
175            return Self::default();
176        }
177
178        let n = predictions.len() as f64;
179
180        // Calculate errors
181        let errors: Vec<f64> = predictions
182            .iter()
183            .zip(actuals.iter())
184            .map(|(p, a)| p - a)
185            .collect();
186
187        let mae = errors.iter().map(|e| e.abs()).sum::<f64>() / n;
188        let mse = errors.iter().map(|e| e * e).sum::<f64>() / n;
189        let rmse = mse.sqrt();
190
191        // MAPE (avoid division by zero)
192        let mape = predictions
193            .iter()
194            .zip(actuals.iter())
195            .filter(|(_, a)| a.abs() > 1e-10)
196            .map(|(p, a)| ((p - a) / a).abs())
197            .sum::<f64>()
198            / n
199            * 100.0;
200
201        // R-squared
202        let actual_mean = actuals.iter().sum::<f64>() / n;
203        let ss_tot: f64 = actuals.iter().map(|a| (a - actual_mean).powi(2)).sum();
204        let ss_res: f64 = errors.iter().map(|e| e * e).sum();
205        let r2 = if ss_tot > 0.0 {
206            1.0 - (ss_res / ss_tot)
207        } else {
208            0.0
209        };
210
211        Self {
212            mae,
213            mse,
214            rmse,
215            mape,
216            r2,
217        }
218    }
219}
220
221/// Ranking metrics for link prediction and recommendation.
222#[derive(Debug, Clone, Default, Serialize, Deserialize)]
223pub struct RankingMetrics {
224    /// Mean Reciprocal Rank.
225    pub mrr: f64,
226    /// Hits@1 (top-1 accuracy).
227    pub hits_at_1: f64,
228    /// Hits@10.
229    pub hits_at_10: f64,
230    /// Hits@100.
231    pub hits_at_100: f64,
232    /// Normalized Discounted Cumulative Gain.
233    pub ndcg: f64,
234}
235
236/// ML baseline task definition.
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct BaselineTask {
239    /// Task identifier.
240    pub id: String,
241    /// Task type.
242    pub task_type: MLTaskType,
243    /// Human-readable description.
244    pub description: String,
245    /// Required data fields.
246    pub required_fields: Vec<String>,
247    /// Target field (label).
248    pub target_field: String,
249    /// Recommended baseline algorithms.
250    pub recommended_algorithms: Vec<BaselineAlgorithm>,
251    /// Expected baseline performance (for reference).
252    pub expected_metrics: ExpectedMetrics,
253}
254
255/// Expected performance metrics for a task.
256#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct ExpectedMetrics {
258    /// Minimum acceptable F1/R2/MRR depending on task type.
259    pub min_acceptable: f64,
260    /// Good performance threshold.
261    pub good_threshold: f64,
262    /// Excellent performance threshold.
263    pub excellent_threshold: f64,
264    /// Primary metric name.
265    pub primary_metric: String,
266}
267
268/// Baseline task results.
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct BaselineResult {
271    /// Task that was evaluated.
272    pub task: BaselineTask,
273    /// Algorithm used.
274    pub algorithm: BaselineAlgorithm,
275    /// Classification metrics (if applicable).
276    pub classification_metrics: Option<ClassificationMetrics>,
277    /// Regression metrics (if applicable).
278    pub regression_metrics: Option<RegressionMetrics>,
279    /// Ranking metrics (if applicable).
280    pub ranking_metrics: Option<RankingMetrics>,
281    /// Training time in seconds.
282    pub training_time_secs: f64,
283    /// Inference time per sample in milliseconds.
284    pub inference_time_ms: f64,
285    /// Number of samples used for training.
286    pub train_samples: usize,
287    /// Number of samples used for testing.
288    pub test_samples: usize,
289    /// Performance grade (excellent, good, acceptable, poor).
290    pub grade: PerformanceGrade,
291}
292
293/// Performance grade for baseline results.
294#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
295#[serde(rename_all = "lowercase")]
296pub enum PerformanceGrade {
297    Excellent,
298    Good,
299    Acceptable,
300    Poor,
301}
302
303impl std::fmt::Display for PerformanceGrade {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        match self {
306            PerformanceGrade::Excellent => write!(f, "Excellent"),
307            PerformanceGrade::Good => write!(f, "Good"),
308            PerformanceGrade::Acceptable => write!(f, "Acceptable"),
309            PerformanceGrade::Poor => write!(f, "Poor"),
310        }
311    }
312}
313
314/// Collection of baseline results for all tasks.
315#[derive(Debug, Clone, Default, Serialize, Deserialize)]
316pub struct BaselineEvaluation {
317    /// Results for each task.
318    pub results: Vec<BaselineResult>,
319    /// Summary statistics.
320    pub summary: BaselineSummary,
321}
322
323/// Summary of baseline evaluation.
324#[derive(Debug, Clone, Default, Serialize, Deserialize)]
325pub struct BaselineSummary {
326    /// Number of tasks evaluated.
327    pub tasks_evaluated: usize,
328    /// Tasks meeting minimum threshold.
329    pub tasks_passing: usize,
330    /// Tasks with good performance.
331    pub tasks_good: usize,
332    /// Tasks with excellent performance.
333    pub tasks_excellent: usize,
334    /// Average primary metric across all tasks.
335    pub average_primary_metric: f64,
336    /// Best performing algorithm per task.
337    pub best_algorithms: HashMap<String, BaselineAlgorithm>,
338}
339
340/// Get predefined baseline tasks for synthetic accounting data.
341pub fn get_accounting_baseline_tasks() -> Vec<BaselineTask> {
342    vec![
343        // Anomaly Detection Tasks
344        BaselineTask {
345            id: "anomaly_fraud_detection".to_string(),
346            task_type: MLTaskType::AnomalyDetection,
347            description:
348                "Detect fraudulent journal entries based on amount, timing, and user patterns"
349                    .to_string(),
350            required_fields: vec![
351                "amount".to_string(),
352                "posting_date".to_string(),
353                "created_by".to_string(),
354                "account_number".to_string(),
355                "is_fraud".to_string(),
356            ],
357            target_field: "is_fraud".to_string(),
358            recommended_algorithms: vec![
359                BaselineAlgorithm::IsolationForest,
360                BaselineAlgorithm::Autoencoder,
361                BaselineAlgorithm::LocalOutlierFactor,
362            ],
363            expected_metrics: ExpectedMetrics {
364                min_acceptable: 0.60,
365                good_threshold: 0.75,
366                excellent_threshold: 0.90,
367                primary_metric: "f1_score".to_string(),
368            },
369        },
370        BaselineTask {
371            id: "anomaly_error_detection".to_string(),
372            task_type: MLTaskType::AnomalyDetection,
373            description: "Detect data entry errors and anomalies in journal entries".to_string(),
374            required_fields: vec![
375                "amount".to_string(),
376                "account_number".to_string(),
377                "is_anomaly".to_string(),
378            ],
379            target_field: "is_anomaly".to_string(),
380            recommended_algorithms: vec![
381                BaselineAlgorithm::IsolationForest,
382                BaselineAlgorithm::OneClassSVM,
383            ],
384            expected_metrics: ExpectedMetrics {
385                min_acceptable: 0.50,
386                good_threshold: 0.70,
387                excellent_threshold: 0.85,
388                primary_metric: "f1_score".to_string(),
389            },
390        },
391        // Entity Matching Tasks
392        BaselineTask {
393            id: "entity_vendor_matching".to_string(),
394            task_type: MLTaskType::EntityMatching,
395            description: "Match duplicate or similar vendor records".to_string(),
396            required_fields: vec![
397                "vendor_name".to_string(),
398                "vendor_address".to_string(),
399                "tax_id".to_string(),
400            ],
401            target_field: "is_duplicate".to_string(),
402            recommended_algorithms: vec![
403                BaselineAlgorithm::TFIDFCosine,
404                BaselineAlgorithm::LevenshteinDistance,
405                BaselineAlgorithm::JaccardSimilarity,
406            ],
407            expected_metrics: ExpectedMetrics {
408                min_acceptable: 0.80,
409                good_threshold: 0.90,
410                excellent_threshold: 0.95,
411                primary_metric: "f1_score".to_string(),
412            },
413        },
414        BaselineTask {
415            id: "entity_customer_matching".to_string(),
416            task_type: MLTaskType::EntityMatching,
417            description: "Match duplicate or similar customer records".to_string(),
418            required_fields: vec![
419                "customer_name".to_string(),
420                "customer_address".to_string(),
421                "customer_email".to_string(),
422            ],
423            target_field: "is_duplicate".to_string(),
424            recommended_algorithms: vec![
425                BaselineAlgorithm::TFIDFCosine,
426                BaselineAlgorithm::LevenshteinDistance,
427            ],
428            expected_metrics: ExpectedMetrics {
429                min_acceptable: 0.80,
430                good_threshold: 0.90,
431                excellent_threshold: 0.95,
432                primary_metric: "f1_score".to_string(),
433            },
434        },
435        // Link Prediction Tasks
436        BaselineTask {
437            id: "link_fraud_network".to_string(),
438            task_type: MLTaskType::LinkPrediction,
439            description: "Predict fraudulent transaction links in entity graph".to_string(),
440            required_fields: vec![
441                "source_entity".to_string(),
442                "target_entity".to_string(),
443                "transaction_amount".to_string(),
444                "is_suspicious".to_string(),
445            ],
446            target_field: "is_suspicious".to_string(),
447            recommended_algorithms: vec![
448                BaselineAlgorithm::GraphNeuralNetwork,
449                BaselineAlgorithm::AdamicAdar,
450                BaselineAlgorithm::CommonNeighbors,
451            ],
452            expected_metrics: ExpectedMetrics {
453                min_acceptable: 0.10,
454                good_threshold: 0.25,
455                excellent_threshold: 0.40,
456                primary_metric: "mrr".to_string(),
457            },
458        },
459        BaselineTask {
460            id: "link_intercompany".to_string(),
461            task_type: MLTaskType::LinkPrediction,
462            description: "Predict intercompany transaction relationships".to_string(),
463            required_fields: vec![
464                "company_from".to_string(),
465                "company_to".to_string(),
466                "transaction_type".to_string(),
467            ],
468            target_field: "has_relationship".to_string(),
469            recommended_algorithms: vec![
470                BaselineAlgorithm::CommonNeighbors,
471                BaselineAlgorithm::ResourceAllocation,
472            ],
473            expected_metrics: ExpectedMetrics {
474                min_acceptable: 0.20,
475                good_threshold: 0.35,
476                excellent_threshold: 0.50,
477                primary_metric: "mrr".to_string(),
478            },
479        },
480        // Time Series Forecasting Tasks
481        BaselineTask {
482            id: "forecast_transaction_volume".to_string(),
483            task_type: MLTaskType::TimeSeriesForecasting,
484            description: "Forecast daily transaction volume".to_string(),
485            required_fields: vec!["date".to_string(), "transaction_count".to_string()],
486            target_field: "transaction_count".to_string(),
487            recommended_algorithms: vec![
488                BaselineAlgorithm::Prophet,
489                BaselineAlgorithm::ARIMA,
490                BaselineAlgorithm::ExponentialSmoothing,
491            ],
492            expected_metrics: ExpectedMetrics {
493                min_acceptable: 0.70,
494                good_threshold: 0.85,
495                excellent_threshold: 0.95,
496                primary_metric: "r2".to_string(),
497            },
498        },
499        BaselineTask {
500            id: "forecast_transaction_amount".to_string(),
501            task_type: MLTaskType::TimeSeriesForecasting,
502            description: "Forecast daily transaction amounts".to_string(),
503            required_fields: vec!["date".to_string(), "total_amount".to_string()],
504            target_field: "total_amount".to_string(),
505            recommended_algorithms: vec![
506                BaselineAlgorithm::LSTM,
507                BaselineAlgorithm::Prophet,
508                BaselineAlgorithm::ARIMA,
509            ],
510            expected_metrics: ExpectedMetrics {
511                min_acceptable: 0.60,
512                good_threshold: 0.80,
513                excellent_threshold: 0.90,
514                primary_metric: "r2".to_string(),
515            },
516        },
517    ]
518}
519
520/// Configuration for baseline evaluation.
521#[derive(Debug, Clone, Serialize, Deserialize)]
522pub struct BaselineConfig {
523    /// Which task types to evaluate.
524    pub task_types: Vec<MLTaskType>,
525    /// Train/test split ratio (e.g., 0.8 for 80% train).
526    pub train_ratio: f64,
527    /// Random seed for reproducibility.
528    pub seed: u64,
529    /// Whether to run all algorithms or just the primary one.
530    pub run_all_algorithms: bool,
531    /// Maximum training time per algorithm in seconds.
532    pub max_training_time_secs: u64,
533}
534
535impl Default for BaselineConfig {
536    fn default() -> Self {
537        Self {
538            task_types: vec![
539                MLTaskType::AnomalyDetection,
540                MLTaskType::EntityMatching,
541                MLTaskType::LinkPrediction,
542                MLTaskType::TimeSeriesForecasting,
543            ],
544            train_ratio: 0.8,
545            seed: 42,
546            run_all_algorithms: false,
547            max_training_time_secs: 300,
548        }
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555
556    #[test]
557    fn test_classification_metrics_from_confusion() {
558        // Perfect classifier
559        let metrics = ClassificationMetrics::from_confusion(100, 100, 0, 0);
560        assert!((metrics.accuracy - 1.0).abs() < 0.001);
561        assert!((metrics.precision - 1.0).abs() < 0.001);
562        assert!((metrics.recall - 1.0).abs() < 0.001);
563        assert!((metrics.f1_score - 1.0).abs() < 0.001);
564
565        // Random classifier (50/50)
566        let metrics = ClassificationMetrics::from_confusion(50, 50, 50, 50);
567        assert!((metrics.accuracy - 0.5).abs() < 0.001);
568    }
569
570    #[test]
571    fn test_regression_metrics_from_predictions() {
572        let predictions = vec![1.0, 2.0, 3.0, 4.0, 5.0];
573        let actuals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
574
575        let metrics = RegressionMetrics::from_predictions(&predictions, &actuals);
576        assert!((metrics.mae).abs() < 0.001);
577        assert!((metrics.mse).abs() < 0.001);
578        assert!((metrics.r2 - 1.0).abs() < 0.001);
579    }
580
581    #[test]
582    fn test_get_accounting_baseline_tasks() {
583        let tasks = get_accounting_baseline_tasks();
584        assert!(!tasks.is_empty());
585
586        // Check we have all task types
587        let has_anomaly = tasks
588            .iter()
589            .any(|t| t.task_type == MLTaskType::AnomalyDetection);
590        let has_entity = tasks
591            .iter()
592            .any(|t| t.task_type == MLTaskType::EntityMatching);
593        let has_link = tasks
594            .iter()
595            .any(|t| t.task_type == MLTaskType::LinkPrediction);
596        let has_ts = tasks
597            .iter()
598            .any(|t| t.task_type == MLTaskType::TimeSeriesForecasting);
599
600        assert!(has_anomaly, "Should have anomaly detection tasks");
601        assert!(has_entity, "Should have entity matching tasks");
602        assert!(has_link, "Should have link prediction tasks");
603        assert!(has_ts, "Should have time series tasks");
604    }
605
606    #[test]
607    fn test_baseline_config_default() {
608        let config = BaselineConfig::default();
609        assert_eq!(config.train_ratio, 0.8);
610        assert_eq!(config.task_types.len(), 4);
611    }
612}