1use crate::{
9 advanced_analytics::VectorAnalyticsEngine,
10 benchmarking::{BenchmarkConfig, BenchmarkSuite},
11 embeddings::EmbeddingStrategy,
12 similarity::SimilarityMetric,
13 VectorStore,
14};
15
16use anyhow::{Context, Result};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::{Arc, RwLock};
20use std::time::{Duration, Instant};
21use tokio::sync::Mutex;
22use tracing::{info, span, warn, Level};
23use uuid::Uuid;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct AutoMLConfig {
28 pub max_optimization_time: Duration,
30 pub trials_per_config: usize,
32 pub optimization_metrics: Vec<OptimizationMetric>,
34 pub search_space: SearchSpace,
36 pub cross_validation_folds: usize,
38 pub early_stopping_patience: usize,
40 pub enable_parallel_optimization: bool,
42 pub resource_constraints: ResourceConstraints,
44}
45
46impl Default for AutoMLConfig {
47 fn default() -> Self {
48 Self {
49 max_optimization_time: Duration::from_secs(3600), trials_per_config: 5,
51 optimization_metrics: vec![
52 OptimizationMetric::Accuracy,
53 OptimizationMetric::Latency,
54 OptimizationMetric::MemoryUsage,
55 ],
56 search_space: SearchSpace::default(),
57 cross_validation_folds: 5,
58 early_stopping_patience: 10,
59 enable_parallel_optimization: true,
60 resource_constraints: ResourceConstraints::default(),
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
67pub enum OptimizationMetric {
68 Accuracy,
70 Latency,
72 MemoryUsage,
74 Throughput,
76 IndexBuildTime,
78 StorageEfficiency,
80 EmbeddingQuality,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct SearchSpace {
87 pub embedding_strategies: Vec<EmbeddingStrategy>,
89 pub vector_dimensions: Vec<usize>,
91 pub similarity_metrics: Vec<SimilarityMetric>,
93 pub index_parameters: IndexParameterSpace,
95 pub learning_rates: Vec<f32>,
97 pub batch_sizes: Vec<usize>,
99}
100
101impl Default for SearchSpace {
102 fn default() -> Self {
103 Self {
104 embedding_strategies: vec![
105 EmbeddingStrategy::TfIdf,
106 EmbeddingStrategy::SentenceTransformer,
107 EmbeddingStrategy::Custom("default".to_string()),
108 ],
109 vector_dimensions: vec![128, 256, 512, 768, 1024],
110 similarity_metrics: vec![
111 SimilarityMetric::Cosine,
112 SimilarityMetric::Euclidean,
113 SimilarityMetric::DotProduct,
114 ],
115 index_parameters: IndexParameterSpace::default(),
116 learning_rates: vec![0.001, 0.01, 0.1],
117 batch_sizes: vec![32, 64, 128, 256],
118 }
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct IndexParameterSpace {
125 pub hnsw_m: Vec<usize>,
127 pub hnsw_ef_construction: Vec<usize>,
128 pub hnsw_ef_search: Vec<usize>,
129 pub ivf_nlist: Vec<usize>,
131 pub ivf_nprobe: Vec<usize>,
132 pub pq_m: Vec<usize>,
134 pub pq_nbits: Vec<usize>,
135}
136
137impl Default for IndexParameterSpace {
138 fn default() -> Self {
139 Self {
140 hnsw_m: vec![16, 32, 64],
141 hnsw_ef_construction: vec![100, 200, 400],
142 hnsw_ef_search: vec![50, 100, 200],
143 ivf_nlist: vec![100, 1000, 4096],
144 ivf_nprobe: vec![1, 10, 50],
145 pq_m: vec![8, 16, 32],
146 pq_nbits: vec![4, 8],
147 }
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct ResourceConstraints {
154 pub max_memory_bytes: usize,
156 pub max_cpu_cores: usize,
158 pub max_gpu_memory_bytes: Option<usize>,
160 pub max_disk_usage_bytes: usize,
162}
163
164impl Default for ResourceConstraints {
165 fn default() -> Self {
166 Self {
167 max_memory_bytes: 8 * 1024 * 1024 * 1024, max_cpu_cores: 4,
169 max_gpu_memory_bytes: None,
170 max_disk_usage_bytes: 50 * 1024 * 1024 * 1024, }
172 }
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct OptimizationTrial {
178 pub trial_id: String,
179 pub embedding_strategy: EmbeddingStrategy,
180 pub vector_dimension: usize,
181 pub similarity_metric: SimilarityMetric,
182 pub index_config: IndexConfiguration,
183 pub hyperparameters: HashMap<String, f32>,
184 pub timestamp: u64,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct IndexConfiguration {
190 pub index_type: String,
191 pub parameters: HashMap<String, f32>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct TrialResult {
197 pub trial: OptimizationTrial,
198 pub metrics: HashMap<OptimizationMetric, f32>,
199 pub cross_validation_scores: Vec<f32>,
200 pub training_time: Duration,
201 pub evaluation_time: Duration,
202 pub memory_peak_usage: usize,
203 pub error_message: Option<String>,
204 pub success: bool,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct AutoMLResults {
210 pub best_configuration: OptimizationTrial,
211 pub best_metrics: HashMap<OptimizationMetric, f32>,
212 pub pareto_frontier: Vec<TrialResult>,
213 pub optimization_history: Vec<TrialResult>,
214 pub total_optimization_time: Duration,
215 pub trials_completed: usize,
216 pub improvement_curve: Vec<(usize, f32)>,
217}
218
219pub struct AutoMLOptimizer {
221 config: AutoMLConfig,
222 trial_history: Arc<RwLock<Vec<TrialResult>>>,
223 best_trial: Arc<RwLock<Option<TrialResult>>>,
224 optimization_state: Arc<Mutex<OptimizationState>>,
225 #[allow(dead_code)]
226 analytics_engine: Arc<Mutex<VectorAnalyticsEngine>>,
227 #[allow(dead_code)]
228 benchmark_engine: Arc<Mutex<BenchmarkSuite>>,
229}
230
231#[derive(Debug)]
233struct OptimizationState {
234 #[allow(dead_code)]
235 current_trial: usize,
236 early_stopping_counter: usize,
237 best_score: f32,
238 #[allow(dead_code)]
239 pareto_frontier: Vec<TrialResult>,
240 active_trials: HashMap<String, Instant>,
241}
242
243impl AutoMLOptimizer {
244 pub fn new(config: AutoMLConfig) -> Result<Self> {
246 Ok(Self {
247 config,
248 trial_history: Arc::new(RwLock::new(Vec::new())),
249 best_trial: Arc::new(RwLock::new(None)),
250 optimization_state: Arc::new(Mutex::new(OptimizationState {
251 current_trial: 0,
252 early_stopping_counter: 0,
253 best_score: f32::NEG_INFINITY,
254 pareto_frontier: Vec::new(),
255 active_trials: HashMap::new(),
256 })),
257 analytics_engine: Arc::new(Mutex::new(VectorAnalyticsEngine::new())),
258 benchmark_engine: Arc::new(Mutex::new(BenchmarkSuite::new(BenchmarkConfig::default()))),
259 })
260 }
261
262 pub fn with_default_config() -> Result<Self> {
264 Self::new(AutoMLConfig::default())
265 }
266
267 pub async fn optimize_embeddings(
269 &self,
270 training_data: &[(String, String)], validation_data: &[(String, String)],
272 test_queries: &[(String, Vec<String>)], ) -> Result<AutoMLResults> {
274 let span = span!(Level::INFO, "automl_optimization");
275 let _enter = span.enter();
276
277 info!(
278 "Starting AutoML optimization for {} training samples",
279 training_data.len()
280 );
281
282 let start_time = Instant::now();
283 let optimization_state = self.optimization_state.lock().await;
284
285 let trials = self.generate_optimization_trials()?;
287 info!("Generated {} optimization trials", trials.len());
288
289 drop(optimization_state);
290
291 let mut results = Vec::new();
292 let mut best_score = f32::NEG_INFINITY;
293
294 for (i, trial) in trials.iter().enumerate() {
296 if start_time.elapsed() > self.config.max_optimization_time {
297 warn!("Optimization time budget exceeded, stopping early");
298 break;
299 }
300
301 info!(
302 "Executing trial {}/{}: {}",
303 i + 1,
304 trials.len(),
305 trial.trial_id
306 );
307
308 match self
309 .execute_trial(trial, training_data, validation_data, test_queries)
310 .await
311 {
312 Ok(trial_result) => {
313 let primary_score = self.compute_primary_score(&trial_result.metrics);
315
316 if primary_score > best_score {
317 best_score = primary_score;
318 {
319 let mut best_trial = self
320 .best_trial
321 .write()
322 .expect("best_trial lock should not be poisoned");
323 *best_trial = Some(trial_result.clone());
324 } let mut state = self.optimization_state.lock().await;
328 state.early_stopping_counter = 0;
329 state.best_score = best_score;
330 } else {
331 let mut state = self.optimization_state.lock().await;
332 state.early_stopping_counter += 1;
333
334 if state.early_stopping_counter >= self.config.early_stopping_patience {
336 info!(
337 "Early stopping triggered after {} trials without improvement",
338 self.config.early_stopping_patience
339 );
340 break;
341 }
342 }
343
344 results.push(trial_result);
345 }
346 Err(e) => {
347 warn!("Trial {} failed: {}", trial.trial_id, e);
348 results.push(TrialResult {
349 trial: trial.clone(),
350 metrics: HashMap::new(),
351 cross_validation_scores: Vec::new(),
352 training_time: Duration::from_secs(0),
353 evaluation_time: Duration::from_secs(0),
354 memory_peak_usage: 0,
355 error_message: Some(e.to_string()),
356 success: false,
357 });
358 }
359 }
360 }
361
362 {
364 let mut history = self
365 .trial_history
366 .write()
367 .expect("trial_history lock should not be poisoned");
368 history.extend(results.clone());
369 }
370
371 let best_trial = self
373 .best_trial
374 .read()
375 .expect("best_trial lock should not be poisoned");
376 let best_configuration = best_trial
377 .as_ref()
378 .map(|r| r.trial.clone())
379 .unwrap_or_else(|| trials[0].clone());
380
381 let best_metrics = best_trial
382 .as_ref()
383 .map(|r| r.metrics.clone())
384 .unwrap_or_default();
385
386 let pareto_frontier = self.compute_pareto_frontier(&results);
387 let improvement_curve = self.compute_improvement_curve(&results);
388
389 Ok(AutoMLResults {
390 best_configuration,
391 best_metrics,
392 pareto_frontier,
393 optimization_history: results,
394 total_optimization_time: start_time.elapsed(),
395 trials_completed: trials.len(),
396 improvement_curve,
397 })
398 }
399
400 fn generate_optimization_trials(&self) -> Result<Vec<OptimizationTrial>> {
402 let mut trials = Vec::new();
403
404 for embedding_strategy in &self.config.search_space.embedding_strategies {
406 for &vector_dimension in &self.config.search_space.vector_dimensions {
407 for similarity_metric in &self.config.search_space.similarity_metrics {
408 for &learning_rate in &self.config.search_space.learning_rates {
409 for &batch_size in &self.config.search_space.batch_sizes {
410 let trial = OptimizationTrial {
411 trial_id: Uuid::new_v4().to_string(),
412 embedding_strategy: embedding_strategy.clone(),
413 vector_dimension,
414 similarity_metric: *similarity_metric,
415 index_config: self.generate_index_config()?,
416 hyperparameters: {
417 let mut params = HashMap::new();
418 params.insert("learning_rate".to_string(), learning_rate);
419 params.insert("batch_size".to_string(), batch_size as f32);
420 params
421 },
422 timestamp: std::time::SystemTime::now()
423 .duration_since(std::time::UNIX_EPOCH)
424 .unwrap_or_default()
425 .as_secs(),
426 };
427 trials.push(trial);
428 }
429 }
430 }
431 }
432 }
433
434 for _ in 0..20 {
436 trials.push(self.generate_random_trial()?);
437 }
438
439 Ok(trials)
440 }
441
442 async fn execute_trial(
444 &self,
445 trial: &OptimizationTrial,
446 training_data: &[(String, String)],
447 validation_data: &[(String, String)],
448 test_queries: &[(String, Vec<String>)],
449 ) -> Result<TrialResult> {
450 let trial_start = Instant::now();
451
452 {
454 let mut state = self.optimization_state.lock().await;
455 state
456 .active_trials
457 .insert(trial.trial_id.clone(), trial_start);
458 }
459
460 let mut vector_store = self.create_vector_store_for_trial(trial)?;
462
463 let training_start = Instant::now();
465 for (id, content) in training_data {
466 vector_store
467 .index_resource(id.clone(), content)
468 .context("Failed to index training data")?;
469 }
470 let training_time = training_start.elapsed();
471
472 let cv_scores = self
474 .perform_cross_validation(&vector_store, validation_data, test_queries)
475 .await?;
476
477 let eval_start = Instant::now();
479 let metrics = self
480 .evaluate_trial_performance(&vector_store, test_queries, trial)
481 .await?;
482 let evaluation_time = eval_start.elapsed();
483
484 let memory_peak_usage = self.estimate_memory_usage(&vector_store, trial)?;
486
487 {
489 let mut state = self.optimization_state.lock().await;
490 state.active_trials.remove(&trial.trial_id);
491 }
492
493 Ok(TrialResult {
494 trial: trial.clone(),
495 metrics,
496 cross_validation_scores: cv_scores,
497 training_time,
498 evaluation_time,
499 memory_peak_usage,
500 error_message: None,
501 success: true,
502 })
503 }
504
505 async fn perform_cross_validation(
507 &self,
508 vector_store: &VectorStore,
509 validation_data: &[(String, String)],
510 test_queries: &[(String, Vec<String>)],
511 ) -> Result<Vec<f32>> {
512 let fold_size = validation_data.len() / self.config.cross_validation_folds;
513 let mut cv_scores = Vec::new();
514
515 for fold in 0..self.config.cross_validation_folds {
516 let _start_idx = fold * fold_size;
517 let _end_idx = if fold == self.config.cross_validation_folds - 1 {
518 validation_data.len()
519 } else {
520 (fold + 1) * fold_size
521 };
522
523 let fold_queries: Vec<_> = test_queries
525 .iter()
526 .filter(|(query_id, _)| {
527 let hash = query_id.chars().map(|c| c as u32).sum::<u32>() as usize;
529 let fold_idx = hash % self.config.cross_validation_folds;
530 fold_idx == fold
531 })
532 .cloned()
533 .collect();
534
535 if fold_queries.is_empty() {
536 cv_scores.push(0.0);
537 continue;
538 }
539
540 let mut total_recall = 0.0;
542 for (query, relevant_docs) in &fold_queries {
543 let search_results = vector_store.similarity_search(query, 10)?;
544 let retrieved_docs: Vec<String> =
545 search_results.iter().map(|r| r.0.clone()).collect();
546
547 let recall = self.compute_recall(&retrieved_docs, relevant_docs);
548 total_recall += recall;
549 }
550
551 let avg_recall = total_recall / fold_queries.len() as f32;
552 cv_scores.push(avg_recall);
553 }
554
555 Ok(cv_scores)
556 }
557
558 async fn evaluate_trial_performance(
560 &self,
561 vector_store: &VectorStore,
562 test_queries: &[(String, Vec<String>)],
563 trial: &OptimizationTrial,
564 ) -> Result<HashMap<OptimizationMetric, f32>> {
565 let mut metrics = HashMap::new();
566
567 let mut total_recall = 0.0;
569 let mut total_precision = 0.0;
570 let mut total_latency = 0.0;
571
572 for (query, relevant_docs) in test_queries {
573 let query_start = Instant::now();
574 let search_results = vector_store.similarity_search(query, 10)?;
575 let query_latency = query_start.elapsed().as_millis() as f32;
576
577 total_latency += query_latency;
578
579 let retrieved_docs: Vec<String> = search_results.iter().map(|r| r.0.clone()).collect();
580
581 let recall = self.compute_recall(&retrieved_docs, relevant_docs);
582 let precision = self.compute_precision(&retrieved_docs, relevant_docs);
583
584 total_recall += recall;
585 total_precision += precision;
586 }
587
588 let num_queries = test_queries.len() as f32;
589 metrics.insert(
590 OptimizationMetric::Accuracy,
591 (total_recall + total_precision) / (2.0 * num_queries),
592 );
593 metrics.insert(OptimizationMetric::Latency, total_latency / num_queries);
594
595 let avg_latency_seconds = (total_latency / num_queries) / 1000.0;
597 metrics.insert(OptimizationMetric::Throughput, 1.0 / avg_latency_seconds);
598
599 metrics.insert(
601 OptimizationMetric::MemoryUsage,
602 (trial.vector_dimension as f32) * 4.0,
603 ); metrics.insert(
605 OptimizationMetric::StorageEfficiency,
606 1.0 / (trial.vector_dimension as f32).log2(),
607 );
608
609 let build_time_estimate = match trial.embedding_strategy {
611 EmbeddingStrategy::TfIdf => 100.0,
612 EmbeddingStrategy::SentenceTransformer => 1000.0,
613 _ => 500.0,
614 };
615 metrics.insert(OptimizationMetric::IndexBuildTime, build_time_estimate);
616
617 let embedding_quality = 1.0 - (1.0 / (trial.vector_dimension as f32).sqrt());
619 metrics.insert(OptimizationMetric::EmbeddingQuality, embedding_quality);
620
621 Ok(metrics)
622 }
623
624 fn generate_random_trial(&self) -> Result<OptimizationTrial> {
626 use std::collections::hash_map::DefaultHasher;
627 use std::hash::{Hash, Hasher};
628
629 let mut hasher = DefaultHasher::new();
630 Uuid::new_v4().hash(&mut hasher);
631 let random_seed = hasher.finish();
632
633 let search_space = &self.config.search_space;
634
635 Ok(OptimizationTrial {
636 trial_id: Uuid::new_v4().to_string(),
637 embedding_strategy: search_space.embedding_strategies
638 [(random_seed % search_space.embedding_strategies.len() as u64) as usize]
639 .clone(),
640 vector_dimension: search_space.vector_dimensions
641 [((random_seed >> 8) % search_space.vector_dimensions.len() as u64) as usize],
642 similarity_metric: search_space.similarity_metrics
643 [((random_seed >> 16) % search_space.similarity_metrics.len() as u64) as usize],
644 index_config: self.generate_index_config()?,
645 hyperparameters: {
646 let mut params = HashMap::new();
647 let lr_idx =
648 ((random_seed >> 24) % search_space.learning_rates.len() as u64) as usize;
649 let bs_idx = ((random_seed >> 32) % search_space.batch_sizes.len() as u64) as usize;
650 params.insert(
651 "learning_rate".to_string(),
652 search_space.learning_rates[lr_idx],
653 );
654 params.insert(
655 "batch_size".to_string(),
656 search_space.batch_sizes[bs_idx] as f32,
657 );
658 params
659 },
660 timestamp: std::time::SystemTime::now()
661 .duration_since(std::time::UNIX_EPOCH)
662 .unwrap_or_default()
663 .as_secs(),
664 })
665 }
666
667 fn generate_index_config(&self) -> Result<IndexConfiguration> {
669 Ok(IndexConfiguration {
670 index_type: "hnsw".to_string(),
671 parameters: {
672 let mut params = HashMap::new();
673 params.insert("m".to_string(), 16.0);
674 params.insert("ef_construction".to_string(), 200.0);
675 params.insert("ef_search".to_string(), 100.0);
676 params
677 },
678 })
679 }
680
681 fn create_vector_store_for_trial(&self, trial: &OptimizationTrial) -> Result<VectorStore> {
683 VectorStore::with_embedding_strategy(trial.embedding_strategy.clone())
684 .context("Failed to create vector store for trial")
685 }
686
687 fn compute_primary_score(&self, metrics: &HashMap<OptimizationMetric, f32>) -> f32 {
689 let mut score = 0.0;
690
691 if let Some(&accuracy) = metrics.get(&OptimizationMetric::Accuracy) {
693 score += accuracy * 0.4;
694 }
695 if let Some(&latency) = metrics.get(&OptimizationMetric::Latency) {
696 score += (1.0 / (1.0 + latency / 1000.0)) * 0.3;
698 }
699 if let Some(&throughput) = metrics.get(&OptimizationMetric::Throughput) {
700 score += (throughput / 100.0).min(1.0) * 0.3;
701 }
702
703 score
704 }
705
706 fn compute_pareto_frontier(&self, results: &[TrialResult]) -> Vec<TrialResult> {
708 let mut frontier = Vec::new();
709
710 for result in results {
711 if !result.success {
712 continue;
713 }
714
715 let is_dominated = results.iter().any(|other| {
716 if !other.success || other.trial.trial_id == result.trial.trial_id {
717 return false;
718 }
719
720 let mut better_in_all = true;
722 let mut better_in_some = false;
723
724 for metric in &self.config.optimization_metrics {
725 let result_val = result.metrics.get(metric).unwrap_or(&0.0);
726 let other_val = other.metrics.get(metric).unwrap_or(&0.0);
727
728 match metric {
729 OptimizationMetric::Latency | OptimizationMetric::MemoryUsage => {
730 if other_val > result_val {
732 better_in_all = false;
733 } else if other_val < result_val {
734 better_in_some = true;
735 }
736 }
737 _ => {
738 if other_val < result_val {
740 better_in_all = false;
741 } else if other_val > result_val {
742 better_in_some = true;
743 }
744 }
745 }
746 }
747
748 better_in_all && better_in_some
749 });
750
751 if !is_dominated {
752 frontier.push(result.clone());
753 }
754 }
755
756 frontier
757 }
758
759 fn compute_improvement_curve(&self, results: &[TrialResult]) -> Vec<(usize, f32)> {
761 let mut curve = Vec::new();
762 let mut best_score = f32::NEG_INFINITY;
763
764 for (i, result) in results.iter().enumerate() {
765 if result.success {
766 let score = self.compute_primary_score(&result.metrics);
767 if score > best_score {
768 best_score = score;
769 }
770 }
771 curve.push((i, best_score));
772 }
773
774 curve
775 }
776
777 fn estimate_memory_usage(
779 &self,
780 _vector_store: &VectorStore,
781 trial: &OptimizationTrial,
782 ) -> Result<usize> {
783 let base_memory = 100 * 1024 * 1024; let vector_memory = trial.vector_dimension * 4; let index_overhead = vector_memory / 2; Ok(base_memory + vector_memory + index_overhead)
789 }
790
791 fn compute_recall(&self, retrieved: &[String], relevant: &[String]) -> f32 {
793 if relevant.is_empty() {
794 return 1.0;
795 }
796
797 let relevant_set: std::collections::HashSet<_> = relevant.iter().collect();
798 let retrieved_relevant = retrieved
799 .iter()
800 .filter(|doc| relevant_set.contains(doc))
801 .count();
802
803 retrieved_relevant as f32 / relevant.len() as f32
804 }
805
806 fn compute_precision(&self, retrieved: &[String], relevant: &[String]) -> f32 {
808 if retrieved.is_empty() {
809 return 0.0;
810 }
811
812 let relevant_set: std::collections::HashSet<_> = relevant.iter().collect();
813 let retrieved_relevant = retrieved
814 .iter()
815 .filter(|doc| relevant_set.contains(doc))
816 .count();
817
818 retrieved_relevant as f32 / retrieved.len() as f32
819 }
820
821 pub fn get_optimization_statistics(&self) -> AutoMLStatistics {
823 let history = self
824 .trial_history
825 .read()
826 .expect("trial_history lock should not be poisoned");
827 let best_trial = self
828 .best_trial
829 .read()
830 .expect("best_trial lock should not be poisoned");
831
832 let total_trials = history.len();
833 let successful_trials = history.iter().filter(|r| r.success).count();
834 let average_trial_time = if !history.is_empty() {
835 history
836 .iter()
837 .map(|r| r.training_time + r.evaluation_time)
838 .sum::<Duration>()
839 .as_secs_f32()
840 / history.len() as f32
841 } else {
842 0.0
843 };
844
845 let best_score = best_trial
846 .as_ref()
847 .map(|r| self.compute_primary_score(&r.metrics))
848 .unwrap_or(0.0);
849
850 AutoMLStatistics {
851 total_trials,
852 successful_trials,
853 best_score,
854 average_trial_time,
855 optimization_metrics: self.config.optimization_metrics.clone(),
856 search_space_size: self.estimate_search_space_size(),
857 }
858 }
859
860 fn estimate_search_space_size(&self) -> usize {
861 let space = &self.config.search_space;
862 space.embedding_strategies.len()
863 * space.vector_dimensions.len()
864 * space.similarity_metrics.len()
865 * space.learning_rates.len()
866 * space.batch_sizes.len()
867 }
868}
869
870#[derive(Debug, Clone, Serialize, Deserialize)]
872pub struct AutoMLStatistics {
873 pub total_trials: usize,
874 pub successful_trials: usize,
875 pub best_score: f32,
876 pub average_trial_time: f32,
877 pub optimization_metrics: Vec<OptimizationMetric>,
878 pub search_space_size: usize,
879}
880
881#[cfg(test)]
882mod tests {
883 use super::*;
884
885 #[test]
886 fn test_automl_config_creation() {
887 let config = AutoMLConfig::default();
888 assert_eq!(config.cross_validation_folds, 5);
889 assert_eq!(config.early_stopping_patience, 10);
890 assert!(config.enable_parallel_optimization);
891 }
892
893 #[test]
894 fn test_search_space_default() {
895 let search_space = SearchSpace::default();
896 assert!(!search_space.embedding_strategies.is_empty());
897 assert!(!search_space.vector_dimensions.is_empty());
898 assert!(!search_space.similarity_metrics.is_empty());
899 }
900
901 #[tokio::test]
902 async fn test_automl_optimizer_creation() {
903 let optimizer = AutoMLOptimizer::with_default_config();
904 assert!(optimizer.is_ok());
905 }
906
907 #[test]
908 fn test_trial_generation() {
909 let optimizer = AutoMLOptimizer::with_default_config().unwrap();
910 let trials = optimizer.generate_optimization_trials().unwrap();
911 assert!(!trials.is_empty());
912
913 let mut trial_ids = std::collections::HashSet::new();
915 for trial in &trials {
916 assert!(trial_ids.insert(trial.trial_id.clone()));
917 }
918 }
919
920 #[tokio::test]
921 async fn test_optimization_with_sample_data() {
922 let _optimizer = AutoMLOptimizer::with_default_config().unwrap();
923
924 let training_data = vec![
925 (
926 "doc1".to_string(),
927 "artificial intelligence machine learning".to_string(),
928 ),
929 (
930 "doc2".to_string(),
931 "deep learning neural networks".to_string(),
932 ),
933 (
934 "doc3".to_string(),
935 "natural language processing".to_string(),
936 ),
937 ];
938
939 let validation_data = vec![
940 (
941 "doc4".to_string(),
942 "computer vision image recognition".to_string(),
943 ),
944 (
945 "doc5".to_string(),
946 "reinforcement learning algorithms".to_string(),
947 ),
948 ];
949
950 let test_queries = vec![
951 (
952 "ai query".to_string(),
953 vec!["doc1".to_string(), "doc2".to_string()],
954 ),
955 ("nlp query".to_string(), vec!["doc3".to_string()]),
956 ];
957
958 let config = AutoMLConfig {
960 max_optimization_time: Duration::from_secs(1),
961 trials_per_config: 1,
962 ..Default::default()
963 };
964
965 let optimizer = AutoMLOptimizer::new(config).unwrap();
966 let results = optimizer
967 .optimize_embeddings(&training_data, &validation_data, &test_queries)
968 .await;
969
970 assert!(results.is_ok());
972 }
973
974 #[test]
975 fn test_pareto_frontier_computation() {
976 let optimizer = AutoMLOptimizer::with_default_config().unwrap();
977
978 let trial1 = OptimizationTrial {
979 trial_id: "trial1".to_string(),
980 embedding_strategy: EmbeddingStrategy::TfIdf,
981 vector_dimension: 128,
982 similarity_metric: SimilarityMetric::Cosine,
983 index_config: optimizer.generate_index_config().unwrap(),
984 hyperparameters: HashMap::new(),
985 timestamp: 0,
986 };
987
988 let trial2 = OptimizationTrial {
989 trial_id: "trial2".to_string(),
990 embedding_strategy: EmbeddingStrategy::SentenceTransformer,
991 vector_dimension: 256,
992 similarity_metric: SimilarityMetric::Euclidean,
993 index_config: optimizer.generate_index_config().unwrap(),
994 hyperparameters: HashMap::new(),
995 timestamp: 0,
996 };
997
998 let mut metrics1 = HashMap::new();
999 metrics1.insert(OptimizationMetric::Accuracy, 0.8);
1000 metrics1.insert(OptimizationMetric::Latency, 100.0);
1001
1002 let mut metrics2 = HashMap::new();
1003 metrics2.insert(OptimizationMetric::Accuracy, 0.9);
1004 metrics2.insert(OptimizationMetric::Latency, 200.0);
1005
1006 let results = vec![
1007 TrialResult {
1008 trial: trial1,
1009 metrics: metrics1,
1010 cross_validation_scores: vec![0.8],
1011 training_time: Duration::from_secs(1),
1012 evaluation_time: Duration::from_secs(1),
1013 memory_peak_usage: 1000,
1014 error_message: None,
1015 success: true,
1016 },
1017 TrialResult {
1018 trial: trial2,
1019 metrics: metrics2,
1020 cross_validation_scores: vec![0.9],
1021 training_time: Duration::from_secs(2),
1022 evaluation_time: Duration::from_secs(1),
1023 memory_peak_usage: 2000,
1024 error_message: None,
1025 success: true,
1026 },
1027 ];
1028
1029 let frontier = optimizer.compute_pareto_frontier(&results);
1030 assert_eq!(frontier.len(), 2); }
1032
1033 #[test]
1034 fn test_recall_precision_computation() {
1035 let optimizer = AutoMLOptimizer::with_default_config().unwrap();
1036
1037 let retrieved = vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()];
1038 let relevant = vec!["doc1".to_string(), "doc3".to_string(), "doc4".to_string()];
1039
1040 let recall = optimizer.compute_recall(&retrieved, &relevant);
1041 let precision = optimizer.compute_precision(&retrieved, &relevant);
1042
1043 assert_eq!(recall, 2.0 / 3.0); assert_eq!(precision, 2.0 / 3.0); }
1046
1047 #[test]
1048 fn test_optimization_statistics() {
1049 let optimizer = AutoMLOptimizer::with_default_config().unwrap();
1050 let stats = optimizer.get_optimization_statistics();
1051
1052 assert_eq!(stats.total_trials, 0);
1053 assert_eq!(stats.successful_trials, 0);
1054 assert!(stats.search_space_size > 0);
1055 }
1056}