1use crate::{
9 advanced_analytics::VectorAnalyticsEngine,
10 benchmarking::{BenchmarkConfig, BenchmarkSuite},
11 embeddings::EmbeddingStrategy,
12 similarity::SimilarityMetric,
13 VectorStore,
14};
15
16use anyhow::{Context, Result};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::{Arc, RwLock};
20use std::time::{Duration, Instant};
21use tokio::sync::Mutex;
22use tracing::{info, span, warn, Level};
23use uuid::Uuid;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct AutoMLConfig {
28 pub max_optimization_time: Duration,
30 pub trials_per_config: usize,
32 pub optimization_metrics: Vec<OptimizationMetric>,
34 pub search_space: SearchSpace,
36 pub cross_validation_folds: usize,
38 pub early_stopping_patience: usize,
40 pub enable_parallel_optimization: bool,
42 pub resource_constraints: ResourceConstraints,
44}
45
46impl Default for AutoMLConfig {
47 fn default() -> Self {
48 Self {
49 max_optimization_time: Duration::from_secs(3600), trials_per_config: 5,
51 optimization_metrics: vec![
52 OptimizationMetric::Accuracy,
53 OptimizationMetric::Latency,
54 OptimizationMetric::MemoryUsage,
55 ],
56 search_space: SearchSpace::default(),
57 cross_validation_folds: 5,
58 early_stopping_patience: 10,
59 enable_parallel_optimization: true,
60 resource_constraints: ResourceConstraints::default(),
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
67pub enum OptimizationMetric {
68 Accuracy,
70 Latency,
72 MemoryUsage,
74 Throughput,
76 IndexBuildTime,
78 StorageEfficiency,
80 EmbeddingQuality,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct SearchSpace {
87 pub embedding_strategies: Vec<EmbeddingStrategy>,
89 pub vector_dimensions: Vec<usize>,
91 pub similarity_metrics: Vec<SimilarityMetric>,
93 pub index_parameters: IndexParameterSpace,
95 pub learning_rates: Vec<f32>,
97 pub batch_sizes: Vec<usize>,
99}
100
101impl Default for SearchSpace {
102 fn default() -> Self {
103 Self {
104 embedding_strategies: vec![
105 EmbeddingStrategy::TfIdf,
106 EmbeddingStrategy::SentenceTransformer,
107 EmbeddingStrategy::Custom("default".to_string()),
108 ],
109 vector_dimensions: vec![128, 256, 512, 768, 1024],
110 similarity_metrics: vec![
111 SimilarityMetric::Cosine,
112 SimilarityMetric::Euclidean,
113 SimilarityMetric::DotProduct,
114 ],
115 index_parameters: IndexParameterSpace::default(),
116 learning_rates: vec![0.001, 0.01, 0.1],
117 batch_sizes: vec![32, 64, 128, 256],
118 }
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct IndexParameterSpace {
125 pub hnsw_m: Vec<usize>,
127 pub hnsw_ef_construction: Vec<usize>,
128 pub hnsw_ef_search: Vec<usize>,
129 pub ivf_nlist: Vec<usize>,
131 pub ivf_nprobe: Vec<usize>,
132 pub pq_m: Vec<usize>,
134 pub pq_nbits: Vec<usize>,
135}
136
137impl Default for IndexParameterSpace {
138 fn default() -> Self {
139 Self {
140 hnsw_m: vec![16, 32, 64],
141 hnsw_ef_construction: vec![100, 200, 400],
142 hnsw_ef_search: vec![50, 100, 200],
143 ivf_nlist: vec![100, 1000, 4096],
144 ivf_nprobe: vec![1, 10, 50],
145 pq_m: vec![8, 16, 32],
146 pq_nbits: vec![4, 8],
147 }
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct ResourceConstraints {
154 pub max_memory_bytes: usize,
156 pub max_cpu_cores: usize,
158 pub max_gpu_memory_bytes: Option<usize>,
160 pub max_disk_usage_bytes: usize,
162}
163
164impl Default for ResourceConstraints {
165 fn default() -> Self {
166 Self {
167 max_memory_bytes: 8 * 1024 * 1024 * 1024, max_cpu_cores: 4,
169 max_gpu_memory_bytes: None,
170 max_disk_usage_bytes: 50 * 1024 * 1024 * 1024, }
172 }
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct OptimizationTrial {
178 pub trial_id: String,
179 pub embedding_strategy: EmbeddingStrategy,
180 pub vector_dimension: usize,
181 pub similarity_metric: SimilarityMetric,
182 pub index_config: IndexConfiguration,
183 pub hyperparameters: HashMap<String, f32>,
184 pub timestamp: u64,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct IndexConfiguration {
190 pub index_type: String,
191 pub parameters: HashMap<String, f32>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct TrialResult {
197 pub trial: OptimizationTrial,
198 pub metrics: HashMap<OptimizationMetric, f32>,
199 pub cross_validation_scores: Vec<f32>,
200 pub training_time: Duration,
201 pub evaluation_time: Duration,
202 pub memory_peak_usage: usize,
203 pub error_message: Option<String>,
204 pub success: bool,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct AutoMLResults {
210 pub best_configuration: OptimizationTrial,
211 pub best_metrics: HashMap<OptimizationMetric, f32>,
212 pub pareto_frontier: Vec<TrialResult>,
213 pub optimization_history: Vec<TrialResult>,
214 pub total_optimization_time: Duration,
215 pub trials_completed: usize,
216 pub improvement_curve: Vec<(usize, f32)>,
217}
218
219pub struct AutoMLOptimizer {
221 config: AutoMLConfig,
222 trial_history: Arc<RwLock<Vec<TrialResult>>>,
223 best_trial: Arc<RwLock<Option<TrialResult>>>,
224 optimization_state: Arc<Mutex<OptimizationState>>,
225 #[allow(dead_code)]
226 analytics_engine: Arc<Mutex<VectorAnalyticsEngine>>,
227 #[allow(dead_code)]
228 benchmark_engine: Arc<Mutex<BenchmarkSuite>>,
229}
230
231#[derive(Debug)]
233struct OptimizationState {
234 #[allow(dead_code)]
235 current_trial: usize,
236 early_stopping_counter: usize,
237 best_score: f32,
238 #[allow(dead_code)]
239 pareto_frontier: Vec<TrialResult>,
240 active_trials: HashMap<String, Instant>,
241}
242
243impl AutoMLOptimizer {
244 pub fn new(config: AutoMLConfig) -> Result<Self> {
246 Ok(Self {
247 config,
248 trial_history: Arc::new(RwLock::new(Vec::new())),
249 best_trial: Arc::new(RwLock::new(None)),
250 optimization_state: Arc::new(Mutex::new(OptimizationState {
251 current_trial: 0,
252 early_stopping_counter: 0,
253 best_score: f32::NEG_INFINITY,
254 pareto_frontier: Vec::new(),
255 active_trials: HashMap::new(),
256 })),
257 analytics_engine: Arc::new(Mutex::new(VectorAnalyticsEngine::new())),
258 benchmark_engine: Arc::new(Mutex::new(BenchmarkSuite::new(BenchmarkConfig::default()))),
259 })
260 }
261
262 pub fn with_default_config() -> Result<Self> {
264 Self::new(AutoMLConfig::default())
265 }
266
267 pub async fn optimize_embeddings(
269 &self,
270 training_data: &[(String, String)], validation_data: &[(String, String)],
272 test_queries: &[(String, Vec<String>)], ) -> Result<AutoMLResults> {
274 let span = span!(Level::INFO, "automl_optimization");
275 let _enter = span.enter();
276
277 info!(
278 "Starting AutoML optimization for {} training samples",
279 training_data.len()
280 );
281
282 let start_time = Instant::now();
283 let optimization_state = self.optimization_state.lock().await;
284
285 let trials = self.generate_optimization_trials()?;
287 info!("Generated {} optimization trials", trials.len());
288
289 drop(optimization_state);
290
291 let mut results = Vec::new();
292 let mut best_score = f32::NEG_INFINITY;
293
294 for (i, trial) in trials.iter().enumerate() {
296 if start_time.elapsed() > self.config.max_optimization_time {
297 warn!("Optimization time budget exceeded, stopping early");
298 break;
299 }
300
301 info!(
302 "Executing trial {}/{}: {}",
303 i + 1,
304 trials.len(),
305 trial.trial_id
306 );
307
308 match self
309 .execute_trial(trial, training_data, validation_data, test_queries)
310 .await
311 {
312 Ok(trial_result) => {
313 let primary_score = self.compute_primary_score(&trial_result.metrics);
315
316 if primary_score > best_score {
317 best_score = primary_score;
318 {
319 let mut best_trial = self
320 .best_trial
321 .write()
322 .expect("best_trial lock should not be poisoned");
323 *best_trial = Some(trial_result.clone());
324 } let mut state = self.optimization_state.lock().await;
328 state.early_stopping_counter = 0;
329 state.best_score = best_score;
330 } else {
331 let mut state = self.optimization_state.lock().await;
332 state.early_stopping_counter += 1;
333
334 if state.early_stopping_counter >= self.config.early_stopping_patience {
336 info!(
337 "Early stopping triggered after {} trials without improvement",
338 self.config.early_stopping_patience
339 );
340 break;
341 }
342 }
343
344 results.push(trial_result);
345 }
346 Err(e) => {
347 warn!("Trial {} failed: {}", trial.trial_id, e);
348 results.push(TrialResult {
349 trial: trial.clone(),
350 metrics: HashMap::new(),
351 cross_validation_scores: Vec::new(),
352 training_time: Duration::from_secs(0),
353 evaluation_time: Duration::from_secs(0),
354 memory_peak_usage: 0,
355 error_message: Some(e.to_string()),
356 success: false,
357 });
358 }
359 }
360 }
361
362 {
364 let mut history = self
365 .trial_history
366 .write()
367 .expect("trial_history lock should not be poisoned");
368 history.extend(results.clone());
369 }
370
371 let best_trial = self
373 .best_trial
374 .read()
375 .expect("best_trial lock should not be poisoned");
376 let best_configuration = best_trial
377 .as_ref()
378 .map(|r| r.trial.clone())
379 .unwrap_or_else(|| trials[0].clone());
380
381 let best_metrics = best_trial
382 .as_ref()
383 .map(|r| r.metrics.clone())
384 .unwrap_or_default();
385
386 let pareto_frontier = self.compute_pareto_frontier(&results);
387 let improvement_curve = self.compute_improvement_curve(&results);
388
389 Ok(AutoMLResults {
390 best_configuration,
391 best_metrics,
392 pareto_frontier,
393 optimization_history: results,
394 total_optimization_time: start_time.elapsed(),
395 trials_completed: trials.len(),
396 improvement_curve,
397 })
398 }
399
400 fn generate_optimization_trials(&self) -> Result<Vec<OptimizationTrial>> {
402 let mut trials = Vec::new();
403
404 for embedding_strategy in &self.config.search_space.embedding_strategies {
406 for &vector_dimension in &self.config.search_space.vector_dimensions {
407 for similarity_metric in &self.config.search_space.similarity_metrics {
408 for &learning_rate in &self.config.search_space.learning_rates {
409 for &batch_size in &self.config.search_space.batch_sizes {
410 let trial = OptimizationTrial {
411 trial_id: Uuid::new_v4().to_string(),
412 embedding_strategy: embedding_strategy.clone(),
413 vector_dimension,
414 similarity_metric: *similarity_metric,
415 index_config: self.generate_index_config()?,
416 hyperparameters: {
417 let mut params = HashMap::new();
418 params.insert("learning_rate".to_string(), learning_rate);
419 params.insert("batch_size".to_string(), batch_size as f32);
420 params
421 },
422 timestamp: std::time::SystemTime::now()
423 .duration_since(std::time::UNIX_EPOCH)
424 .unwrap_or_default()
425 .as_secs(),
426 };
427 trials.push(trial);
428 }
429 }
430 }
431 }
432 }
433
434 for _ in 0..20 {
436 trials.push(self.generate_random_trial()?);
437 }
438
439 Ok(trials)
440 }
441
442 async fn execute_trial(
444 &self,
445 trial: &OptimizationTrial,
446 training_data: &[(String, String)],
447 validation_data: &[(String, String)],
448 test_queries: &[(String, Vec<String>)],
449 ) -> Result<TrialResult> {
450 let trial_start = Instant::now();
451
452 {
454 let mut state = self.optimization_state.lock().await;
455 state
456 .active_trials
457 .insert(trial.trial_id.clone(), trial_start);
458 }
459
460 let mut vector_store = self.create_vector_store_for_trial(trial)?;
462
463 let training_start = Instant::now();
465 for (id, content) in training_data {
466 vector_store
467 .index_resource(id.clone(), content)
468 .context("Failed to index training data")?;
469 }
470 let training_time = training_start.elapsed();
471
472 let cv_scores = self
474 .perform_cross_validation(&vector_store, validation_data, test_queries)
475 .await?;
476
477 let eval_start = Instant::now();
479 let metrics = self
480 .evaluate_trial_performance(&vector_store, test_queries, trial)
481 .await?;
482 let evaluation_time = eval_start.elapsed();
483
484 let memory_peak_usage = self.estimate_memory_usage(&vector_store, trial)?;
486
487 {
489 let mut state = self.optimization_state.lock().await;
490 state.active_trials.remove(&trial.trial_id);
491 }
492
493 Ok(TrialResult {
494 trial: trial.clone(),
495 metrics,
496 cross_validation_scores: cv_scores,
497 training_time,
498 evaluation_time,
499 memory_peak_usage,
500 error_message: None,
501 success: true,
502 })
503 }
504
505 async fn perform_cross_validation(
507 &self,
508 vector_store: &VectorStore,
509 validation_data: &[(String, String)],
510 test_queries: &[(String, Vec<String>)],
511 ) -> Result<Vec<f32>> {
512 let fold_size = validation_data.len() / self.config.cross_validation_folds;
513 let mut cv_scores = Vec::new();
514
515 for fold in 0..self.config.cross_validation_folds {
516 let _start_idx = fold * fold_size;
517 let _end_idx = if fold == self.config.cross_validation_folds - 1 {
518 validation_data.len()
519 } else {
520 (fold + 1) * fold_size
521 };
522
523 let fold_queries: Vec<_> = test_queries
525 .iter()
526 .filter(|(query_id, _)| {
527 let hash = query_id.chars().map(|c| c as u32).sum::<u32>() as usize;
529 let fold_idx = hash % self.config.cross_validation_folds;
530 fold_idx == fold
531 })
532 .cloned()
533 .collect();
534
535 if fold_queries.is_empty() {
536 cv_scores.push(0.0);
537 continue;
538 }
539
540 let mut total_recall = 0.0;
542 for (query, relevant_docs) in &fold_queries {
543 let search_results = vector_store.similarity_search(query, 10)?;
544 let retrieved_docs: Vec<String> =
545 search_results.iter().map(|r| r.0.clone()).collect();
546
547 let recall = self.compute_recall(&retrieved_docs, relevant_docs);
548 total_recall += recall;
549 }
550
551 let avg_recall = total_recall / fold_queries.len() as f32;
552 cv_scores.push(avg_recall);
553 }
554
555 Ok(cv_scores)
556 }
557
558 async fn evaluate_trial_performance(
560 &self,
561 vector_store: &VectorStore,
562 test_queries: &[(String, Vec<String>)],
563 trial: &OptimizationTrial,
564 ) -> Result<HashMap<OptimizationMetric, f32>> {
565 let mut metrics = HashMap::new();
566
567 let mut total_recall = 0.0;
569 let mut total_precision = 0.0;
570 let mut total_latency = 0.0;
571
572 for (query, relevant_docs) in test_queries {
573 let query_start = Instant::now();
574 let search_results = vector_store.similarity_search(query, 10)?;
575 let query_latency = query_start.elapsed().as_millis() as f32;
576
577 total_latency += query_latency;
578
579 let retrieved_docs: Vec<String> = search_results.iter().map(|r| r.0.clone()).collect();
580
581 let recall = self.compute_recall(&retrieved_docs, relevant_docs);
582 let precision = self.compute_precision(&retrieved_docs, relevant_docs);
583
584 total_recall += recall;
585 total_precision += precision;
586 }
587
588 let num_queries = test_queries.len() as f32;
589 metrics.insert(
590 OptimizationMetric::Accuracy,
591 (total_recall + total_precision) / (2.0 * num_queries),
592 );
593 metrics.insert(OptimizationMetric::Latency, total_latency / num_queries);
594
595 let avg_latency_seconds = (total_latency / num_queries) / 1000.0;
597 metrics.insert(OptimizationMetric::Throughput, 1.0 / avg_latency_seconds);
598
599 metrics.insert(
601 OptimizationMetric::MemoryUsage,
602 (trial.vector_dimension as f32) * 4.0,
603 ); metrics.insert(
605 OptimizationMetric::StorageEfficiency,
606 1.0 / (trial.vector_dimension as f32).log2(),
607 );
608
609 let build_time_estimate = match trial.embedding_strategy {
611 EmbeddingStrategy::TfIdf => 100.0,
612 EmbeddingStrategy::SentenceTransformer => 1000.0,
613 _ => 500.0,
614 };
615 metrics.insert(OptimizationMetric::IndexBuildTime, build_time_estimate);
616
617 let embedding_quality = 1.0 - (1.0 / (trial.vector_dimension as f32).sqrt());
619 metrics.insert(OptimizationMetric::EmbeddingQuality, embedding_quality);
620
621 Ok(metrics)
622 }
623
624 fn generate_random_trial(&self) -> Result<OptimizationTrial> {
626 use std::collections::hash_map::DefaultHasher;
627 use std::hash::{Hash, Hasher};
628
629 let mut hasher = DefaultHasher::new();
630 Uuid::new_v4().hash(&mut hasher);
631 let random_seed = hasher.finish();
632
633 let search_space = &self.config.search_space;
634
635 Ok(OptimizationTrial {
636 trial_id: Uuid::new_v4().to_string(),
637 embedding_strategy: search_space.embedding_strategies
638 [(random_seed % search_space.embedding_strategies.len() as u64) as usize]
639 .clone(),
640 vector_dimension: search_space.vector_dimensions
641 [((random_seed >> 8) % search_space.vector_dimensions.len() as u64) as usize],
642 similarity_metric: search_space.similarity_metrics
643 [((random_seed >> 16) % search_space.similarity_metrics.len() as u64) as usize],
644 index_config: self.generate_index_config()?,
645 hyperparameters: {
646 let mut params = HashMap::new();
647 let lr_idx =
648 ((random_seed >> 24) % search_space.learning_rates.len() as u64) as usize;
649 let bs_idx = ((random_seed >> 32) % search_space.batch_sizes.len() as u64) as usize;
650 params.insert(
651 "learning_rate".to_string(),
652 search_space.learning_rates[lr_idx],
653 );
654 params.insert(
655 "batch_size".to_string(),
656 search_space.batch_sizes[bs_idx] as f32,
657 );
658 params
659 },
660 timestamp: std::time::SystemTime::now()
661 .duration_since(std::time::UNIX_EPOCH)
662 .unwrap_or_default()
663 .as_secs(),
664 })
665 }
666
667 fn generate_index_config(&self) -> Result<IndexConfiguration> {
669 Ok(IndexConfiguration {
670 index_type: "hnsw".to_string(),
671 parameters: {
672 let mut params = HashMap::new();
673 params.insert("m".to_string(), 16.0);
674 params.insert("ef_construction".to_string(), 200.0);
675 params.insert("ef_search".to_string(), 100.0);
676 params
677 },
678 })
679 }
680
681 fn create_vector_store_for_trial(&self, trial: &OptimizationTrial) -> Result<VectorStore> {
683 VectorStore::with_embedding_strategy(trial.embedding_strategy.clone())
684 .context("Failed to create vector store for trial")
685 }
686
687 fn compute_primary_score(&self, metrics: &HashMap<OptimizationMetric, f32>) -> f32 {
689 let mut score = 0.0;
690
691 if let Some(&accuracy) = metrics.get(&OptimizationMetric::Accuracy) {
693 score += accuracy * 0.4;
694 }
695 if let Some(&latency) = metrics.get(&OptimizationMetric::Latency) {
696 score += (1.0 / (1.0 + latency / 1000.0)) * 0.3;
698 }
699 if let Some(&throughput) = metrics.get(&OptimizationMetric::Throughput) {
700 score += (throughput / 100.0).min(1.0) * 0.3;
701 }
702
703 score
704 }
705
706 fn compute_pareto_frontier(&self, results: &[TrialResult]) -> Vec<TrialResult> {
708 let mut frontier = Vec::new();
709
710 for result in results {
711 if !result.success {
712 continue;
713 }
714
715 let is_dominated = results.iter().any(|other| {
716 if !other.success || other.trial.trial_id == result.trial.trial_id {
717 return false;
718 }
719
720 let mut better_in_all = true;
722 let mut better_in_some = false;
723
724 for metric in &self.config.optimization_metrics {
725 let result_val = result.metrics.get(metric).unwrap_or(&0.0);
726 let other_val = other.metrics.get(metric).unwrap_or(&0.0);
727
728 match metric {
729 OptimizationMetric::Latency | OptimizationMetric::MemoryUsage => {
730 if other_val > result_val {
732 better_in_all = false;
733 } else if other_val < result_val {
734 better_in_some = true;
735 }
736 }
737 _ => {
738 if other_val < result_val {
740 better_in_all = false;
741 } else if other_val > result_val {
742 better_in_some = true;
743 }
744 }
745 }
746 }
747
748 better_in_all && better_in_some
749 });
750
751 if !is_dominated {
752 frontier.push(result.clone());
753 }
754 }
755
756 frontier
757 }
758
759 fn compute_improvement_curve(&self, results: &[TrialResult]) -> Vec<(usize, f32)> {
761 let mut curve = Vec::new();
762 let mut best_score = f32::NEG_INFINITY;
763
764 for (i, result) in results.iter().enumerate() {
765 if result.success {
766 let score = self.compute_primary_score(&result.metrics);
767 if score > best_score {
768 best_score = score;
769 }
770 }
771 curve.push((i, best_score));
772 }
773
774 curve
775 }
776
777 fn estimate_memory_usage(
779 &self,
780 _vector_store: &VectorStore,
781 trial: &OptimizationTrial,
782 ) -> Result<usize> {
783 let base_memory = 100 * 1024 * 1024; let vector_memory = trial.vector_dimension * 4; let index_overhead = vector_memory / 2; Ok(base_memory + vector_memory + index_overhead)
789 }
790
791 fn compute_recall(&self, retrieved: &[String], relevant: &[String]) -> f32 {
793 if relevant.is_empty() {
794 return 1.0;
795 }
796
797 let relevant_set: std::collections::HashSet<_> = relevant.iter().collect();
798 let retrieved_relevant = retrieved
799 .iter()
800 .filter(|doc| relevant_set.contains(doc))
801 .count();
802
803 retrieved_relevant as f32 / relevant.len() as f32
804 }
805
806 fn compute_precision(&self, retrieved: &[String], relevant: &[String]) -> f32 {
808 if retrieved.is_empty() {
809 return 0.0;
810 }
811
812 let relevant_set: std::collections::HashSet<_> = relevant.iter().collect();
813 let retrieved_relevant = retrieved
814 .iter()
815 .filter(|doc| relevant_set.contains(doc))
816 .count();
817
818 retrieved_relevant as f32 / retrieved.len() as f32
819 }
820
821 pub fn get_optimization_statistics(&self) -> AutoMLStatistics {
823 let history = self
824 .trial_history
825 .read()
826 .expect("trial_history lock should not be poisoned");
827 let best_trial = self
828 .best_trial
829 .read()
830 .expect("best_trial lock should not be poisoned");
831
832 let total_trials = history.len();
833 let successful_trials = history.iter().filter(|r| r.success).count();
834 let average_trial_time = if !history.is_empty() {
835 history
836 .iter()
837 .map(|r| r.training_time + r.evaluation_time)
838 .sum::<Duration>()
839 .as_secs_f32()
840 / history.len() as f32
841 } else {
842 0.0
843 };
844
845 let best_score = best_trial
846 .as_ref()
847 .map(|r| self.compute_primary_score(&r.metrics))
848 .unwrap_or(0.0);
849
850 AutoMLStatistics {
851 total_trials,
852 successful_trials,
853 best_score,
854 average_trial_time,
855 optimization_metrics: self.config.optimization_metrics.clone(),
856 search_space_size: self.estimate_search_space_size(),
857 }
858 }
859
860 fn estimate_search_space_size(&self) -> usize {
861 let space = &self.config.search_space;
862 space.embedding_strategies.len()
863 * space.vector_dimensions.len()
864 * space.similarity_metrics.len()
865 * space.learning_rates.len()
866 * space.batch_sizes.len()
867 }
868}
869
870#[derive(Debug, Clone, Serialize, Deserialize)]
872pub struct AutoMLStatistics {
873 pub total_trials: usize,
874 pub successful_trials: usize,
875 pub best_score: f32,
876 pub average_trial_time: f32,
877 pub optimization_metrics: Vec<OptimizationMetric>,
878 pub search_space_size: usize,
879}
880
881#[cfg(test)]
882mod tests {
883 use super::*;
884 use anyhow::Result;
885
886 #[test]
887 fn test_automl_config_creation() {
888 let config = AutoMLConfig::default();
889 assert_eq!(config.cross_validation_folds, 5);
890 assert_eq!(config.early_stopping_patience, 10);
891 assert!(config.enable_parallel_optimization);
892 }
893
894 #[test]
895 fn test_search_space_default() {
896 let search_space = SearchSpace::default();
897 assert!(!search_space.embedding_strategies.is_empty());
898 assert!(!search_space.vector_dimensions.is_empty());
899 assert!(!search_space.similarity_metrics.is_empty());
900 }
901
902 #[tokio::test]
903 async fn test_automl_optimizer_creation() {
904 let optimizer = AutoMLOptimizer::with_default_config();
905 assert!(optimizer.is_ok());
906 }
907
908 #[test]
909 fn test_trial_generation() -> Result<()> {
910 let optimizer = AutoMLOptimizer::with_default_config()?;
911 let trials = optimizer.generate_optimization_trials()?;
912 assert!(!trials.is_empty());
913
914 let mut trial_ids = std::collections::HashSet::new();
916 for trial in &trials {
917 assert!(trial_ids.insert(trial.trial_id.clone()));
918 }
919 Ok(())
920 }
921
922 #[tokio::test]
923 async fn test_optimization_with_sample_data() -> Result<()> {
924 let _optimizer = AutoMLOptimizer::with_default_config()?;
925
926 let training_data = vec![
927 (
928 "doc1".to_string(),
929 "artificial intelligence machine learning".to_string(),
930 ),
931 (
932 "doc2".to_string(),
933 "deep learning neural networks".to_string(),
934 ),
935 (
936 "doc3".to_string(),
937 "natural language processing".to_string(),
938 ),
939 ];
940
941 let validation_data = vec![
942 (
943 "doc4".to_string(),
944 "computer vision image recognition".to_string(),
945 ),
946 (
947 "doc5".to_string(),
948 "reinforcement learning algorithms".to_string(),
949 ),
950 ];
951
952 let test_queries = vec![
953 (
954 "ai query".to_string(),
955 vec!["doc1".to_string(), "doc2".to_string()],
956 ),
957 ("nlp query".to_string(), vec!["doc3".to_string()]),
958 ];
959
960 let config = AutoMLConfig {
962 max_optimization_time: Duration::from_secs(1),
963 trials_per_config: 1,
964 ..Default::default()
965 };
966
967 let optimizer = AutoMLOptimizer::new(config)?;
968 let results = optimizer
969 .optimize_embeddings(&training_data, &validation_data, &test_queries)
970 .await;
971
972 assert!(results.is_ok());
974 Ok(())
975 }
976
977 #[test]
978 fn test_pareto_frontier_computation() -> Result<()> {
979 let optimizer = AutoMLOptimizer::with_default_config()?;
980
981 let trial1 = OptimizationTrial {
982 trial_id: "trial1".to_string(),
983 embedding_strategy: EmbeddingStrategy::TfIdf,
984 vector_dimension: 128,
985 similarity_metric: SimilarityMetric::Cosine,
986 index_config: optimizer.generate_index_config()?,
987 hyperparameters: HashMap::new(),
988 timestamp: 0,
989 };
990
991 let trial2 = OptimizationTrial {
992 trial_id: "trial2".to_string(),
993 embedding_strategy: EmbeddingStrategy::SentenceTransformer,
994 vector_dimension: 256,
995 similarity_metric: SimilarityMetric::Euclidean,
996 index_config: optimizer.generate_index_config()?,
997 hyperparameters: HashMap::new(),
998 timestamp: 0,
999 };
1000
1001 let mut metrics1 = HashMap::new();
1002 metrics1.insert(OptimizationMetric::Accuracy, 0.8);
1003 metrics1.insert(OptimizationMetric::Latency, 100.0);
1004
1005 let mut metrics2 = HashMap::new();
1006 metrics2.insert(OptimizationMetric::Accuracy, 0.9);
1007 metrics2.insert(OptimizationMetric::Latency, 200.0);
1008
1009 let results = vec![
1010 TrialResult {
1011 trial: trial1,
1012 metrics: metrics1,
1013 cross_validation_scores: vec![0.8],
1014 training_time: Duration::from_secs(1),
1015 evaluation_time: Duration::from_secs(1),
1016 memory_peak_usage: 1000,
1017 error_message: None,
1018 success: true,
1019 },
1020 TrialResult {
1021 trial: trial2,
1022 metrics: metrics2,
1023 cross_validation_scores: vec![0.9],
1024 training_time: Duration::from_secs(2),
1025 evaluation_time: Duration::from_secs(1),
1026 memory_peak_usage: 2000,
1027 error_message: None,
1028 success: true,
1029 },
1030 ];
1031
1032 let frontier = optimizer.compute_pareto_frontier(&results);
1033 assert_eq!(frontier.len(), 2); Ok(())
1035 }
1036
1037 #[test]
1038 fn test_recall_precision_computation() -> Result<()> {
1039 let optimizer = AutoMLOptimizer::with_default_config()?;
1040
1041 let retrieved = vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()];
1042 let relevant = vec!["doc1".to_string(), "doc3".to_string(), "doc4".to_string()];
1043
1044 let recall = optimizer.compute_recall(&retrieved, &relevant);
1045 let precision = optimizer.compute_precision(&retrieved, &relevant);
1046
1047 assert_eq!(recall, 2.0 / 3.0); assert_eq!(precision, 2.0 / 3.0); Ok(())
1050 }
1051
1052 #[test]
1053 fn test_optimization_statistics() -> Result<()> {
1054 let optimizer = AutoMLOptimizer::with_default_config()?;
1055 let stats = optimizer.get_optimization_statistics();
1056
1057 assert_eq!(stats.total_trials, 0);
1058 assert_eq!(stats.successful_trials, 0);
1059 assert!(stats.search_space_size > 0);
1060 Ok(())
1061 }
1062}