Skip to main content

oxirs_vec/
automl_optimization.rs

1//! AutoML for Embedding Optimization - Version 1.2 Feature
2//!
3//! This module implements comprehensive AutoML capabilities for automatically
4//! optimizing embedding configurations, model selection, hyperparameters, and
5//! vector search performance. It provides intelligent automation for finding
6//! the best embedding strategies for specific datasets and use cases.
7
8use crate::{
9    advanced_analytics::VectorAnalyticsEngine,
10    benchmarking::{BenchmarkConfig, BenchmarkSuite},
11    embeddings::EmbeddingStrategy,
12    similarity::SimilarityMetric,
13    VectorStore,
14};
15
16use anyhow::{Context, Result};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::{Arc, RwLock};
20use std::time::{Duration, Instant};
21use tokio::sync::Mutex;
22use tracing::{info, span, warn, Level};
23use uuid::Uuid;
24
25/// AutoML optimization configuration
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct AutoMLConfig {
28    /// Maximum optimization time budget
29    pub max_optimization_time: Duration,
30    /// Number of trials per configuration
31    pub trials_per_config: usize,
32    /// Evaluation metrics to optimize
33    pub optimization_metrics: Vec<OptimizationMetric>,
34    /// Search space for hyperparameters
35    pub search_space: SearchSpace,
36    /// Cross-validation folds
37    pub cross_validation_folds: usize,
38    /// Early stopping patience
39    pub early_stopping_patience: usize,
40    /// Enable parallel optimization
41    pub enable_parallel_optimization: bool,
42    /// Resource constraints
43    pub resource_constraints: ResourceConstraints,
44}
45
46impl Default for AutoMLConfig {
47    fn default() -> Self {
48        Self {
49            max_optimization_time: Duration::from_secs(3600), // 1 hour
50            trials_per_config: 5,
51            optimization_metrics: vec![
52                OptimizationMetric::Accuracy,
53                OptimizationMetric::Latency,
54                OptimizationMetric::MemoryUsage,
55            ],
56            search_space: SearchSpace::default(),
57            cross_validation_folds: 5,
58            early_stopping_patience: 10,
59            enable_parallel_optimization: true,
60            resource_constraints: ResourceConstraints::default(),
61        }
62    }
63}
64
65/// Metrics to optimize during AutoML
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
67pub enum OptimizationMetric {
68    /// Search accuracy (recall@k, precision@k)
69    Accuracy,
70    /// Query latency
71    Latency,
72    /// Memory usage
73    MemoryUsage,
74    /// Throughput (queries per second)
75    Throughput,
76    /// Index build time
77    IndexBuildTime,
78    /// Storage efficiency
79    StorageEfficiency,
80    /// Embedding quality
81    EmbeddingQuality,
82}
83
84/// Hyperparameter search space configuration
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct SearchSpace {
87    /// Embedding strategies to evaluate
88    pub embedding_strategies: Vec<EmbeddingStrategy>,
89    /// Vector dimensions to test
90    pub vector_dimensions: Vec<usize>,
91    /// Similarity metrics to evaluate
92    pub similarity_metrics: Vec<SimilarityMetric>,
93    /// Index parameters
94    pub index_parameters: IndexParameterSpace,
95    /// Learning rate ranges for trainable embeddings
96    pub learning_rates: Vec<f32>,
97    /// Batch sizes for processing
98    pub batch_sizes: Vec<usize>,
99}
100
101impl Default for SearchSpace {
102    fn default() -> Self {
103        Self {
104            embedding_strategies: vec![
105                EmbeddingStrategy::TfIdf,
106                EmbeddingStrategy::SentenceTransformer,
107                EmbeddingStrategy::Custom("default".to_string()),
108            ],
109            vector_dimensions: vec![128, 256, 512, 768, 1024],
110            similarity_metrics: vec![
111                SimilarityMetric::Cosine,
112                SimilarityMetric::Euclidean,
113                SimilarityMetric::DotProduct,
114            ],
115            index_parameters: IndexParameterSpace::default(),
116            learning_rates: vec![0.001, 0.01, 0.1],
117            batch_sizes: vec![32, 64, 128, 256],
118        }
119    }
120}
121
122/// Index-specific parameter search space
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct IndexParameterSpace {
125    /// HNSW specific parameters
126    pub hnsw_m: Vec<usize>,
127    pub hnsw_ef_construction: Vec<usize>,
128    pub hnsw_ef_search: Vec<usize>,
129    /// IVF specific parameters
130    pub ivf_nlist: Vec<usize>,
131    pub ivf_nprobe: Vec<usize>,
132    /// PQ specific parameters
133    pub pq_m: Vec<usize>,
134    pub pq_nbits: Vec<usize>,
135}
136
137impl Default for IndexParameterSpace {
138    fn default() -> Self {
139        Self {
140            hnsw_m: vec![16, 32, 64],
141            hnsw_ef_construction: vec![100, 200, 400],
142            hnsw_ef_search: vec![50, 100, 200],
143            ivf_nlist: vec![100, 1000, 4096],
144            ivf_nprobe: vec![1, 10, 50],
145            pq_m: vec![8, 16, 32],
146            pq_nbits: vec![4, 8],
147        }
148    }
149}
150
151/// Resource constraints for optimization
152#[derive(Debug, Clone, Serialize, Deserialize)]
153pub struct ResourceConstraints {
154    /// Maximum memory usage in bytes
155    pub max_memory_bytes: usize,
156    /// Maximum CPU cores to use
157    pub max_cpu_cores: usize,
158    /// Maximum GPU memory if available
159    pub max_gpu_memory_bytes: Option<usize>,
160    /// Maximum disk usage for indices
161    pub max_disk_usage_bytes: usize,
162}
163
164impl Default for ResourceConstraints {
165    fn default() -> Self {
166        Self {
167            max_memory_bytes: 8 * 1024 * 1024 * 1024, // 8GB
168            max_cpu_cores: 4,
169            max_gpu_memory_bytes: None,
170            max_disk_usage_bytes: 50 * 1024 * 1024 * 1024, // 50GB
171        }
172    }
173}
174
175/// AutoML optimization configuration for a specific trial
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct OptimizationTrial {
178    pub trial_id: String,
179    pub embedding_strategy: EmbeddingStrategy,
180    pub vector_dimension: usize,
181    pub similarity_metric: SimilarityMetric,
182    pub index_config: IndexConfiguration,
183    pub hyperparameters: HashMap<String, f32>,
184    pub timestamp: u64,
185}
186
187/// Index configuration for optimization trials
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct IndexConfiguration {
190    pub index_type: String,
191    pub parameters: HashMap<String, f32>,
192}
193
194/// Result of an AutoML optimization trial
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct TrialResult {
197    pub trial: OptimizationTrial,
198    pub metrics: HashMap<OptimizationMetric, f32>,
199    pub cross_validation_scores: Vec<f32>,
200    pub training_time: Duration,
201    pub evaluation_time: Duration,
202    pub memory_peak_usage: usize,
203    pub error_message: Option<String>,
204    pub success: bool,
205}
206
207/// AutoML optimization results summary
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct AutoMLResults {
210    pub best_configuration: OptimizationTrial,
211    pub best_metrics: HashMap<OptimizationMetric, f32>,
212    pub pareto_frontier: Vec<TrialResult>,
213    pub optimization_history: Vec<TrialResult>,
214    pub total_optimization_time: Duration,
215    pub trials_completed: usize,
216    pub improvement_curve: Vec<(usize, f32)>,
217}
218
219/// AutoML optimizer for embedding and vector search configurations
220pub struct AutoMLOptimizer {
221    config: AutoMLConfig,
222    trial_history: Arc<RwLock<Vec<TrialResult>>>,
223    best_trial: Arc<RwLock<Option<TrialResult>>>,
224    optimization_state: Arc<Mutex<OptimizationState>>,
225    #[allow(dead_code)]
226    analytics_engine: Arc<Mutex<VectorAnalyticsEngine>>,
227    #[allow(dead_code)]
228    benchmark_engine: Arc<Mutex<BenchmarkSuite>>,
229}
230
231/// Internal optimization state
232#[derive(Debug)]
233struct OptimizationState {
234    #[allow(dead_code)]
235    current_trial: usize,
236    early_stopping_counter: usize,
237    best_score: f32,
238    #[allow(dead_code)]
239    pareto_frontier: Vec<TrialResult>,
240    active_trials: HashMap<String, Instant>,
241}
242
243impl AutoMLOptimizer {
244    /// Create a new AutoML optimizer
245    pub fn new(config: AutoMLConfig) -> Result<Self> {
246        Ok(Self {
247            config,
248            trial_history: Arc::new(RwLock::new(Vec::new())),
249            best_trial: Arc::new(RwLock::new(None)),
250            optimization_state: Arc::new(Mutex::new(OptimizationState {
251                current_trial: 0,
252                early_stopping_counter: 0,
253                best_score: f32::NEG_INFINITY,
254                pareto_frontier: Vec::new(),
255                active_trials: HashMap::new(),
256            })),
257            analytics_engine: Arc::new(Mutex::new(VectorAnalyticsEngine::new())),
258            benchmark_engine: Arc::new(Mutex::new(BenchmarkSuite::new(BenchmarkConfig::default()))),
259        })
260    }
261
262    /// Create optimizer with default configuration
263    pub fn with_default_config() -> Result<Self> {
264        Self::new(AutoMLConfig::default())
265    }
266
267    /// Optimize embedding configuration for given dataset
268    pub async fn optimize_embeddings(
269        &self,
270        training_data: &[(String, String)], // (id, content) pairs
271        validation_data: &[(String, String)],
272        test_queries: &[(String, Vec<String>)], // (query, relevant_docs)
273    ) -> Result<AutoMLResults> {
274        let span = span!(Level::INFO, "automl_optimization");
275        let _enter = span.enter();
276
277        info!(
278            "Starting AutoML optimization for {} training samples",
279            training_data.len()
280        );
281
282        let start_time = Instant::now();
283        let optimization_state = self.optimization_state.lock().await;
284
285        // Generate optimization trials
286        let trials = self.generate_optimization_trials()?;
287        info!("Generated {} optimization trials", trials.len());
288
289        drop(optimization_state);
290
291        let mut results = Vec::new();
292        let mut best_score = f32::NEG_INFINITY;
293
294        // Execute optimization trials
295        for (i, trial) in trials.iter().enumerate() {
296            if start_time.elapsed() > self.config.max_optimization_time {
297                warn!("Optimization time budget exceeded, stopping early");
298                break;
299            }
300
301            info!(
302                "Executing trial {}/{}: {}",
303                i + 1,
304                trials.len(),
305                trial.trial_id
306            );
307
308            match self
309                .execute_trial(trial, training_data, validation_data, test_queries)
310                .await
311            {
312                Ok(trial_result) => {
313                    // Check for improvement
314                    let primary_score = self.compute_primary_score(&trial_result.metrics);
315
316                    if primary_score > best_score {
317                        best_score = primary_score;
318                        {
319                            let mut best_trial = self
320                                .best_trial
321                                .write()
322                                .expect("best_trial lock should not be poisoned");
323                            *best_trial = Some(trial_result.clone());
324                        } // Drop the mutex guard before await
325
326                        // Reset early stopping counter
327                        let mut state = self.optimization_state.lock().await;
328                        state.early_stopping_counter = 0;
329                        state.best_score = best_score;
330                    } else {
331                        let mut state = self.optimization_state.lock().await;
332                        state.early_stopping_counter += 1;
333
334                        // Check early stopping
335                        if state.early_stopping_counter >= self.config.early_stopping_patience {
336                            info!(
337                                "Early stopping triggered after {} trials without improvement",
338                                self.config.early_stopping_patience
339                            );
340                            break;
341                        }
342                    }
343
344                    results.push(trial_result);
345                }
346                Err(e) => {
347                    warn!("Trial {} failed: {}", trial.trial_id, e);
348                    results.push(TrialResult {
349                        trial: trial.clone(),
350                        metrics: HashMap::new(),
351                        cross_validation_scores: Vec::new(),
352                        training_time: Duration::from_secs(0),
353                        evaluation_time: Duration::from_secs(0),
354                        memory_peak_usage: 0,
355                        error_message: Some(e.to_string()),
356                        success: false,
357                    });
358                }
359            }
360        }
361
362        // Store results in history
363        {
364            let mut history = self
365                .trial_history
366                .write()
367                .expect("trial_history lock should not be poisoned");
368            history.extend(results.clone());
369        }
370
371        // Generate final results
372        let best_trial = self
373            .best_trial
374            .read()
375            .expect("best_trial lock should not be poisoned");
376        let best_configuration = best_trial
377            .as_ref()
378            .map(|r| r.trial.clone())
379            .unwrap_or_else(|| trials[0].clone());
380
381        let best_metrics = best_trial
382            .as_ref()
383            .map(|r| r.metrics.clone())
384            .unwrap_or_default();
385
386        let pareto_frontier = self.compute_pareto_frontier(&results);
387        let improvement_curve = self.compute_improvement_curve(&results);
388
389        Ok(AutoMLResults {
390            best_configuration,
391            best_metrics,
392            pareto_frontier,
393            optimization_history: results,
394            total_optimization_time: start_time.elapsed(),
395            trials_completed: trials.len(),
396            improvement_curve,
397        })
398    }
399
400    /// Generate optimization trials based on search space
401    fn generate_optimization_trials(&self) -> Result<Vec<OptimizationTrial>> {
402        let mut trials = Vec::new();
403
404        // Grid search over key parameters
405        for embedding_strategy in &self.config.search_space.embedding_strategies {
406            for &vector_dimension in &self.config.search_space.vector_dimensions {
407                for similarity_metric in &self.config.search_space.similarity_metrics {
408                    for &learning_rate in &self.config.search_space.learning_rates {
409                        for &batch_size in &self.config.search_space.batch_sizes {
410                            let trial = OptimizationTrial {
411                                trial_id: Uuid::new_v4().to_string(),
412                                embedding_strategy: embedding_strategy.clone(),
413                                vector_dimension,
414                                similarity_metric: *similarity_metric,
415                                index_config: self.generate_index_config()?,
416                                hyperparameters: {
417                                    let mut params = HashMap::new();
418                                    params.insert("learning_rate".to_string(), learning_rate);
419                                    params.insert("batch_size".to_string(), batch_size as f32);
420                                    params
421                                },
422                                timestamp: std::time::SystemTime::now()
423                                    .duration_since(std::time::UNIX_EPOCH)
424                                    .unwrap_or_default()
425                                    .as_secs(),
426                            };
427                            trials.push(trial);
428                        }
429                    }
430                }
431            }
432        }
433
434        // Add random search trials for exploration
435        for _ in 0..20 {
436            trials.push(self.generate_random_trial()?);
437        }
438
439        Ok(trials)
440    }
441
442    /// Execute a single optimization trial
443    async fn execute_trial(
444        &self,
445        trial: &OptimizationTrial,
446        training_data: &[(String, String)],
447        validation_data: &[(String, String)],
448        test_queries: &[(String, Vec<String>)],
449    ) -> Result<TrialResult> {
450        let trial_start = Instant::now();
451
452        // Record trial start
453        {
454            let mut state = self.optimization_state.lock().await;
455            state
456                .active_trials
457                .insert(trial.trial_id.clone(), trial_start);
458        }
459
460        // Create vector store with trial configuration
461        let mut vector_store = self.create_vector_store_for_trial(trial)?;
462
463        // Training phase
464        let training_start = Instant::now();
465        for (id, content) in training_data {
466            vector_store
467                .index_resource(id.clone(), content)
468                .context("Failed to index training data")?;
469        }
470        let training_time = training_start.elapsed();
471
472        // Cross-validation evaluation
473        let cv_scores = self
474            .perform_cross_validation(&vector_store, validation_data, test_queries)
475            .await?;
476
477        // Final evaluation
478        let eval_start = Instant::now();
479        let metrics = self
480            .evaluate_trial_performance(&vector_store, test_queries, trial)
481            .await?;
482        let evaluation_time = eval_start.elapsed();
483
484        // Estimate memory usage (simplified)
485        let memory_peak_usage = self.estimate_memory_usage(&vector_store, trial)?;
486
487        // Clean up trial state
488        {
489            let mut state = self.optimization_state.lock().await;
490            state.active_trials.remove(&trial.trial_id);
491        }
492
493        Ok(TrialResult {
494            trial: trial.clone(),
495            metrics,
496            cross_validation_scores: cv_scores,
497            training_time,
498            evaluation_time,
499            memory_peak_usage,
500            error_message: None,
501            success: true,
502        })
503    }
504
505    /// Perform cross-validation for trial evaluation
506    async fn perform_cross_validation(
507        &self,
508        vector_store: &VectorStore,
509        validation_data: &[(String, String)],
510        test_queries: &[(String, Vec<String>)],
511    ) -> Result<Vec<f32>> {
512        let fold_size = validation_data.len() / self.config.cross_validation_folds;
513        let mut cv_scores = Vec::new();
514
515        for fold in 0..self.config.cross_validation_folds {
516            let _start_idx = fold * fold_size;
517            let _end_idx = if fold == self.config.cross_validation_folds - 1 {
518                validation_data.len()
519            } else {
520                (fold + 1) * fold_size
521            };
522
523            // Use this fold as test set
524            let fold_queries: Vec<_> = test_queries
525                .iter()
526                .filter(|(query_id, _)| {
527                    // Simple hash-based assignment
528                    let hash = query_id.chars().map(|c| c as u32).sum::<u32>() as usize;
529                    let fold_idx = hash % self.config.cross_validation_folds;
530                    fold_idx == fold
531                })
532                .cloned()
533                .collect();
534
535            if fold_queries.is_empty() {
536                cv_scores.push(0.0);
537                continue;
538            }
539
540            // Evaluate on this fold
541            let mut total_recall = 0.0;
542            for (query, relevant_docs) in &fold_queries {
543                let search_results = vector_store.similarity_search(query, 10)?;
544                let retrieved_docs: Vec<String> =
545                    search_results.iter().map(|r| r.0.clone()).collect();
546
547                let recall = self.compute_recall(&retrieved_docs, relevant_docs);
548                total_recall += recall;
549            }
550
551            let avg_recall = total_recall / fold_queries.len() as f32;
552            cv_scores.push(avg_recall);
553        }
554
555        Ok(cv_scores)
556    }
557
558    /// Evaluate trial performance metrics
559    async fn evaluate_trial_performance(
560        &self,
561        vector_store: &VectorStore,
562        test_queries: &[(String, Vec<String>)],
563        trial: &OptimizationTrial,
564    ) -> Result<HashMap<OptimizationMetric, f32>> {
565        let mut metrics = HashMap::new();
566
567        // Accuracy metrics
568        let mut total_recall = 0.0;
569        let mut total_precision = 0.0;
570        let mut total_latency = 0.0;
571
572        for (query, relevant_docs) in test_queries {
573            let query_start = Instant::now();
574            let search_results = vector_store.similarity_search(query, 10)?;
575            let query_latency = query_start.elapsed().as_millis() as f32;
576
577            total_latency += query_latency;
578
579            let retrieved_docs: Vec<String> = search_results.iter().map(|r| r.0.clone()).collect();
580
581            let recall = self.compute_recall(&retrieved_docs, relevant_docs);
582            let precision = self.compute_precision(&retrieved_docs, relevant_docs);
583
584            total_recall += recall;
585            total_precision += precision;
586        }
587
588        let num_queries = test_queries.len() as f32;
589        metrics.insert(
590            OptimizationMetric::Accuracy,
591            (total_recall + total_precision) / (2.0 * num_queries),
592        );
593        metrics.insert(OptimizationMetric::Latency, total_latency / num_queries);
594
595        // Throughput (queries per second)
596        let avg_latency_seconds = (total_latency / num_queries) / 1000.0;
597        metrics.insert(OptimizationMetric::Throughput, 1.0 / avg_latency_seconds);
598
599        // Memory and storage efficiency (simplified estimates)
600        metrics.insert(
601            OptimizationMetric::MemoryUsage,
602            (trial.vector_dimension as f32) * 4.0,
603        ); // 4 bytes per f32
604        metrics.insert(
605            OptimizationMetric::StorageEfficiency,
606            1.0 / (trial.vector_dimension as f32).log2(),
607        );
608
609        // Index build time (estimated from hyperparameters)
610        let build_time_estimate = match trial.embedding_strategy {
611            EmbeddingStrategy::TfIdf => 100.0,
612            EmbeddingStrategy::SentenceTransformer => 1000.0,
613            _ => 500.0,
614        };
615        metrics.insert(OptimizationMetric::IndexBuildTime, build_time_estimate);
616
617        // Embedding quality (simplified metric)
618        let embedding_quality = 1.0 - (1.0 / (trial.vector_dimension as f32).sqrt());
619        metrics.insert(OptimizationMetric::EmbeddingQuality, embedding_quality);
620
621        Ok(metrics)
622    }
623
624    /// Generate a random trial for exploration
625    fn generate_random_trial(&self) -> Result<OptimizationTrial> {
626        use std::collections::hash_map::DefaultHasher;
627        use std::hash::{Hash, Hasher};
628
629        let mut hasher = DefaultHasher::new();
630        Uuid::new_v4().hash(&mut hasher);
631        let random_seed = hasher.finish();
632
633        let search_space = &self.config.search_space;
634
635        Ok(OptimizationTrial {
636            trial_id: Uuid::new_v4().to_string(),
637            embedding_strategy: search_space.embedding_strategies
638                [(random_seed % search_space.embedding_strategies.len() as u64) as usize]
639                .clone(),
640            vector_dimension: search_space.vector_dimensions
641                [((random_seed >> 8) % search_space.vector_dimensions.len() as u64) as usize],
642            similarity_metric: search_space.similarity_metrics
643                [((random_seed >> 16) % search_space.similarity_metrics.len() as u64) as usize],
644            index_config: self.generate_index_config()?,
645            hyperparameters: {
646                let mut params = HashMap::new();
647                let lr_idx =
648                    ((random_seed >> 24) % search_space.learning_rates.len() as u64) as usize;
649                let bs_idx = ((random_seed >> 32) % search_space.batch_sizes.len() as u64) as usize;
650                params.insert(
651                    "learning_rate".to_string(),
652                    search_space.learning_rates[lr_idx],
653                );
654                params.insert(
655                    "batch_size".to_string(),
656                    search_space.batch_sizes[bs_idx] as f32,
657                );
658                params
659            },
660            timestamp: std::time::SystemTime::now()
661                .duration_since(std::time::UNIX_EPOCH)
662                .unwrap_or_default()
663                .as_secs(),
664        })
665    }
666
667    /// Generate index configuration for trial
668    fn generate_index_config(&self) -> Result<IndexConfiguration> {
669        Ok(IndexConfiguration {
670            index_type: "hnsw".to_string(),
671            parameters: {
672                let mut params = HashMap::new();
673                params.insert("m".to_string(), 16.0);
674                params.insert("ef_construction".to_string(), 200.0);
675                params.insert("ef_search".to_string(), 100.0);
676                params
677            },
678        })
679    }
680
681    /// Create vector store for trial configuration
682    fn create_vector_store_for_trial(&self, trial: &OptimizationTrial) -> Result<VectorStore> {
683        VectorStore::with_embedding_strategy(trial.embedding_strategy.clone())
684            .context("Failed to create vector store for trial")
685    }
686
687    /// Compute primary optimization score
688    fn compute_primary_score(&self, metrics: &HashMap<OptimizationMetric, f32>) -> f32 {
689        let mut score = 0.0;
690
691        // Weighted combination of metrics
692        if let Some(&accuracy) = metrics.get(&OptimizationMetric::Accuracy) {
693            score += accuracy * 0.4;
694        }
695        if let Some(&latency) = metrics.get(&OptimizationMetric::Latency) {
696            // Lower latency is better, so invert
697            score += (1.0 / (1.0 + latency / 1000.0)) * 0.3;
698        }
699        if let Some(&throughput) = metrics.get(&OptimizationMetric::Throughput) {
700            score += (throughput / 100.0).min(1.0) * 0.3;
701        }
702
703        score
704    }
705
706    /// Compute Pareto frontier from trial results
707    fn compute_pareto_frontier(&self, results: &[TrialResult]) -> Vec<TrialResult> {
708        let mut frontier = Vec::new();
709
710        for result in results {
711            if !result.success {
712                continue;
713            }
714
715            let is_dominated = results.iter().any(|other| {
716                if !other.success || other.trial.trial_id == result.trial.trial_id {
717                    return false;
718                }
719
720                // Check if other dominates result
721                let mut better_in_all = true;
722                let mut better_in_some = false;
723
724                for metric in &self.config.optimization_metrics {
725                    let result_val = result.metrics.get(metric).unwrap_or(&0.0);
726                    let other_val = other.metrics.get(metric).unwrap_or(&0.0);
727
728                    match metric {
729                        OptimizationMetric::Latency | OptimizationMetric::MemoryUsage => {
730                            // Lower is better
731                            if other_val > result_val {
732                                better_in_all = false;
733                            } else if other_val < result_val {
734                                better_in_some = true;
735                            }
736                        }
737                        _ => {
738                            // Higher is better
739                            if other_val < result_val {
740                                better_in_all = false;
741                            } else if other_val > result_val {
742                                better_in_some = true;
743                            }
744                        }
745                    }
746                }
747
748                better_in_all && better_in_some
749            });
750
751            if !is_dominated {
752                frontier.push(result.clone());
753            }
754        }
755
756        frontier
757    }
758
759    /// Compute improvement curve over trials
760    fn compute_improvement_curve(&self, results: &[TrialResult]) -> Vec<(usize, f32)> {
761        let mut curve = Vec::new();
762        let mut best_score = f32::NEG_INFINITY;
763
764        for (i, result) in results.iter().enumerate() {
765            if result.success {
766                let score = self.compute_primary_score(&result.metrics);
767                if score > best_score {
768                    best_score = score;
769                }
770            }
771            curve.push((i, best_score));
772        }
773
774        curve
775    }
776
777    /// Estimate memory usage for trial
778    fn estimate_memory_usage(
779        &self,
780        _vector_store: &VectorStore,
781        trial: &OptimizationTrial,
782    ) -> Result<usize> {
783        // Simplified memory estimation
784        let base_memory = 100 * 1024 * 1024; // 100MB base
785        let vector_memory = trial.vector_dimension * 4; // 4 bytes per f32
786        let index_overhead = vector_memory / 2; // 50% overhead for index structures
787
788        Ok(base_memory + vector_memory + index_overhead)
789    }
790
791    /// Compute recall@k metric
792    fn compute_recall(&self, retrieved: &[String], relevant: &[String]) -> f32 {
793        if relevant.is_empty() {
794            return 1.0;
795        }
796
797        let relevant_set: std::collections::HashSet<_> = relevant.iter().collect();
798        let retrieved_relevant = retrieved
799            .iter()
800            .filter(|doc| relevant_set.contains(doc))
801            .count();
802
803        retrieved_relevant as f32 / relevant.len() as f32
804    }
805
806    /// Compute precision@k metric
807    fn compute_precision(&self, retrieved: &[String], relevant: &[String]) -> f32 {
808        if retrieved.is_empty() {
809            return 0.0;
810        }
811
812        let relevant_set: std::collections::HashSet<_> = relevant.iter().collect();
813        let retrieved_relevant = retrieved
814            .iter()
815            .filter(|doc| relevant_set.contains(doc))
816            .count();
817
818        retrieved_relevant as f32 / retrieved.len() as f32
819    }
820
821    /// Get optimization statistics
822    pub fn get_optimization_statistics(&self) -> AutoMLStatistics {
823        let history = self
824            .trial_history
825            .read()
826            .expect("trial_history lock should not be poisoned");
827        let best_trial = self
828            .best_trial
829            .read()
830            .expect("best_trial lock should not be poisoned");
831
832        let total_trials = history.len();
833        let successful_trials = history.iter().filter(|r| r.success).count();
834        let average_trial_time = if !history.is_empty() {
835            history
836                .iter()
837                .map(|r| r.training_time + r.evaluation_time)
838                .sum::<Duration>()
839                .as_secs_f32()
840                / history.len() as f32
841        } else {
842            0.0
843        };
844
845        let best_score = best_trial
846            .as_ref()
847            .map(|r| self.compute_primary_score(&r.metrics))
848            .unwrap_or(0.0);
849
850        AutoMLStatistics {
851            total_trials,
852            successful_trials,
853            best_score,
854            average_trial_time,
855            optimization_metrics: self.config.optimization_metrics.clone(),
856            search_space_size: self.estimate_search_space_size(),
857        }
858    }
859
860    fn estimate_search_space_size(&self) -> usize {
861        let space = &self.config.search_space;
862        space.embedding_strategies.len()
863            * space.vector_dimensions.len()
864            * space.similarity_metrics.len()
865            * space.learning_rates.len()
866            * space.batch_sizes.len()
867    }
868}
869
870/// AutoML optimization statistics
871#[derive(Debug, Clone, Serialize, Deserialize)]
872pub struct AutoMLStatistics {
873    pub total_trials: usize,
874    pub successful_trials: usize,
875    pub best_score: f32,
876    pub average_trial_time: f32,
877    pub optimization_metrics: Vec<OptimizationMetric>,
878    pub search_space_size: usize,
879}
880
881#[cfg(test)]
882mod tests {
883    use super::*;
884
885    #[test]
886    fn test_automl_config_creation() {
887        let config = AutoMLConfig::default();
888        assert_eq!(config.cross_validation_folds, 5);
889        assert_eq!(config.early_stopping_patience, 10);
890        assert!(config.enable_parallel_optimization);
891    }
892
893    #[test]
894    fn test_search_space_default() {
895        let search_space = SearchSpace::default();
896        assert!(!search_space.embedding_strategies.is_empty());
897        assert!(!search_space.vector_dimensions.is_empty());
898        assert!(!search_space.similarity_metrics.is_empty());
899    }
900
901    #[tokio::test]
902    async fn test_automl_optimizer_creation() {
903        let optimizer = AutoMLOptimizer::with_default_config();
904        assert!(optimizer.is_ok());
905    }
906
907    #[test]
908    fn test_trial_generation() {
909        let optimizer = AutoMLOptimizer::with_default_config().unwrap();
910        let trials = optimizer.generate_optimization_trials().unwrap();
911        assert!(!trials.is_empty());
912
913        // Check trial uniqueness
914        let mut trial_ids = std::collections::HashSet::new();
915        for trial in &trials {
916            assert!(trial_ids.insert(trial.trial_id.clone()));
917        }
918    }
919
920    #[tokio::test]
921    async fn test_optimization_with_sample_data() {
922        let _optimizer = AutoMLOptimizer::with_default_config().unwrap();
923
924        let training_data = vec![
925            (
926                "doc1".to_string(),
927                "artificial intelligence machine learning".to_string(),
928            ),
929            (
930                "doc2".to_string(),
931                "deep learning neural networks".to_string(),
932            ),
933            (
934                "doc3".to_string(),
935                "natural language processing".to_string(),
936            ),
937        ];
938
939        let validation_data = vec![
940            (
941                "doc4".to_string(),
942                "computer vision image recognition".to_string(),
943            ),
944            (
945                "doc5".to_string(),
946                "reinforcement learning algorithms".to_string(),
947            ),
948        ];
949
950        let test_queries = vec![
951            (
952                "ai query".to_string(),
953                vec!["doc1".to_string(), "doc2".to_string()],
954            ),
955            ("nlp query".to_string(), vec!["doc3".to_string()]),
956        ];
957
958        // Use a very short optimization time for testing
959        let config = AutoMLConfig {
960            max_optimization_time: Duration::from_secs(1),
961            trials_per_config: 1,
962            ..Default::default()
963        };
964
965        let optimizer = AutoMLOptimizer::new(config).unwrap();
966        let results = optimizer
967            .optimize_embeddings(&training_data, &validation_data, &test_queries)
968            .await;
969
970        // Test should complete without errors even if no trials complete
971        assert!(results.is_ok());
972    }
973
974    #[test]
975    fn test_pareto_frontier_computation() {
976        let optimizer = AutoMLOptimizer::with_default_config().unwrap();
977
978        let trial1 = OptimizationTrial {
979            trial_id: "trial1".to_string(),
980            embedding_strategy: EmbeddingStrategy::TfIdf,
981            vector_dimension: 128,
982            similarity_metric: SimilarityMetric::Cosine,
983            index_config: optimizer.generate_index_config().unwrap(),
984            hyperparameters: HashMap::new(),
985            timestamp: 0,
986        };
987
988        let trial2 = OptimizationTrial {
989            trial_id: "trial2".to_string(),
990            embedding_strategy: EmbeddingStrategy::SentenceTransformer,
991            vector_dimension: 256,
992            similarity_metric: SimilarityMetric::Euclidean,
993            index_config: optimizer.generate_index_config().unwrap(),
994            hyperparameters: HashMap::new(),
995            timestamp: 0,
996        };
997
998        let mut metrics1 = HashMap::new();
999        metrics1.insert(OptimizationMetric::Accuracy, 0.8);
1000        metrics1.insert(OptimizationMetric::Latency, 100.0);
1001
1002        let mut metrics2 = HashMap::new();
1003        metrics2.insert(OptimizationMetric::Accuracy, 0.9);
1004        metrics2.insert(OptimizationMetric::Latency, 200.0);
1005
1006        let results = vec![
1007            TrialResult {
1008                trial: trial1,
1009                metrics: metrics1,
1010                cross_validation_scores: vec![0.8],
1011                training_time: Duration::from_secs(1),
1012                evaluation_time: Duration::from_secs(1),
1013                memory_peak_usage: 1000,
1014                error_message: None,
1015                success: true,
1016            },
1017            TrialResult {
1018                trial: trial2,
1019                metrics: metrics2,
1020                cross_validation_scores: vec![0.9],
1021                training_time: Duration::from_secs(2),
1022                evaluation_time: Duration::from_secs(1),
1023                memory_peak_usage: 2000,
1024                error_message: None,
1025                success: true,
1026            },
1027        ];
1028
1029        let frontier = optimizer.compute_pareto_frontier(&results);
1030        assert_eq!(frontier.len(), 2); // Both trials should be on frontier
1031    }
1032
1033    #[test]
1034    fn test_recall_precision_computation() {
1035        let optimizer = AutoMLOptimizer::with_default_config().unwrap();
1036
1037        let retrieved = vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()];
1038        let relevant = vec!["doc1".to_string(), "doc3".to_string(), "doc4".to_string()];
1039
1040        let recall = optimizer.compute_recall(&retrieved, &relevant);
1041        let precision = optimizer.compute_precision(&retrieved, &relevant);
1042
1043        assert_eq!(recall, 2.0 / 3.0); // 2 out of 3 relevant docs retrieved
1044        assert_eq!(precision, 2.0 / 3.0); // 2 out of 3 retrieved docs are relevant
1045    }
1046
1047    #[test]
1048    fn test_optimization_statistics() {
1049        let optimizer = AutoMLOptimizer::with_default_config().unwrap();
1050        let stats = optimizer.get_optimization_statistics();
1051
1052        assert_eq!(stats.total_trials, 0);
1053        assert_eq!(stats.successful_trials, 0);
1054        assert!(stats.search_space_size > 0);
1055    }
1056}