1use serde::{Deserialize, Serialize};
4use std::time::Duration;
5
6#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct AutoMLConfig {
9 pub enabled: bool,
11 pub tasks: Vec<AutoMLTask>,
13 pub model_selection: ModelSelectionConfig,
15 pub hyperparameter_optimization: HyperparameterOptimizationConfig,
17 pub deployment: ModelDeploymentConfig,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
23pub enum AutoMLTask {
24 AnomalyDetection,
25 ForecastingOptimization,
26 AlertClassification,
27 ResourcePrediction,
28 CostOptimization,
29 PerformanceTuning,
30 Custom(String),
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ModelSelectionConfig {
36 pub criteria: Vec<SelectionCriterion>,
38 pub metrics: Vec<EvaluationMetric>,
40 pub cross_validation: AutoMLCrossValidationConfig,
42 pub model_families: Vec<ModelFamily>,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub enum SelectionCriterion {
49 Accuracy,
50 Precision,
51 Recall,
52 F1Score,
53 AUC,
54 Speed,
55 MemoryUsage,
56 Interpretability,
57 Custom(String),
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
62pub enum EvaluationMetric {
63 Accuracy,
64 Precision,
65 Recall,
66 F1Score,
67 RocAuc,
68 MAE,
69 MSE,
70 RMSE,
71 Custom(String),
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct AutoMLCrossValidationConfig {
77 pub method: CrossValidationMethod,
79 pub folds: u32,
81 pub stratified: bool,
83 pub seed: Option<u64>,
85}
86
87#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
89pub enum CrossValidationMethod {
90 KFold,
91 StratifiedKFold,
92 TimeSeriesSplit,
93 Custom(String),
94}
95
96#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
98pub enum ModelFamily {
99 LinearModels,
100 TreeModels,
101 EnsembleMethods,
102 NeuralNetworks,
103 SupportVectorMachines,
104 NaiveBayes,
105 KNearestNeighbors,
106 Custom(String),
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct HyperparameterOptimizationConfig {
112 pub method: OptimizationMethod,
114 pub search_space: SearchSpaceConfig,
116 pub budget: OptimizationBudget,
118 pub early_stopping: EarlyStoppingConfig,
120}
121
122#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
124pub enum OptimizationMethod {
125 RandomSearch,
126 GridSearch,
127 BayesianOptimization,
128 GeneticAlgorithm,
129 ParticleSwarmOptimization,
130 Custom(String),
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize, Default)]
135pub struct SearchSpaceConfig {
136 pub parameters: Vec<ParameterDefinition>,
138 pub constraints: Vec<ParameterConstraint>,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct ParameterDefinition {
145 pub name: String,
147 pub param_type: ParameterType,
149 pub default: Option<String>,
151 pub description: String,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub enum ParameterType {
158 Integer { min: i64, max: i64 },
159 Float { min: f64, max: f64 },
160 Categorical { values: Vec<String> },
161 Boolean,
162 Custom(String),
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct ParameterConstraint {
168 pub constraint_type: ConstraintType,
170 pub parameters: Vec<String>,
172 pub expression: String,
174}
175
176#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
178pub enum ConstraintType {
179 LinearEquality,
180 LinearInequality,
181 NonLinear,
182 Conditional,
183 Custom(String),
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct OptimizationBudget {
189 pub max_evaluations: u32,
191 pub max_time: Duration,
193 pub max_parallel: u32,
195 pub resources: ResourceLimits,
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct ResourceLimits {
202 pub cpu_cores: Option<u32>,
204 pub memory_gb: Option<u32>,
206 pub gpu_count: Option<u32>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct EarlyStoppingConfig {
213 pub enabled: bool,
215 pub patience: u32,
217 pub min_improvement: f64,
219 pub evaluation_frequency: u32,
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct ModelDeploymentConfig {
226 pub strategy: DeploymentStrategy,
228 pub versioning: VersioningConfig,
230 pub monitoring: ModelMonitoringConfig,
232 pub rollback: RollbackConfig,
234}
235
236#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
238pub enum DeploymentStrategy {
239 BlueGreen,
240 Canary,
241 RollingUpdate,
242 Immediate,
243 Custom(String),
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct VersioningConfig {
249 pub scheme: VersionScheme,
251 pub registry: ModelRegistryConfig,
253 pub artifacts: ArtifactStorageConfig,
255}
256
257#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
259pub enum VersionScheme {
260 Semantic,
261 Sequential,
262 Timestamp,
263 Custom(String),
264}
265
266#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct ModelRegistryConfig {
269 pub registry_type: RegistryType,
271 pub connection: RegistryConnection,
273 pub metadata: MetadataConfig,
275}
276
277#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
279pub enum RegistryType {
280 MLflow,
281 ModelDB,
282 KubeflowPipelines,
283 Custom(String),
284}
285
286#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct RegistryConnection {
289 pub url: String,
291 pub auth: AuthConfig,
293 pub timeout: Duration,
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct AuthConfig {
300 pub method: AuthMethod,
302 pub credentials: std::collections::HashMap<String, String>,
304}
305
306#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
308pub enum AuthMethod {
309 Token,
310 BasicAuth,
311 OAuth2,
312 Certificate,
313 Custom(String),
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct MetadataConfig {
319 pub track_hyperparameters: bool,
321 pub track_metrics: bool,
323 pub track_artifacts: bool,
325 pub custom_fields: Vec<MetadataField>,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct MetadataField {
332 pub name: String,
334 pub field_type: MetadataFieldType,
336 pub required: bool,
338}
339
340#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
342pub enum MetadataFieldType {
343 String,
344 Number,
345 Boolean,
346 Date,
347 JSON,
348 Custom(String),
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
353pub struct ArtifactStorageConfig {
354 pub backend: StorageBackend,
356 pub path: String,
358 pub compression: CompressionConfig,
360}
361
362#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
364pub enum StorageBackend {
365 Local,
366 S3,
367 GCS,
368 Azure,
369 HDFS,
370 Custom(String),
371}
372
373#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct CompressionConfig {
376 pub enabled: bool,
378 pub algorithm: CompressionAlgorithm,
380 pub level: u8,
382}
383
384#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
386pub enum CompressionAlgorithm {
387 Gzip,
388 Lz4,
389 Zstd,
390 Bzip2,
391 Custom(String),
392}
393
394#[derive(Debug, Clone, Serialize, Deserialize, Default)]
396pub struct ModelMonitoringConfig {
397 pub performance: PerformanceMonitoringConfig,
399 pub data_drift: DataDriftConfig,
401 pub model_drift: ModelDriftConfig,
403 pub alerts: ModelAlertConfig,
405}
406
407#[derive(Debug, Clone, Serialize, Deserialize)]
409pub struct PerformanceMonitoringConfig {
410 pub metrics: Vec<PerformanceMetric>,
412 pub frequency: Duration,
414 pub thresholds: PerformanceThresholds,
416}
417
418#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
420pub enum PerformanceMetric {
421 Accuracy,
422 Precision,
423 Recall,
424 F1Score,
425 Latency,
426 Throughput,
427 MemoryUsage,
428 CPUUsage,
429 Custom(String),
430}
431
432#[derive(Debug, Clone, Serialize, Deserialize)]
434pub struct PerformanceThresholds {
435 pub accuracy: Option<f64>,
437 pub latency_ms: Option<f64>,
439 pub throughput: Option<f64>,
441 pub memory_mb: Option<f64>,
443}
444
445#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct DataDriftConfig {
448 pub enabled: bool,
450 pub methods: Vec<DriftDetectionMethod>,
452 pub frequency: Duration,
454 pub reference_window: Duration,
456}
457
458#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
460pub enum DriftDetectionMethod {
461 KolmogorovSmirnov,
462 ChiSquare,
463 PopulationStabilityIndex,
464 JensenShannonDivergence,
465 Custom(String),
466}
467
468#[derive(Debug, Clone, Serialize, Deserialize)]
470pub struct ModelDriftConfig {
471 pub enabled: bool,
473 pub threshold: f64,
475 pub window: Duration,
477 pub retrain_trigger: RetrainTrigger,
479}
480
481#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
483pub enum RetrainTrigger {
484 PerformanceDegradation,
485 DataDrift,
486 TimeBased,
487 Manual,
488 Custom(String),
489}
490
491#[derive(Debug, Clone, Serialize, Deserialize)]
493pub struct ModelAlertConfig {
494 pub types: Vec<ModelAlertType>,
496 pub channels: Vec<String>,
498 pub thresholds: AlertThresholds,
500}
501
502#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
504pub enum ModelAlertType {
505 PerformanceDegradation,
506 DataDrift,
507 ModelDrift,
508 PredictionBias,
509 ServiceUnavailable,
510 Custom(String),
511}
512
513#[derive(Debug, Clone, Serialize, Deserialize)]
515pub struct AlertThresholds {
516 pub performance_degradation: f64,
518 pub data_drift: f64,
520 pub model_drift: f64,
522 pub bias: f64,
524}
525
526#[derive(Debug, Clone, Serialize, Deserialize)]
528pub struct RollbackConfig {
529 pub auto_rollback: bool,
531 pub triggers: Vec<RollbackTrigger>,
533 pub strategy: RollbackStrategy,
535 pub timeout: Duration,
537}
538
539#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
541pub enum RollbackTrigger {
542 PerformanceDegradation,
543 ErrorRateIncrease,
544 LatencyIncrease,
545 Manual,
546 Custom(String),
547}
548
549#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
551pub enum RollbackStrategy {
552 PreviousVersion,
553 StableVersion,
554 SpecificVersion,
555 Custom(String),
556}
557
558impl Default for ModelSelectionConfig {
559 fn default() -> Self {
560 Self {
561 criteria: vec![SelectionCriterion::Accuracy],
562 metrics: vec![EvaluationMetric::Accuracy],
563 cross_validation: AutoMLCrossValidationConfig::default(),
564 model_families: vec![ModelFamily::LinearModels, ModelFamily::TreeModels],
565 }
566 }
567}
568
569impl Default for AutoMLCrossValidationConfig {
570 fn default() -> Self {
571 Self {
572 method: CrossValidationMethod::KFold,
573 folds: 5,
574 stratified: true,
575 seed: Some(42),
576 }
577 }
578}
579
580impl Default for HyperparameterOptimizationConfig {
581 fn default() -> Self {
582 Self {
583 method: OptimizationMethod::BayesianOptimization,
584 search_space: SearchSpaceConfig::default(),
585 budget: OptimizationBudget::default(),
586 early_stopping: EarlyStoppingConfig::default(),
587 }
588 }
589}
590
591impl Default for OptimizationBudget {
592 fn default() -> Self {
593 Self {
594 max_evaluations: 100,
595 max_time: Duration::from_secs(3600), max_parallel: 4,
597 resources: ResourceLimits::default(),
598 }
599 }
600}
601
602impl Default for ResourceLimits {
603 fn default() -> Self {
604 Self {
605 cpu_cores: Some(4),
606 memory_gb: Some(8),
607 gpu_count: None,
608 }
609 }
610}
611
612impl Default for EarlyStoppingConfig {
613 fn default() -> Self {
614 Self {
615 enabled: true,
616 patience: 10,
617 min_improvement: 0.001,
618 evaluation_frequency: 5,
619 }
620 }
621}
622
623impl Default for ModelDeploymentConfig {
624 fn default() -> Self {
625 Self {
626 strategy: DeploymentStrategy::RollingUpdate,
627 versioning: VersioningConfig::default(),
628 monitoring: ModelMonitoringConfig::default(),
629 rollback: RollbackConfig::default(),
630 }
631 }
632}
633
634impl Default for VersioningConfig {
635 fn default() -> Self {
636 Self {
637 scheme: VersionScheme::Semantic,
638 registry: ModelRegistryConfig::default(),
639 artifacts: ArtifactStorageConfig::default(),
640 }
641 }
642}
643
644impl Default for ModelRegistryConfig {
645 fn default() -> Self {
646 Self {
647 registry_type: RegistryType::MLflow,
648 connection: RegistryConnection::default(),
649 metadata: MetadataConfig::default(),
650 }
651 }
652}
653
654impl Default for RegistryConnection {
655 fn default() -> Self {
656 Self {
657 url: "http://localhost:5000".to_string(),
658 auth: AuthConfig::default(),
659 timeout: Duration::from_secs(30),
660 }
661 }
662}
663
664impl Default for AuthConfig {
665 fn default() -> Self {
666 Self {
667 method: AuthMethod::Token,
668 credentials: std::collections::HashMap::new(),
669 }
670 }
671}
672
673impl Default for MetadataConfig {
674 fn default() -> Self {
675 Self {
676 track_hyperparameters: true,
677 track_metrics: true,
678 track_artifacts: true,
679 custom_fields: vec![],
680 }
681 }
682}
683
684impl Default for ArtifactStorageConfig {
685 fn default() -> Self {
686 Self {
687 backend: StorageBackend::Local,
688 path: "./models".to_string(),
689 compression: CompressionConfig::default(),
690 }
691 }
692}
693
694impl Default for CompressionConfig {
695 fn default() -> Self {
696 Self {
697 enabled: true,
698 algorithm: CompressionAlgorithm::Gzip,
699 level: 6,
700 }
701 }
702}
703
704impl Default for PerformanceMonitoringConfig {
705 fn default() -> Self {
706 Self {
707 metrics: vec![PerformanceMetric::Accuracy, PerformanceMetric::Latency],
708 frequency: Duration::from_secs(300), thresholds: PerformanceThresholds::default(),
710 }
711 }
712}
713
714impl Default for PerformanceThresholds {
715 fn default() -> Self {
716 Self {
717 accuracy: Some(0.85),
718 latency_ms: Some(1000.0),
719 throughput: Some(100.0),
720 memory_mb: Some(1024.0),
721 }
722 }
723}
724
725impl Default for DataDriftConfig {
726 fn default() -> Self {
727 Self {
728 enabled: false,
729 methods: vec![DriftDetectionMethod::KolmogorovSmirnov],
730 frequency: Duration::from_secs(3600), reference_window: Duration::from_secs(86400 * 7), }
733 }
734}
735
736impl Default for ModelDriftConfig {
737 fn default() -> Self {
738 Self {
739 enabled: false,
740 threshold: 0.1,
741 window: Duration::from_secs(86400), retrain_trigger: RetrainTrigger::PerformanceDegradation,
743 }
744 }
745}
746
747impl Default for ModelAlertConfig {
748 fn default() -> Self {
749 Self {
750 types: vec![ModelAlertType::PerformanceDegradation],
751 channels: vec!["email".to_string()],
752 thresholds: AlertThresholds::default(),
753 }
754 }
755}
756
757impl Default for AlertThresholds {
758 fn default() -> Self {
759 Self {
760 performance_degradation: 0.05,
761 data_drift: 0.1,
762 model_drift: 0.1,
763 bias: 0.05,
764 }
765 }
766}
767
768impl Default for RollbackConfig {
769 fn default() -> Self {
770 Self {
771 auto_rollback: false,
772 triggers: vec![RollbackTrigger::PerformanceDegradation],
773 strategy: RollbackStrategy::PreviousVersion,
774 timeout: Duration::from_secs(300), }
776 }
777}