1use crate::{
9 advanced_analytics::VectorAnalyticsEngine,
10 benchmarking::{BenchmarkConfig, BenchmarkSuite},
11 embeddings::EmbeddingStrategy,
12 similarity::SimilarityMetric,
13 VectorStore,
14};
15
16use anyhow::{Context, Result};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::{Arc, RwLock};
20use std::time::{Duration, Instant};
21use tokio::sync::Mutex;
22use tracing::{info, span, warn, Level};
23use uuid::Uuid;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct AutoMLConfig {
28 pub max_optimization_time: Duration,
30 pub trials_per_config: usize,
32 pub optimization_metrics: Vec<OptimizationMetric>,
34 pub search_space: SearchSpace,
36 pub cross_validation_folds: usize,
38 pub early_stopping_patience: usize,
40 pub enable_parallel_optimization: bool,
42 pub resource_constraints: ResourceConstraints,
44}
45
46impl Default for AutoMLConfig {
47 fn default() -> Self {
48 Self {
49 max_optimization_time: Duration::from_secs(3600), trials_per_config: 5,
51 optimization_metrics: vec![
52 OptimizationMetric::Accuracy,
53 OptimizationMetric::Latency,
54 OptimizationMetric::MemoryUsage,
55 ],
56 search_space: SearchSpace::default(),
57 cross_validation_folds: 5,
58 early_stopping_patience: 10,
59 enable_parallel_optimization: true,
60 resource_constraints: ResourceConstraints::default(),
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
67pub enum OptimizationMetric {
68 Accuracy,
70 Latency,
72 MemoryUsage,
74 Throughput,
76 IndexBuildTime,
78 StorageEfficiency,
80 EmbeddingQuality,
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct SearchSpace {
87 pub embedding_strategies: Vec<EmbeddingStrategy>,
89 pub vector_dimensions: Vec<usize>,
91 pub similarity_metrics: Vec<SimilarityMetric>,
93 pub index_parameters: IndexParameterSpace,
95 pub learning_rates: Vec<f32>,
97 pub batch_sizes: Vec<usize>,
99}
100
101impl Default for SearchSpace {
102 fn default() -> Self {
103 Self {
104 embedding_strategies: vec![
105 EmbeddingStrategy::TfIdf,
106 EmbeddingStrategy::SentenceTransformer,
107 EmbeddingStrategy::Custom("default".to_string()),
108 ],
109 vector_dimensions: vec![128, 256, 512, 768, 1024],
110 similarity_metrics: vec![
111 SimilarityMetric::Cosine,
112 SimilarityMetric::Euclidean,
113 SimilarityMetric::DotProduct,
114 ],
115 index_parameters: IndexParameterSpace::default(),
116 learning_rates: vec![0.001, 0.01, 0.1],
117 batch_sizes: vec![32, 64, 128, 256],
118 }
119 }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct IndexParameterSpace {
125 pub hnsw_m: Vec<usize>,
127 pub hnsw_ef_construction: Vec<usize>,
128 pub hnsw_ef_search: Vec<usize>,
129 pub ivf_nlist: Vec<usize>,
131 pub ivf_nprobe: Vec<usize>,
132 pub pq_m: Vec<usize>,
134 pub pq_nbits: Vec<usize>,
135}
136
137impl Default for IndexParameterSpace {
138 fn default() -> Self {
139 Self {
140 hnsw_m: vec![16, 32, 64],
141 hnsw_ef_construction: vec![100, 200, 400],
142 hnsw_ef_search: vec![50, 100, 200],
143 ivf_nlist: vec![100, 1000, 4096],
144 ivf_nprobe: vec![1, 10, 50],
145 pq_m: vec![8, 16, 32],
146 pq_nbits: vec![4, 8],
147 }
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct ResourceConstraints {
154 pub max_memory_bytes: usize,
156 pub max_cpu_cores: usize,
158 pub max_gpu_memory_bytes: Option<usize>,
160 pub max_disk_usage_bytes: usize,
162}
163
164impl Default for ResourceConstraints {
165 fn default() -> Self {
166 Self {
167 max_memory_bytes: 8 * 1024 * 1024 * 1024, max_cpu_cores: 4,
169 max_gpu_memory_bytes: None,
170 max_disk_usage_bytes: 50 * 1024 * 1024 * 1024, }
172 }
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct OptimizationTrial {
178 pub trial_id: String,
179 pub embedding_strategy: EmbeddingStrategy,
180 pub vector_dimension: usize,
181 pub similarity_metric: SimilarityMetric,
182 pub index_config: IndexConfiguration,
183 pub hyperparameters: HashMap<String, f32>,
184 pub timestamp: u64,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct IndexConfiguration {
190 pub index_type: String,
191 pub parameters: HashMap<String, f32>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct TrialResult {
197 pub trial: OptimizationTrial,
198 pub metrics: HashMap<OptimizationMetric, f32>,
199 pub cross_validation_scores: Vec<f32>,
200 pub training_time: Duration,
201 pub evaluation_time: Duration,
202 pub memory_peak_usage: usize,
203 pub error_message: Option<String>,
204 pub success: bool,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct AutoMLResults {
210 pub best_configuration: OptimizationTrial,
211 pub best_metrics: HashMap<OptimizationMetric, f32>,
212 pub pareto_frontier: Vec<TrialResult>,
213 pub optimization_history: Vec<TrialResult>,
214 pub total_optimization_time: Duration,
215 pub trials_completed: usize,
216 pub improvement_curve: Vec<(usize, f32)>,
217}
218
219pub struct AutoMLOptimizer {
221 config: AutoMLConfig,
222 trial_history: Arc<RwLock<Vec<TrialResult>>>,
223 best_trial: Arc<RwLock<Option<TrialResult>>>,
224 optimization_state: Arc<Mutex<OptimizationState>>,
225 #[allow(dead_code)]
226 analytics_engine: Arc<Mutex<VectorAnalyticsEngine>>,
227 #[allow(dead_code)]
228 benchmark_engine: Arc<Mutex<BenchmarkSuite>>,
229}
230
231#[derive(Debug)]
233struct OptimizationState {
234 #[allow(dead_code)]
235 current_trial: usize,
236 early_stopping_counter: usize,
237 best_score: f32,
238 #[allow(dead_code)]
239 pareto_frontier: Vec<TrialResult>,
240 active_trials: HashMap<String, Instant>,
241}
242
243impl AutoMLOptimizer {
244 pub fn new(config: AutoMLConfig) -> Result<Self> {
246 Ok(Self {
247 config,
248 trial_history: Arc::new(RwLock::new(Vec::new())),
249 best_trial: Arc::new(RwLock::new(None)),
250 optimization_state: Arc::new(Mutex::new(OptimizationState {
251 current_trial: 0,
252 early_stopping_counter: 0,
253 best_score: f32::NEG_INFINITY,
254 pareto_frontier: Vec::new(),
255 active_trials: HashMap::new(),
256 })),
257 analytics_engine: Arc::new(Mutex::new(VectorAnalyticsEngine::new())),
258 benchmark_engine: Arc::new(Mutex::new(BenchmarkSuite::new(BenchmarkConfig::default()))),
259 })
260 }
261
262 pub fn with_default_config() -> Result<Self> {
264 Self::new(AutoMLConfig::default())
265 }
266
267 pub async fn optimize_embeddings(
269 &self,
270 training_data: &[(String, String)], validation_data: &[(String, String)],
272 test_queries: &[(String, Vec<String>)], ) -> Result<AutoMLResults> {
274 let span = span!(Level::INFO, "automl_optimization");
275 let _enter = span.enter();
276
277 info!(
278 "Starting AutoML optimization for {} training samples",
279 training_data.len()
280 );
281
282 let start_time = Instant::now();
283 let optimization_state = self.optimization_state.lock().await;
284
285 let trials = self.generate_optimization_trials()?;
287 info!("Generated {} optimization trials", trials.len());
288
289 drop(optimization_state);
290
291 let mut results = Vec::new();
292 let mut best_score = f32::NEG_INFINITY;
293
294 for (i, trial) in trials.iter().enumerate() {
296 if start_time.elapsed() > self.config.max_optimization_time {
297 warn!("Optimization time budget exceeded, stopping early");
298 break;
299 }
300
301 info!(
302 "Executing trial {}/{}: {}",
303 i + 1,
304 trials.len(),
305 trial.trial_id
306 );
307
308 match self
309 .execute_trial(trial, training_data, validation_data, test_queries)
310 .await
311 {
312 Ok(trial_result) => {
313 let primary_score = self.compute_primary_score(&trial_result.metrics);
315
316 if primary_score > best_score {
317 best_score = primary_score;
318 {
319 let mut best_trial = self.best_trial.write().unwrap();
320 *best_trial = Some(trial_result.clone());
321 } let mut state = self.optimization_state.lock().await;
325 state.early_stopping_counter = 0;
326 state.best_score = best_score;
327 } else {
328 let mut state = self.optimization_state.lock().await;
329 state.early_stopping_counter += 1;
330
331 if state.early_stopping_counter >= self.config.early_stopping_patience {
333 info!(
334 "Early stopping triggered after {} trials without improvement",
335 self.config.early_stopping_patience
336 );
337 break;
338 }
339 }
340
341 results.push(trial_result);
342 }
343 Err(e) => {
344 warn!("Trial {} failed: {}", trial.trial_id, e);
345 results.push(TrialResult {
346 trial: trial.clone(),
347 metrics: HashMap::new(),
348 cross_validation_scores: Vec::new(),
349 training_time: Duration::from_secs(0),
350 evaluation_time: Duration::from_secs(0),
351 memory_peak_usage: 0,
352 error_message: Some(e.to_string()),
353 success: false,
354 });
355 }
356 }
357 }
358
359 {
361 let mut history = self.trial_history.write().unwrap();
362 history.extend(results.clone());
363 }
364
365 let best_trial = self.best_trial.read().unwrap();
367 let best_configuration = best_trial
368 .as_ref()
369 .map(|r| r.trial.clone())
370 .unwrap_or_else(|| trials[0].clone());
371
372 let best_metrics = best_trial
373 .as_ref()
374 .map(|r| r.metrics.clone())
375 .unwrap_or_default();
376
377 let pareto_frontier = self.compute_pareto_frontier(&results);
378 let improvement_curve = self.compute_improvement_curve(&results);
379
380 Ok(AutoMLResults {
381 best_configuration,
382 best_metrics,
383 pareto_frontier,
384 optimization_history: results,
385 total_optimization_time: start_time.elapsed(),
386 trials_completed: trials.len(),
387 improvement_curve,
388 })
389 }
390
391 fn generate_optimization_trials(&self) -> Result<Vec<OptimizationTrial>> {
393 let mut trials = Vec::new();
394
395 for embedding_strategy in &self.config.search_space.embedding_strategies {
397 for &vector_dimension in &self.config.search_space.vector_dimensions {
398 for similarity_metric in &self.config.search_space.similarity_metrics {
399 for &learning_rate in &self.config.search_space.learning_rates {
400 for &batch_size in &self.config.search_space.batch_sizes {
401 let trial = OptimizationTrial {
402 trial_id: Uuid::new_v4().to_string(),
403 embedding_strategy: embedding_strategy.clone(),
404 vector_dimension,
405 similarity_metric: *similarity_metric,
406 index_config: self.generate_index_config()?,
407 hyperparameters: {
408 let mut params = HashMap::new();
409 params.insert("learning_rate".to_string(), learning_rate);
410 params.insert("batch_size".to_string(), batch_size as f32);
411 params
412 },
413 timestamp: std::time::SystemTime::now()
414 .duration_since(std::time::UNIX_EPOCH)
415 .unwrap_or_default()
416 .as_secs(),
417 };
418 trials.push(trial);
419 }
420 }
421 }
422 }
423 }
424
425 for _ in 0..20 {
427 trials.push(self.generate_random_trial()?);
428 }
429
430 Ok(trials)
431 }
432
433 async fn execute_trial(
435 &self,
436 trial: &OptimizationTrial,
437 training_data: &[(String, String)],
438 validation_data: &[(String, String)],
439 test_queries: &[(String, Vec<String>)],
440 ) -> Result<TrialResult> {
441 let trial_start = Instant::now();
442
443 {
445 let mut state = self.optimization_state.lock().await;
446 state
447 .active_trials
448 .insert(trial.trial_id.clone(), trial_start);
449 }
450
451 let mut vector_store = self.create_vector_store_for_trial(trial)?;
453
454 let training_start = Instant::now();
456 for (id, content) in training_data {
457 vector_store
458 .index_resource(id.clone(), content)
459 .context("Failed to index training data")?;
460 }
461 let training_time = training_start.elapsed();
462
463 let cv_scores = self
465 .perform_cross_validation(&vector_store, validation_data, test_queries)
466 .await?;
467
468 let eval_start = Instant::now();
470 let metrics = self
471 .evaluate_trial_performance(&vector_store, test_queries, trial)
472 .await?;
473 let evaluation_time = eval_start.elapsed();
474
475 let memory_peak_usage = self.estimate_memory_usage(&vector_store, trial)?;
477
478 {
480 let mut state = self.optimization_state.lock().await;
481 state.active_trials.remove(&trial.trial_id);
482 }
483
484 Ok(TrialResult {
485 trial: trial.clone(),
486 metrics,
487 cross_validation_scores: cv_scores,
488 training_time,
489 evaluation_time,
490 memory_peak_usage,
491 error_message: None,
492 success: true,
493 })
494 }
495
496 async fn perform_cross_validation(
498 &self,
499 vector_store: &VectorStore,
500 validation_data: &[(String, String)],
501 test_queries: &[(String, Vec<String>)],
502 ) -> Result<Vec<f32>> {
503 let fold_size = validation_data.len() / self.config.cross_validation_folds;
504 let mut cv_scores = Vec::new();
505
506 for fold in 0..self.config.cross_validation_folds {
507 let _start_idx = fold * fold_size;
508 let _end_idx = if fold == self.config.cross_validation_folds - 1 {
509 validation_data.len()
510 } else {
511 (fold + 1) * fold_size
512 };
513
514 let fold_queries: Vec<_> = test_queries
516 .iter()
517 .filter(|(query_id, _)| {
518 let hash = query_id.chars().map(|c| c as u32).sum::<u32>() as usize;
520 let fold_idx = hash % self.config.cross_validation_folds;
521 fold_idx == fold
522 })
523 .cloned()
524 .collect();
525
526 if fold_queries.is_empty() {
527 cv_scores.push(0.0);
528 continue;
529 }
530
531 let mut total_recall = 0.0;
533 for (query, relevant_docs) in &fold_queries {
534 let search_results = vector_store.similarity_search(query, 10)?;
535 let retrieved_docs: Vec<String> =
536 search_results.iter().map(|r| r.0.clone()).collect();
537
538 let recall = self.compute_recall(&retrieved_docs, relevant_docs);
539 total_recall += recall;
540 }
541
542 let avg_recall = total_recall / fold_queries.len() as f32;
543 cv_scores.push(avg_recall);
544 }
545
546 Ok(cv_scores)
547 }
548
549 async fn evaluate_trial_performance(
551 &self,
552 vector_store: &VectorStore,
553 test_queries: &[(String, Vec<String>)],
554 trial: &OptimizationTrial,
555 ) -> Result<HashMap<OptimizationMetric, f32>> {
556 let mut metrics = HashMap::new();
557
558 let mut total_recall = 0.0;
560 let mut total_precision = 0.0;
561 let mut total_latency = 0.0;
562
563 for (query, relevant_docs) in test_queries {
564 let query_start = Instant::now();
565 let search_results = vector_store.similarity_search(query, 10)?;
566 let query_latency = query_start.elapsed().as_millis() as f32;
567
568 total_latency += query_latency;
569
570 let retrieved_docs: Vec<String> = search_results.iter().map(|r| r.0.clone()).collect();
571
572 let recall = self.compute_recall(&retrieved_docs, relevant_docs);
573 let precision = self.compute_precision(&retrieved_docs, relevant_docs);
574
575 total_recall += recall;
576 total_precision += precision;
577 }
578
579 let num_queries = test_queries.len() as f32;
580 metrics.insert(
581 OptimizationMetric::Accuracy,
582 (total_recall + total_precision) / (2.0 * num_queries),
583 );
584 metrics.insert(OptimizationMetric::Latency, total_latency / num_queries);
585
586 let avg_latency_seconds = (total_latency / num_queries) / 1000.0;
588 metrics.insert(OptimizationMetric::Throughput, 1.0 / avg_latency_seconds);
589
590 metrics.insert(
592 OptimizationMetric::MemoryUsage,
593 (trial.vector_dimension as f32) * 4.0,
594 ); metrics.insert(
596 OptimizationMetric::StorageEfficiency,
597 1.0 / (trial.vector_dimension as f32).log2(),
598 );
599
600 let build_time_estimate = match trial.embedding_strategy {
602 EmbeddingStrategy::TfIdf => 100.0,
603 EmbeddingStrategy::SentenceTransformer => 1000.0,
604 _ => 500.0,
605 };
606 metrics.insert(OptimizationMetric::IndexBuildTime, build_time_estimate);
607
608 let embedding_quality = 1.0 - (1.0 / (trial.vector_dimension as f32).sqrt());
610 metrics.insert(OptimizationMetric::EmbeddingQuality, embedding_quality);
611
612 Ok(metrics)
613 }
614
615 fn generate_random_trial(&self) -> Result<OptimizationTrial> {
617 use std::collections::hash_map::DefaultHasher;
618 use std::hash::{Hash, Hasher};
619
620 let mut hasher = DefaultHasher::new();
621 Uuid::new_v4().hash(&mut hasher);
622 let random_seed = hasher.finish();
623
624 let search_space = &self.config.search_space;
625
626 Ok(OptimizationTrial {
627 trial_id: Uuid::new_v4().to_string(),
628 embedding_strategy: search_space.embedding_strategies
629 [(random_seed % search_space.embedding_strategies.len() as u64) as usize]
630 .clone(),
631 vector_dimension: search_space.vector_dimensions
632 [((random_seed >> 8) % search_space.vector_dimensions.len() as u64) as usize],
633 similarity_metric: search_space.similarity_metrics
634 [((random_seed >> 16) % search_space.similarity_metrics.len() as u64) as usize],
635 index_config: self.generate_index_config()?,
636 hyperparameters: {
637 let mut params = HashMap::new();
638 let lr_idx =
639 ((random_seed >> 24) % search_space.learning_rates.len() as u64) as usize;
640 let bs_idx = ((random_seed >> 32) % search_space.batch_sizes.len() as u64) as usize;
641 params.insert(
642 "learning_rate".to_string(),
643 search_space.learning_rates[lr_idx],
644 );
645 params.insert(
646 "batch_size".to_string(),
647 search_space.batch_sizes[bs_idx] as f32,
648 );
649 params
650 },
651 timestamp: std::time::SystemTime::now()
652 .duration_since(std::time::UNIX_EPOCH)
653 .unwrap_or_default()
654 .as_secs(),
655 })
656 }
657
658 fn generate_index_config(&self) -> Result<IndexConfiguration> {
660 Ok(IndexConfiguration {
661 index_type: "hnsw".to_string(),
662 parameters: {
663 let mut params = HashMap::new();
664 params.insert("m".to_string(), 16.0);
665 params.insert("ef_construction".to_string(), 200.0);
666 params.insert("ef_search".to_string(), 100.0);
667 params
668 },
669 })
670 }
671
672 fn create_vector_store_for_trial(&self, trial: &OptimizationTrial) -> Result<VectorStore> {
674 VectorStore::with_embedding_strategy(trial.embedding_strategy.clone())
675 .context("Failed to create vector store for trial")
676 }
677
678 fn compute_primary_score(&self, metrics: &HashMap<OptimizationMetric, f32>) -> f32 {
680 let mut score = 0.0;
681
682 if let Some(&accuracy) = metrics.get(&OptimizationMetric::Accuracy) {
684 score += accuracy * 0.4;
685 }
686 if let Some(&latency) = metrics.get(&OptimizationMetric::Latency) {
687 score += (1.0 / (1.0 + latency / 1000.0)) * 0.3;
689 }
690 if let Some(&throughput) = metrics.get(&OptimizationMetric::Throughput) {
691 score += (throughput / 100.0).min(1.0) * 0.3;
692 }
693
694 score
695 }
696
697 fn compute_pareto_frontier(&self, results: &[TrialResult]) -> Vec<TrialResult> {
699 let mut frontier = Vec::new();
700
701 for result in results {
702 if !result.success {
703 continue;
704 }
705
706 let is_dominated = results.iter().any(|other| {
707 if !other.success || other.trial.trial_id == result.trial.trial_id {
708 return false;
709 }
710
711 let mut better_in_all = true;
713 let mut better_in_some = false;
714
715 for metric in &self.config.optimization_metrics {
716 let result_val = result.metrics.get(metric).unwrap_or(&0.0);
717 let other_val = other.metrics.get(metric).unwrap_or(&0.0);
718
719 match metric {
720 OptimizationMetric::Latency | OptimizationMetric::MemoryUsage => {
721 if other_val > result_val {
723 better_in_all = false;
724 } else if other_val < result_val {
725 better_in_some = true;
726 }
727 }
728 _ => {
729 if other_val < result_val {
731 better_in_all = false;
732 } else if other_val > result_val {
733 better_in_some = true;
734 }
735 }
736 }
737 }
738
739 better_in_all && better_in_some
740 });
741
742 if !is_dominated {
743 frontier.push(result.clone());
744 }
745 }
746
747 frontier
748 }
749
750 fn compute_improvement_curve(&self, results: &[TrialResult]) -> Vec<(usize, f32)> {
752 let mut curve = Vec::new();
753 let mut best_score = f32::NEG_INFINITY;
754
755 for (i, result) in results.iter().enumerate() {
756 if result.success {
757 let score = self.compute_primary_score(&result.metrics);
758 if score > best_score {
759 best_score = score;
760 }
761 }
762 curve.push((i, best_score));
763 }
764
765 curve
766 }
767
768 fn estimate_memory_usage(
770 &self,
771 _vector_store: &VectorStore,
772 trial: &OptimizationTrial,
773 ) -> Result<usize> {
774 let base_memory = 100 * 1024 * 1024; let vector_memory = trial.vector_dimension * 4; let index_overhead = vector_memory / 2; Ok(base_memory + vector_memory + index_overhead)
780 }
781
782 fn compute_recall(&self, retrieved: &[String], relevant: &[String]) -> f32 {
784 if relevant.is_empty() {
785 return 1.0;
786 }
787
788 let relevant_set: std::collections::HashSet<_> = relevant.iter().collect();
789 let retrieved_relevant = retrieved
790 .iter()
791 .filter(|doc| relevant_set.contains(doc))
792 .count();
793
794 retrieved_relevant as f32 / relevant.len() as f32
795 }
796
797 fn compute_precision(&self, retrieved: &[String], relevant: &[String]) -> f32 {
799 if retrieved.is_empty() {
800 return 0.0;
801 }
802
803 let relevant_set: std::collections::HashSet<_> = relevant.iter().collect();
804 let retrieved_relevant = retrieved
805 .iter()
806 .filter(|doc| relevant_set.contains(doc))
807 .count();
808
809 retrieved_relevant as f32 / retrieved.len() as f32
810 }
811
812 pub fn get_optimization_statistics(&self) -> AutoMLStatistics {
814 let history = self.trial_history.read().unwrap();
815 let best_trial = self.best_trial.read().unwrap();
816
817 let total_trials = history.len();
818 let successful_trials = history.iter().filter(|r| r.success).count();
819 let average_trial_time = if !history.is_empty() {
820 history
821 .iter()
822 .map(|r| r.training_time + r.evaluation_time)
823 .sum::<Duration>()
824 .as_secs_f32()
825 / history.len() as f32
826 } else {
827 0.0
828 };
829
830 let best_score = best_trial
831 .as_ref()
832 .map(|r| self.compute_primary_score(&r.metrics))
833 .unwrap_or(0.0);
834
835 AutoMLStatistics {
836 total_trials,
837 successful_trials,
838 best_score,
839 average_trial_time,
840 optimization_metrics: self.config.optimization_metrics.clone(),
841 search_space_size: self.estimate_search_space_size(),
842 }
843 }
844
845 fn estimate_search_space_size(&self) -> usize {
846 let space = &self.config.search_space;
847 space.embedding_strategies.len()
848 * space.vector_dimensions.len()
849 * space.similarity_metrics.len()
850 * space.learning_rates.len()
851 * space.batch_sizes.len()
852 }
853}
854
855#[derive(Debug, Clone, Serialize, Deserialize)]
857pub struct AutoMLStatistics {
858 pub total_trials: usize,
859 pub successful_trials: usize,
860 pub best_score: f32,
861 pub average_trial_time: f32,
862 pub optimization_metrics: Vec<OptimizationMetric>,
863 pub search_space_size: usize,
864}
865
866#[cfg(test)]
867mod tests {
868 use super::*;
869
870 #[test]
871 fn test_automl_config_creation() {
872 let config = AutoMLConfig::default();
873 assert_eq!(config.cross_validation_folds, 5);
874 assert_eq!(config.early_stopping_patience, 10);
875 assert!(config.enable_parallel_optimization);
876 }
877
878 #[test]
879 fn test_search_space_default() {
880 let search_space = SearchSpace::default();
881 assert!(!search_space.embedding_strategies.is_empty());
882 assert!(!search_space.vector_dimensions.is_empty());
883 assert!(!search_space.similarity_metrics.is_empty());
884 }
885
886 #[tokio::test]
887 async fn test_automl_optimizer_creation() {
888 let optimizer = AutoMLOptimizer::with_default_config();
889 assert!(optimizer.is_ok());
890 }
891
892 #[test]
893 fn test_trial_generation() {
894 let optimizer = AutoMLOptimizer::with_default_config().unwrap();
895 let trials = optimizer.generate_optimization_trials().unwrap();
896 assert!(!trials.is_empty());
897
898 let mut trial_ids = std::collections::HashSet::new();
900 for trial in &trials {
901 assert!(trial_ids.insert(trial.trial_id.clone()));
902 }
903 }
904
905 #[tokio::test]
906 async fn test_optimization_with_sample_data() {
907 let _optimizer = AutoMLOptimizer::with_default_config().unwrap();
908
909 let training_data = vec![
910 (
911 "doc1".to_string(),
912 "artificial intelligence machine learning".to_string(),
913 ),
914 (
915 "doc2".to_string(),
916 "deep learning neural networks".to_string(),
917 ),
918 (
919 "doc3".to_string(),
920 "natural language processing".to_string(),
921 ),
922 ];
923
924 let validation_data = vec![
925 (
926 "doc4".to_string(),
927 "computer vision image recognition".to_string(),
928 ),
929 (
930 "doc5".to_string(),
931 "reinforcement learning algorithms".to_string(),
932 ),
933 ];
934
935 let test_queries = vec![
936 (
937 "ai query".to_string(),
938 vec!["doc1".to_string(), "doc2".to_string()],
939 ),
940 ("nlp query".to_string(), vec!["doc3".to_string()]),
941 ];
942
943 let config = AutoMLConfig {
945 max_optimization_time: Duration::from_secs(1),
946 trials_per_config: 1,
947 ..Default::default()
948 };
949
950 let optimizer = AutoMLOptimizer::new(config).unwrap();
951 let results = optimizer
952 .optimize_embeddings(&training_data, &validation_data, &test_queries)
953 .await;
954
955 assert!(results.is_ok());
957 }
958
959 #[test]
960 fn test_pareto_frontier_computation() {
961 let optimizer = AutoMLOptimizer::with_default_config().unwrap();
962
963 let trial1 = OptimizationTrial {
964 trial_id: "trial1".to_string(),
965 embedding_strategy: EmbeddingStrategy::TfIdf,
966 vector_dimension: 128,
967 similarity_metric: SimilarityMetric::Cosine,
968 index_config: optimizer.generate_index_config().unwrap(),
969 hyperparameters: HashMap::new(),
970 timestamp: 0,
971 };
972
973 let trial2 = OptimizationTrial {
974 trial_id: "trial2".to_string(),
975 embedding_strategy: EmbeddingStrategy::SentenceTransformer,
976 vector_dimension: 256,
977 similarity_metric: SimilarityMetric::Euclidean,
978 index_config: optimizer.generate_index_config().unwrap(),
979 hyperparameters: HashMap::new(),
980 timestamp: 0,
981 };
982
983 let mut metrics1 = HashMap::new();
984 metrics1.insert(OptimizationMetric::Accuracy, 0.8);
985 metrics1.insert(OptimizationMetric::Latency, 100.0);
986
987 let mut metrics2 = HashMap::new();
988 metrics2.insert(OptimizationMetric::Accuracy, 0.9);
989 metrics2.insert(OptimizationMetric::Latency, 200.0);
990
991 let results = vec![
992 TrialResult {
993 trial: trial1,
994 metrics: metrics1,
995 cross_validation_scores: vec![0.8],
996 training_time: Duration::from_secs(1),
997 evaluation_time: Duration::from_secs(1),
998 memory_peak_usage: 1000,
999 error_message: None,
1000 success: true,
1001 },
1002 TrialResult {
1003 trial: trial2,
1004 metrics: metrics2,
1005 cross_validation_scores: vec![0.9],
1006 training_time: Duration::from_secs(2),
1007 evaluation_time: Duration::from_secs(1),
1008 memory_peak_usage: 2000,
1009 error_message: None,
1010 success: true,
1011 },
1012 ];
1013
1014 let frontier = optimizer.compute_pareto_frontier(&results);
1015 assert_eq!(frontier.len(), 2); }
1017
1018 #[test]
1019 fn test_recall_precision_computation() {
1020 let optimizer = AutoMLOptimizer::with_default_config().unwrap();
1021
1022 let retrieved = vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()];
1023 let relevant = vec!["doc1".to_string(), "doc3".to_string(), "doc4".to_string()];
1024
1025 let recall = optimizer.compute_recall(&retrieved, &relevant);
1026 let precision = optimizer.compute_precision(&retrieved, &relevant);
1027
1028 assert_eq!(recall, 2.0 / 3.0); assert_eq!(precision, 2.0 / 3.0); }
1031
1032 #[test]
1033 fn test_optimization_statistics() {
1034 let optimizer = AutoMLOptimizer::with_default_config().unwrap();
1035 let stats = optimizer.get_optimization_statistics();
1036
1037 assert_eq!(stats.total_trials, 0);
1038 assert_eq!(stats.successful_trials, 0);
1039 assert!(stats.search_space_size > 0);
1040 }
1041}