1use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum MLTaskType {
16 AnomalyDetection,
18 EntityMatching,
20 LinkPrediction,
22 TimeSeriesForecasting,
24}
25
26impl std::fmt::Display for MLTaskType {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 MLTaskType::AnomalyDetection => write!(f, "Anomaly Detection"),
30 MLTaskType::EntityMatching => write!(f, "Entity Matching"),
31 MLTaskType::LinkPrediction => write!(f, "Link Prediction"),
32 MLTaskType::TimeSeriesForecasting => write!(f, "Time Series Forecasting"),
33 }
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39#[serde(rename_all = "snake_case")]
40pub enum BaselineAlgorithm {
41 IsolationForest,
43 LocalOutlierFactor,
44 OneClassSVM,
45 Autoencoder,
46
47 ExactMatch,
49 JaccardSimilarity,
50 LevenshteinDistance,
51 TFIDFCosine,
52
53 CommonNeighbors,
55 AdamicAdar,
56 ResourceAllocation,
57 GraphNeuralNetwork,
58
59 ARIMA,
61 ExponentialSmoothing,
62 Prophet,
63 LSTM,
64}
65
66impl std::fmt::Display for BaselineAlgorithm {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 match self {
69 BaselineAlgorithm::IsolationForest => write!(f, "Isolation Forest"),
70 BaselineAlgorithm::LocalOutlierFactor => write!(f, "Local Outlier Factor"),
71 BaselineAlgorithm::OneClassSVM => write!(f, "One-Class SVM"),
72 BaselineAlgorithm::Autoencoder => write!(f, "Autoencoder"),
73 BaselineAlgorithm::ExactMatch => write!(f, "Exact Match"),
74 BaselineAlgorithm::JaccardSimilarity => write!(f, "Jaccard Similarity"),
75 BaselineAlgorithm::LevenshteinDistance => write!(f, "Levenshtein Distance"),
76 BaselineAlgorithm::TFIDFCosine => write!(f, "TF-IDF Cosine"),
77 BaselineAlgorithm::CommonNeighbors => write!(f, "Common Neighbors"),
78 BaselineAlgorithm::AdamicAdar => write!(f, "Adamic-Adar"),
79 BaselineAlgorithm::ResourceAllocation => write!(f, "Resource Allocation"),
80 BaselineAlgorithm::GraphNeuralNetwork => write!(f, "Graph Neural Network"),
81 BaselineAlgorithm::ARIMA => write!(f, "ARIMA"),
82 BaselineAlgorithm::ExponentialSmoothing => write!(f, "Exponential Smoothing"),
83 BaselineAlgorithm::Prophet => write!(f, "Prophet"),
84 BaselineAlgorithm::LSTM => write!(f, "LSTM"),
85 }
86 }
87}
88
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
91pub struct ClassificationMetrics {
92 pub accuracy: f64,
94 pub precision: f64,
96 pub recall: f64,
98 pub f1_score: f64,
100 pub auc_roc: f64,
102 pub auc_pr: f64,
104 pub mcc: f64,
106}
107
108impl ClassificationMetrics {
109 pub fn from_confusion(tp: u64, tn: u64, fp: u64, fn_: u64) -> Self {
111 let total = (tp + tn + fp + fn_) as f64;
112 let accuracy = if total > 0.0 {
113 (tp + tn) as f64 / total
114 } else {
115 0.0
116 };
117
118 let precision = if tp + fp > 0 {
119 tp as f64 / (tp + fp) as f64
120 } else {
121 0.0
122 };
123 let recall = if tp + fn_ > 0 {
124 tp as f64 / (tp + fn_) as f64
125 } else {
126 0.0
127 };
128 let f1_score = if precision + recall > 0.0 {
129 2.0 * precision * recall / (precision + recall)
130 } else {
131 0.0
132 };
133
134 let mcc_num = (tp * tn) as f64 - (fp * fn_) as f64;
136 let mcc_denom =
137 ((tp + fp) as f64 * (tp + fn_) as f64 * (tn + fp) as f64 * (tn + fn_) as f64).sqrt();
138 let mcc = if mcc_denom > 0.0 {
139 mcc_num / mcc_denom
140 } else {
141 0.0
142 };
143
144 Self {
145 accuracy,
146 precision,
147 recall,
148 f1_score,
149 auc_roc: 0.0, auc_pr: 0.0, mcc,
152 }
153 }
154}
155
156#[derive(Debug, Clone, Default, Serialize, Deserialize)]
158pub struct RegressionMetrics {
159 pub mae: f64,
161 pub mse: f64,
163 pub rmse: f64,
165 pub mape: f64,
167 pub r2: f64,
169}
170
171impl RegressionMetrics {
172 pub fn from_predictions(predictions: &[f64], actuals: &[f64]) -> Self {
174 if predictions.len() != actuals.len() || predictions.is_empty() {
175 return Self::default();
176 }
177
178 let n = predictions.len() as f64;
179
180 let errors: Vec<f64> = predictions
182 .iter()
183 .zip(actuals.iter())
184 .map(|(p, a)| p - a)
185 .collect();
186
187 let mae = errors.iter().map(|e| e.abs()).sum::<f64>() / n;
188 let mse = errors.iter().map(|e| e * e).sum::<f64>() / n;
189 let rmse = mse.sqrt();
190
191 let mape = predictions
193 .iter()
194 .zip(actuals.iter())
195 .filter(|(_, a)| a.abs() > 1e-10)
196 .map(|(p, a)| ((p - a) / a).abs())
197 .sum::<f64>()
198 / n
199 * 100.0;
200
201 let actual_mean = actuals.iter().sum::<f64>() / n;
203 let ss_tot: f64 = actuals.iter().map(|a| (a - actual_mean).powi(2)).sum();
204 let ss_res: f64 = errors.iter().map(|e| e * e).sum();
205 let r2 = if ss_tot > 0.0 {
206 1.0 - (ss_res / ss_tot)
207 } else {
208 0.0
209 };
210
211 Self {
212 mae,
213 mse,
214 rmse,
215 mape,
216 r2,
217 }
218 }
219}
220
221#[derive(Debug, Clone, Default, Serialize, Deserialize)]
223pub struct RankingMetrics {
224 pub mrr: f64,
226 pub hits_at_1: f64,
228 pub hits_at_10: f64,
230 pub hits_at_100: f64,
232 pub ndcg: f64,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct BaselineTask {
239 pub id: String,
241 pub task_type: MLTaskType,
243 pub description: String,
245 pub required_fields: Vec<String>,
247 pub target_field: String,
249 pub recommended_algorithms: Vec<BaselineAlgorithm>,
251 pub expected_metrics: ExpectedMetrics,
253}
254
255#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct ExpectedMetrics {
258 pub min_acceptable: f64,
260 pub good_threshold: f64,
262 pub excellent_threshold: f64,
264 pub primary_metric: String,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct BaselineResult {
271 pub task: BaselineTask,
273 pub algorithm: BaselineAlgorithm,
275 pub classification_metrics: Option<ClassificationMetrics>,
277 pub regression_metrics: Option<RegressionMetrics>,
279 pub ranking_metrics: Option<RankingMetrics>,
281 pub training_time_secs: f64,
283 pub inference_time_ms: f64,
285 pub train_samples: usize,
287 pub test_samples: usize,
289 pub grade: PerformanceGrade,
291}
292
293#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
295#[serde(rename_all = "lowercase")]
296pub enum PerformanceGrade {
297 Excellent,
298 Good,
299 Acceptable,
300 Poor,
301}
302
303impl std::fmt::Display for PerformanceGrade {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 match self {
306 PerformanceGrade::Excellent => write!(f, "Excellent"),
307 PerformanceGrade::Good => write!(f, "Good"),
308 PerformanceGrade::Acceptable => write!(f, "Acceptable"),
309 PerformanceGrade::Poor => write!(f, "Poor"),
310 }
311 }
312}
313
314#[derive(Debug, Clone, Default, Serialize, Deserialize)]
316pub struct BaselineEvaluation {
317 pub results: Vec<BaselineResult>,
319 pub summary: BaselineSummary,
321}
322
323#[derive(Debug, Clone, Default, Serialize, Deserialize)]
325pub struct BaselineSummary {
326 pub tasks_evaluated: usize,
328 pub tasks_passing: usize,
330 pub tasks_good: usize,
332 pub tasks_excellent: usize,
334 pub average_primary_metric: f64,
336 pub best_algorithms: HashMap<String, BaselineAlgorithm>,
338}
339
340pub fn get_accounting_baseline_tasks() -> Vec<BaselineTask> {
342 vec![
343 BaselineTask {
345 id: "anomaly_fraud_detection".to_string(),
346 task_type: MLTaskType::AnomalyDetection,
347 description:
348 "Detect fraudulent journal entries based on amount, timing, and user patterns"
349 .to_string(),
350 required_fields: vec![
351 "amount".to_string(),
352 "posting_date".to_string(),
353 "created_by".to_string(),
354 "account_number".to_string(),
355 "is_fraud".to_string(),
356 ],
357 target_field: "is_fraud".to_string(),
358 recommended_algorithms: vec![
359 BaselineAlgorithm::IsolationForest,
360 BaselineAlgorithm::Autoencoder,
361 BaselineAlgorithm::LocalOutlierFactor,
362 ],
363 expected_metrics: ExpectedMetrics {
364 min_acceptable: 0.60,
365 good_threshold: 0.75,
366 excellent_threshold: 0.90,
367 primary_metric: "f1_score".to_string(),
368 },
369 },
370 BaselineTask {
371 id: "anomaly_error_detection".to_string(),
372 task_type: MLTaskType::AnomalyDetection,
373 description: "Detect data entry errors and anomalies in journal entries".to_string(),
374 required_fields: vec![
375 "amount".to_string(),
376 "account_number".to_string(),
377 "is_anomaly".to_string(),
378 ],
379 target_field: "is_anomaly".to_string(),
380 recommended_algorithms: vec![
381 BaselineAlgorithm::IsolationForest,
382 BaselineAlgorithm::OneClassSVM,
383 ],
384 expected_metrics: ExpectedMetrics {
385 min_acceptable: 0.50,
386 good_threshold: 0.70,
387 excellent_threshold: 0.85,
388 primary_metric: "f1_score".to_string(),
389 },
390 },
391 BaselineTask {
393 id: "entity_vendor_matching".to_string(),
394 task_type: MLTaskType::EntityMatching,
395 description: "Match duplicate or similar vendor records".to_string(),
396 required_fields: vec![
397 "vendor_name".to_string(),
398 "vendor_address".to_string(),
399 "tax_id".to_string(),
400 ],
401 target_field: "is_duplicate".to_string(),
402 recommended_algorithms: vec![
403 BaselineAlgorithm::TFIDFCosine,
404 BaselineAlgorithm::LevenshteinDistance,
405 BaselineAlgorithm::JaccardSimilarity,
406 ],
407 expected_metrics: ExpectedMetrics {
408 min_acceptable: 0.80,
409 good_threshold: 0.90,
410 excellent_threshold: 0.95,
411 primary_metric: "f1_score".to_string(),
412 },
413 },
414 BaselineTask {
415 id: "entity_customer_matching".to_string(),
416 task_type: MLTaskType::EntityMatching,
417 description: "Match duplicate or similar customer records".to_string(),
418 required_fields: vec![
419 "customer_name".to_string(),
420 "customer_address".to_string(),
421 "customer_email".to_string(),
422 ],
423 target_field: "is_duplicate".to_string(),
424 recommended_algorithms: vec![
425 BaselineAlgorithm::TFIDFCosine,
426 BaselineAlgorithm::LevenshteinDistance,
427 ],
428 expected_metrics: ExpectedMetrics {
429 min_acceptable: 0.80,
430 good_threshold: 0.90,
431 excellent_threshold: 0.95,
432 primary_metric: "f1_score".to_string(),
433 },
434 },
435 BaselineTask {
437 id: "link_fraud_network".to_string(),
438 task_type: MLTaskType::LinkPrediction,
439 description: "Predict fraudulent transaction links in entity graph".to_string(),
440 required_fields: vec![
441 "source_entity".to_string(),
442 "target_entity".to_string(),
443 "transaction_amount".to_string(),
444 "is_suspicious".to_string(),
445 ],
446 target_field: "is_suspicious".to_string(),
447 recommended_algorithms: vec![
448 BaselineAlgorithm::GraphNeuralNetwork,
449 BaselineAlgorithm::AdamicAdar,
450 BaselineAlgorithm::CommonNeighbors,
451 ],
452 expected_metrics: ExpectedMetrics {
453 min_acceptable: 0.10,
454 good_threshold: 0.25,
455 excellent_threshold: 0.40,
456 primary_metric: "mrr".to_string(),
457 },
458 },
459 BaselineTask {
460 id: "link_intercompany".to_string(),
461 task_type: MLTaskType::LinkPrediction,
462 description: "Predict intercompany transaction relationships".to_string(),
463 required_fields: vec![
464 "company_from".to_string(),
465 "company_to".to_string(),
466 "transaction_type".to_string(),
467 ],
468 target_field: "has_relationship".to_string(),
469 recommended_algorithms: vec![
470 BaselineAlgorithm::CommonNeighbors,
471 BaselineAlgorithm::ResourceAllocation,
472 ],
473 expected_metrics: ExpectedMetrics {
474 min_acceptable: 0.20,
475 good_threshold: 0.35,
476 excellent_threshold: 0.50,
477 primary_metric: "mrr".to_string(),
478 },
479 },
480 BaselineTask {
482 id: "forecast_transaction_volume".to_string(),
483 task_type: MLTaskType::TimeSeriesForecasting,
484 description: "Forecast daily transaction volume".to_string(),
485 required_fields: vec!["date".to_string(), "transaction_count".to_string()],
486 target_field: "transaction_count".to_string(),
487 recommended_algorithms: vec![
488 BaselineAlgorithm::Prophet,
489 BaselineAlgorithm::ARIMA,
490 BaselineAlgorithm::ExponentialSmoothing,
491 ],
492 expected_metrics: ExpectedMetrics {
493 min_acceptable: 0.70,
494 good_threshold: 0.85,
495 excellent_threshold: 0.95,
496 primary_metric: "r2".to_string(),
497 },
498 },
499 BaselineTask {
500 id: "forecast_transaction_amount".to_string(),
501 task_type: MLTaskType::TimeSeriesForecasting,
502 description: "Forecast daily transaction amounts".to_string(),
503 required_fields: vec!["date".to_string(), "total_amount".to_string()],
504 target_field: "total_amount".to_string(),
505 recommended_algorithms: vec![
506 BaselineAlgorithm::LSTM,
507 BaselineAlgorithm::Prophet,
508 BaselineAlgorithm::ARIMA,
509 ],
510 expected_metrics: ExpectedMetrics {
511 min_acceptable: 0.60,
512 good_threshold: 0.80,
513 excellent_threshold: 0.90,
514 primary_metric: "r2".to_string(),
515 },
516 },
517 ]
518}
519
520#[derive(Debug, Clone, Serialize, Deserialize)]
522pub struct BaselineConfig {
523 pub task_types: Vec<MLTaskType>,
525 pub train_ratio: f64,
527 pub seed: u64,
529 pub run_all_algorithms: bool,
531 pub max_training_time_secs: u64,
533}
534
535impl Default for BaselineConfig {
536 fn default() -> Self {
537 Self {
538 task_types: vec![
539 MLTaskType::AnomalyDetection,
540 MLTaskType::EntityMatching,
541 MLTaskType::LinkPrediction,
542 MLTaskType::TimeSeriesForecasting,
543 ],
544 train_ratio: 0.8,
545 seed: 42,
546 run_all_algorithms: false,
547 max_training_time_secs: 300,
548 }
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 #[test]
557 fn test_classification_metrics_from_confusion() {
558 let metrics = ClassificationMetrics::from_confusion(100, 100, 0, 0);
560 assert!((metrics.accuracy - 1.0).abs() < 0.001);
561 assert!((metrics.precision - 1.0).abs() < 0.001);
562 assert!((metrics.recall - 1.0).abs() < 0.001);
563 assert!((metrics.f1_score - 1.0).abs() < 0.001);
564
565 let metrics = ClassificationMetrics::from_confusion(50, 50, 50, 50);
567 assert!((metrics.accuracy - 0.5).abs() < 0.001);
568 }
569
570 #[test]
571 fn test_regression_metrics_from_predictions() {
572 let predictions = vec![1.0, 2.0, 3.0, 4.0, 5.0];
573 let actuals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
574
575 let metrics = RegressionMetrics::from_predictions(&predictions, &actuals);
576 assert!((metrics.mae).abs() < 0.001);
577 assert!((metrics.mse).abs() < 0.001);
578 assert!((metrics.r2 - 1.0).abs() < 0.001);
579 }
580
581 #[test]
582 fn test_get_accounting_baseline_tasks() {
583 let tasks = get_accounting_baseline_tasks();
584 assert!(!tasks.is_empty());
585
586 let has_anomaly = tasks
588 .iter()
589 .any(|t| t.task_type == MLTaskType::AnomalyDetection);
590 let has_entity = tasks
591 .iter()
592 .any(|t| t.task_type == MLTaskType::EntityMatching);
593 let has_link = tasks
594 .iter()
595 .any(|t| t.task_type == MLTaskType::LinkPrediction);
596 let has_ts = tasks
597 .iter()
598 .any(|t| t.task_type == MLTaskType::TimeSeriesForecasting);
599
600 assert!(has_anomaly, "Should have anomaly detection tasks");
601 assert!(has_entity, "Should have entity matching tasks");
602 assert!(has_link, "Should have link prediction tasks");
603 assert!(has_ts, "Should have time series tasks");
604 }
605
606 #[test]
607 fn test_baseline_config_default() {
608 let config = BaselineConfig::default();
609 assert_eq!(config.train_ratio, 0.8);
610 assert_eq!(config.task_types.len(), 4);
611 }
612}