quantrs2_circuit/
ml_optimization.rs

1//! Machine Learning-based circuit optimization
2//!
3//! This module provides ML-driven optimization techniques for quantum circuits,
4//! including reinforcement learning for gate scheduling, neural networks for
5//! pattern recognition, and automated hyperparameter tuning.
6
7use crate::builder::Circuit;
8use crate::dag::{circuit_to_dag, CircuitDag};
9use crate::scirs2_integration::{AnalysisResult, SciRS2CircuitAnalyzer};
10use quantrs2_core::{
11    error::{QuantRS2Error, QuantRS2Result},
12    gate::GateOp,
13    qubit::QubitId,
14};
15use serde::{Deserialize, Serialize};
16use std::collections::{HashMap, VecDeque};
17use std::sync::{Arc, Mutex};
18use std::time::{Duration, Instant};
19
20/// ML optimization strategy
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22pub enum MLStrategy {
23    /// Reinforcement learning for gate scheduling
24    ReinforcementLearning {
25        /// Q-learning parameters
26        learning_rate: f64,
27        discount_factor: f64,
28        exploration_rate: f64,
29        /// Number of training episodes
30        episodes: usize,
31    },
32    /// Neural network for pattern recognition
33    NeuralNetwork {
34        /// Network architecture (layer sizes)
35        architecture: Vec<usize>,
36        /// Learning rate
37        learning_rate: f64,
38        /// Number of training epochs
39        epochs: usize,
40        /// Batch size
41        batch_size: usize,
42    },
43    /// Genetic algorithm for optimization
44    GeneticAlgorithm {
45        /// Population size
46        population_size: usize,
47        /// Number of generations
48        generations: usize,
49        /// Mutation rate
50        mutation_rate: f64,
51        /// Selection pressure
52        selection_pressure: f64,
53    },
54    /// Bayesian optimization for hyperparameter tuning
55    BayesianOptimization {
56        /// Number of initial random samples
57        initial_samples: usize,
58        /// Number of optimization iterations
59        iterations: usize,
60        /// Acquisition function
61        acquisition: AcquisitionFunction,
62    },
63}
64
65/// Acquisition functions for Bayesian optimization
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67pub enum AcquisitionFunction {
68    ExpectedImprovement,
69    UpperConfidenceBound { beta: f64 },
70    ProbabilityOfImprovement,
71}
72
73/// Circuit representation for ML algorithms
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct MLCircuitRepresentation {
76    /// Feature vector representing the circuit
77    pub features: Vec<f64>,
78    /// Gate sequence encoding
79    pub gate_sequence: Vec<usize>,
80    /// Adjacency matrix
81    pub adjacency_matrix: Vec<Vec<f64>>,
82    /// Qubit connectivity
83    pub qubit_connectivity: Vec<Vec<bool>>,
84    /// Circuit metrics
85    pub metrics: CircuitMetrics,
86}
87
88/// Circuit metrics for ML training
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct CircuitMetrics {
91    /// Circuit depth
92    pub depth: usize,
93    /// Gate count
94    pub gate_count: usize,
95    /// Two-qubit gate count
96    pub two_qubit_gate_count: usize,
97    /// Entanglement measure
98    pub entanglement_measure: f64,
99    /// Critical path length
100    pub critical_path_length: usize,
101    /// Parallelization potential
102    pub parallelization_potential: f64,
103}
104
105/// ML-based optimizer
106pub struct MLCircuitOptimizer {
107    /// Optimization strategy
108    strategy: MLStrategy,
109    /// Feature extractor
110    feature_extractor: Arc<Mutex<FeatureExtractor>>,
111    /// Model storage
112    models: Arc<Mutex<HashMap<String, MLModel>>>,
113    /// Training data
114    training_data: Arc<Mutex<Vec<TrainingExample>>>,
115    /// Configuration
116    config: MLOptimizerConfig,
117}
118
119/// ML optimizer configuration
120#[derive(Debug, Clone)]
121pub struct MLOptimizerConfig {
122    /// Enable feature caching
123    pub cache_features: bool,
124    /// Maximum training examples to keep
125    pub max_training_examples: usize,
126    /// Model update frequency
127    pub model_update_frequency: usize,
128    /// Enable parallel training
129    pub parallel_training: bool,
130    /// Feature selection threshold
131    pub feature_selection_threshold: f64,
132}
133
134impl Default for MLOptimizerConfig {
135    fn default() -> Self {
136        Self {
137            cache_features: true,
138            max_training_examples: 10000,
139            model_update_frequency: 100,
140            parallel_training: true,
141            feature_selection_threshold: 0.01,
142        }
143    }
144}
145
146/// Training example for supervised learning
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct TrainingExample {
149    /// Input circuit representation
150    pub input: MLCircuitRepresentation,
151    /// Target optimization result
152    pub target: OptimizationTarget,
153    /// Quality score
154    pub score: f64,
155    /// Metadata
156    pub metadata: HashMap<String, String>,
157}
158
159/// Optimization target
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub enum OptimizationTarget {
162    /// Minimize circuit depth
163    MinimizeDepth { target_depth: usize },
164    /// Minimize gate count
165    MinimizeGates { target_count: usize },
166    /// Maximize parallelization
167    MaximizeParallelization { target_parallel_fraction: f64 },
168    /// Custom objective
169    Custom {
170        objective: String,
171        target_value: f64,
172    },
173}
174
175/// Feature extractor for circuits
176pub struct FeatureExtractor {
177    /// `SciRS2` analyzer for graph features
178    analyzer: SciRS2CircuitAnalyzer,
179    /// Feature cache
180    cache: HashMap<String, Vec<f64>>,
181    /// Feature importance weights
182    feature_weights: Vec<f64>,
183}
184
185impl Default for FeatureExtractor {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191impl FeatureExtractor {
192    /// Create a new feature extractor
193    #[must_use]
194    pub fn new() -> Self {
195        Self {
196            analyzer: SciRS2CircuitAnalyzer::new(),
197            cache: HashMap::new(),
198            feature_weights: Vec::new(),
199        }
200    }
201
202    /// Extract features from a circuit
203    pub fn extract_features<const N: usize>(
204        &mut self,
205        circuit: &Circuit<N>,
206    ) -> QuantRS2Result<Vec<f64>> {
207        // Generate cache key
208        let cache_key = self.generate_cache_key(circuit);
209
210        // Check cache
211        if let Some(features) = self.cache.get(&cache_key) {
212            return Ok(features.clone());
213        }
214
215        // Extract features
216        let mut features = Vec::new();
217
218        // Basic circuit features
219        features.extend(self.extract_basic_features(circuit));
220
221        // Graph-based features using SciRS2
222        features.extend(self.extract_graph_features(circuit)?);
223
224        // Gate pattern features
225        features.extend(self.extract_pattern_features(circuit));
226
227        // Entanglement features
228        features.extend(self.extract_entanglement_features(circuit));
229
230        // Cache the result
231        self.cache.insert(cache_key, features.clone());
232
233        Ok(features)
234    }
235
236    /// Extract basic circuit features
237    fn extract_basic_features<const N: usize>(&self, circuit: &Circuit<N>) -> Vec<f64> {
238        let gates = circuit.gates();
239        let depth = circuit.gates().len() as f64;
240        let gate_count = gates.len() as f64;
241
242        // Gate type distribution
243        let mut gate_types = HashMap::new();
244        for gate in gates {
245            *gate_types.entry(gate.name().to_string()).or_insert(0.0) += 1.0;
246        }
247
248        let h_count = gate_types.get("H").copied().unwrap_or(0.0);
249        let cnot_count = gate_types.get("CNOT").copied().unwrap_or(0.0);
250        let single_qubit_count = gate_count - cnot_count;
251
252        vec![
253            depth,
254            gate_count,
255            single_qubit_count,
256            cnot_count,
257            h_count,
258            depth / gate_count.max(1.0),      // Density
259            cnot_count / gate_count.max(1.0), // Two-qubit fraction
260            N as f64,                         // Number of qubits
261        ]
262    }
263
264    /// Extract graph-based features using `SciRS2`
265    fn extract_graph_features<const N: usize>(
266        &mut self,
267        circuit: &Circuit<N>,
268    ) -> QuantRS2Result<Vec<f64>> {
269        let analysis = self.analyzer.analyze_circuit(circuit)?;
270
271        let metrics = &analysis.metrics;
272
273        Ok(vec![
274            metrics.num_nodes as f64,
275            metrics.num_edges as f64,
276            metrics.density,
277            metrics.clustering_coefficient,
278            metrics.connected_components as f64,
279            metrics.diameter.unwrap_or(0) as f64,
280            metrics.average_path_length.unwrap_or(0.0),
281            analysis.communities.len() as f64,
282            analysis.critical_paths.len() as f64,
283        ])
284    }
285
286    /// Extract gate pattern features
287    fn extract_pattern_features<const N: usize>(&self, circuit: &Circuit<N>) -> Vec<f64> {
288        let gates = circuit.gates();
289        let mut features = Vec::new();
290
291        // Sequential patterns
292        let mut h_cnot_patterns = 0.0;
293        let mut cnot_chains = 0.0;
294
295        for window in gates.windows(2) {
296            if window.len() == 2 {
297                let gate1 = &window[0];
298                let gate2 = &window[1];
299
300                if gate1.name() == "H" && gate2.name() == "CNOT" {
301                    h_cnot_patterns += 1.0;
302                }
303
304                if gate1.name() == "CNOT" && gate2.name() == "CNOT" {
305                    cnot_chains += 1.0;
306                }
307            }
308        }
309
310        features.push(h_cnot_patterns);
311        features.push(cnot_chains);
312
313        // Qubit usage patterns
314        let mut qubit_usage = vec![0.0; N];
315        for gate in gates {
316            for qubit in gate.qubits() {
317                qubit_usage[qubit.id() as usize] += 1.0;
318            }
319        }
320
321        let max_usage: f64 = qubit_usage.iter().fold(0.0_f64, |a, &b| a.max(b));
322        let min_usage = qubit_usage.iter().fold(f64::INFINITY, |a, &b| a.min(b));
323        let avg_usage = qubit_usage.iter().sum::<f64>() / N as f64;
324
325        features.push(max_usage);
326        features.push(min_usage);
327        features.push(avg_usage);
328        features.push(max_usage - min_usage); // Usage variance
329
330        features
331    }
332
333    /// Extract entanglement-related features
334    fn extract_entanglement_features<const N: usize>(&self, circuit: &Circuit<N>) -> Vec<f64> {
335        let gates = circuit.gates();
336        let mut features = Vec::new();
337
338        // Entangling gate distribution
339        let cnot_gates: Vec<_> = gates.iter().filter(|gate| gate.name() == "CNOT").collect();
340
341        // Connectivity graph
342        let mut connectivity = vec![vec![false; N]; N];
343        for gate in &cnot_gates {
344            let qubits = gate.qubits();
345            if qubits.len() == 2 {
346                let q1 = qubits[0].id() as usize;
347                let q2 = qubits[1].id() as usize;
348                connectivity[q1][q2] = true;
349                connectivity[q2][q1] = true;
350            }
351        }
352
353        // Calculate connectivity features
354        let total_connections: f64 = connectivity
355            .iter()
356            .flat_map(|row| row.iter())
357            .map(|&connected| if connected { 1.0 } else { 0.0 })
358            .sum();
359
360        let max_connections = N * (N - 1);
361        let connectivity_ratio = total_connections / max_connections as f64;
362
363        features.push(cnot_gates.len() as f64);
364        features.push(connectivity_ratio);
365
366        // Star connectivity (one qubit connected to many)
367        let max_degree = connectivity
368            .iter()
369            .map(|row| row.iter().filter(|&&c| c).count())
370            .max()
371            .unwrap_or(0) as f64;
372
373        features.push(max_degree);
374
375        features
376    }
377
378    /// Generate cache key for a circuit
379    fn generate_cache_key<const N: usize>(&self, circuit: &Circuit<N>) -> String {
380        use std::collections::hash_map::DefaultHasher;
381        use std::hash::{Hash, Hasher};
382
383        let mut hasher = DefaultHasher::new();
384
385        N.hash(&mut hasher);
386        circuit.gates().len().hash(&mut hasher);
387
388        for gate in circuit.gates() {
389            gate.name().hash(&mut hasher);
390            for qubit in gate.qubits() {
391                qubit.id().hash(&mut hasher);
392            }
393        }
394
395        format!("{:x}", hasher.finish())
396    }
397}
398
399/// ML model abstraction
400#[derive(Debug, Clone)]
401pub enum MLModel {
402    /// Linear regression model
403    LinearRegression { weights: Vec<f64>, bias: f64 },
404    /// Neural network model
405    NeuralNetwork {
406        layers: Vec<Layer>,
407        learning_rate: f64,
408    },
409    /// Q-learning model
410    QLearning {
411        q_table: HashMap<String, HashMap<String, f64>>,
412        learning_rate: f64,
413        discount_factor: f64,
414    },
415    /// Random forest model
416    RandomForest {
417        trees: Vec<DecisionTree>,
418        num_trees: usize,
419    },
420}
421
422/// Neural network layer
423#[derive(Debug, Clone)]
424pub struct Layer {
425    pub weights: Vec<Vec<f64>>,
426    pub biases: Vec<f64>,
427    pub activation: ActivationFunction,
428}
429
430/// Activation functions
431#[derive(Debug, Clone)]
432pub enum ActivationFunction {
433    ReLU,
434    Sigmoid,
435    Tanh,
436    Linear,
437}
438
439/// Decision tree for random forest
440#[derive(Debug, Clone)]
441pub struct DecisionTree {
442    pub nodes: Vec<TreeNode>,
443    pub root: usize,
444}
445
446/// Tree node
447#[derive(Debug, Clone)]
448pub struct TreeNode {
449    pub feature_index: Option<usize>,
450    pub threshold: Option<f64>,
451    pub value: Option<f64>,
452    pub left_child: Option<usize>,
453    pub right_child: Option<usize>,
454}
455
456impl MLCircuitOptimizer {
457    /// Create a new ML optimizer
458    #[must_use]
459    pub fn new(strategy: MLStrategy) -> Self {
460        Self {
461            strategy,
462            feature_extractor: Arc::new(Mutex::new(FeatureExtractor::new())),
463            models: Arc::new(Mutex::new(HashMap::new())),
464            training_data: Arc::new(Mutex::new(Vec::new())),
465            config: MLOptimizerConfig::default(),
466        }
467    }
468
469    /// Create optimizer with custom configuration
470    #[must_use]
471    pub fn with_config(strategy: MLStrategy, config: MLOptimizerConfig) -> Self {
472        Self {
473            strategy,
474            feature_extractor: Arc::new(Mutex::new(FeatureExtractor::new())),
475            models: Arc::new(Mutex::new(HashMap::new())),
476            training_data: Arc::new(Mutex::new(Vec::new())),
477            config,
478        }
479    }
480
481    /// Optimize a circuit using ML
482    pub fn optimize<const N: usize>(
483        &mut self,
484        circuit: &Circuit<N>,
485    ) -> QuantRS2Result<MLOptimizationResult<N>> {
486        let start_time = Instant::now();
487
488        // Extract features
489        let features = {
490            let mut extractor = self.feature_extractor.lock().map_err(|e| {
491                QuantRS2Error::RuntimeError(format!("Failed to lock feature extractor: {e}"))
492            })?;
493            extractor.extract_features(circuit)?
494        };
495
496        // Apply optimization based on strategy
497        let optimized_circuit = match &self.strategy {
498            MLStrategy::ReinforcementLearning { .. } => {
499                self.optimize_with_rl(circuit, &features)?
500            }
501            MLStrategy::NeuralNetwork { .. } => self.optimize_with_nn(circuit, &features)?,
502            MLStrategy::GeneticAlgorithm { .. } => self.optimize_with_ga(circuit, &features)?,
503            MLStrategy::BayesianOptimization { .. } => {
504                self.optimize_with_bayesian(circuit, &features)?
505            }
506        };
507
508        let optimization_time = start_time.elapsed();
509
510        Ok(MLOptimizationResult {
511            original_circuit: circuit.clone(),
512            optimized_circuit: optimized_circuit.clone(),
513            features,
514            optimization_time,
515            improvement_metrics: self.calculate_improvement_metrics(circuit, &optimized_circuit),
516            strategy_used: self.strategy.clone(),
517        })
518    }
519
520    /// Optimize using reinforcement learning
521    fn optimize_with_rl<const N: usize>(
522        &self,
523        circuit: &Circuit<N>,
524        features: &[f64],
525    ) -> QuantRS2Result<Circuit<N>> {
526        // Simplified RL optimization
527        // In a full implementation, this would:
528        // 1. Define state space (circuit configuration)
529        // 2. Define action space (gate reorderings, fusions, etc.)
530        // 3. Train Q-learning agent
531        // 4. Apply learned policy
532
533        // For now, return a simple optimization
534        let mut optimized = circuit.clone();
535
536        // Simple gate reordering based on features
537        if features.len() > 6 {
538            let cnot_fraction = features[6]; // Two-qubit gate fraction
539            if cnot_fraction > 0.5 {
540                // High CNOT fraction - could benefit from reordering
541                // This is a placeholder for actual RL optimization
542            }
543        }
544
545        Ok(optimized)
546    }
547
548    /// Optimize using neural network
549    fn optimize_with_nn<const N: usize>(
550        &self,
551        circuit: &Circuit<N>,
552        features: &[f64],
553    ) -> QuantRS2Result<Circuit<N>> {
554        // Simplified NN optimization
555        // In a full implementation, this would:
556        // 1. Use trained NN to predict optimal gate sequences
557        // 2. Apply predicted transformations
558        // 3. Validate and refine results
559
560        Ok(circuit.clone())
561    }
562
563    /// Optimize using genetic algorithm
564    fn optimize_with_ga<const N: usize>(
565        &self,
566        circuit: &Circuit<N>,
567        features: &[f64],
568    ) -> QuantRS2Result<Circuit<N>> {
569        // Simplified GA optimization
570        // In a full implementation, this would:
571        // 1. Create initial population of circuit variants
572        // 2. Evaluate fitness of each variant
573        // 3. Apply selection, crossover, and mutation
574        // 4. Evolve over multiple generations
575
576        Ok(circuit.clone())
577    }
578
579    /// Optimize using Bayesian optimization
580    fn optimize_with_bayesian<const N: usize>(
581        &self,
582        circuit: &Circuit<N>,
583        features: &[f64],
584    ) -> QuantRS2Result<Circuit<N>> {
585        // Simplified Bayesian optimization
586        // In a full implementation, this would:
587        // 1. Build Gaussian process model of optimization landscape
588        // 2. Use acquisition function to select next optimization point
589        // 3. Iteratively improve based on observed results
590
591        Ok(circuit.clone())
592    }
593
594    /// Calculate improvement metrics
595    fn calculate_improvement_metrics<const N: usize>(
596        &self,
597        original: &Circuit<N>,
598        optimized: &Circuit<N>,
599    ) -> ImprovementMetrics {
600        let original_depth = original.gates().len();
601        let optimized_depth = optimized.gates().len();
602        let original_gates = original.gates().len();
603        let optimized_gates = optimized.gates().len();
604
605        ImprovementMetrics {
606            depth_reduction: (original_depth as f64 - optimized_depth as f64)
607                / original_depth as f64,
608            gate_reduction: (original_gates as f64 - optimized_gates as f64)
609                / original_gates as f64,
610            compilation_speedup: 1.0,  // Placeholder
611            fidelity_improvement: 0.0, // Placeholder
612        }
613    }
614
615    /// Add training example
616    pub fn add_training_example(&mut self, example: TrainingExample) {
617        let mut data = self
618            .training_data
619            .lock()
620            .expect("Training data mutex poisoned");
621        data.push(example);
622
623        // Maintain maximum size
624        if data.len() > self.config.max_training_examples {
625            data.remove(0);
626        }
627    }
628
629    /// Train models with current data
630    pub fn train_models(&mut self) -> QuantRS2Result<()> {
631        let data = {
632            let training_data = self.training_data.lock().map_err(|e| {
633                QuantRS2Error::RuntimeError(format!("Failed to lock training data: {e}"))
634            })?;
635            training_data.clone()
636        };
637
638        if data.is_empty() {
639            return Err(QuantRS2Error::InvalidInput(
640                "No training data available".to_string(),
641            ));
642        }
643
644        // Train based on strategy
645        let strategy = self.strategy.clone();
646        match strategy {
647            MLStrategy::NeuralNetwork {
648                architecture,
649                learning_rate,
650                epochs,
651                batch_size,
652            } => {
653                self.train_neural_network(&data, &architecture, learning_rate, epochs, batch_size)?;
654            }
655            MLStrategy::ReinforcementLearning {
656                learning_rate,
657                discount_factor,
658                ..
659            } => {
660                self.train_rl_model(&data, learning_rate, discount_factor)?;
661            }
662            _ => {
663                // Other strategies would be implemented here
664            }
665        }
666
667        Ok(())
668    }
669
670    /// Train neural network model
671    fn train_neural_network(
672        &self,
673        data: &[TrainingExample],
674        architecture: &[usize],
675        learning_rate: f64,
676        epochs: usize,
677        batch_size: usize,
678    ) -> QuantRS2Result<()> {
679        // Simplified NN training
680        // In a full implementation, this would implement backpropagation
681
682        let input_size = data.first().map_or(0, |ex| ex.input.features.len());
683
684        // Create network layers
685        let mut layers = Vec::new();
686        let mut prev_size = input_size;
687
688        for &layer_size in architecture {
689            let weights = vec![vec![0.1; prev_size]; layer_size]; // Random initialization
690            let biases = vec![0.0; layer_size];
691
692            layers.push(Layer {
693                weights,
694                biases,
695                activation: ActivationFunction::ReLU,
696            });
697
698            prev_size = layer_size;
699        }
700
701        let model = MLModel::NeuralNetwork {
702            layers,
703            learning_rate,
704        };
705
706        let mut models = self
707            .models
708            .lock()
709            .map_err(|e| QuantRS2Error::RuntimeError(format!("Failed to lock models: {e}")))?;
710        models.insert("neural_network".to_string(), model);
711
712        Ok(())
713    }
714
715    /// Train reinforcement learning model
716    fn train_rl_model(
717        &self,
718        data: &[TrainingExample],
719        learning_rate: f64,
720        discount_factor: f64,
721    ) -> QuantRS2Result<()> {
722        // Simplified Q-learning training
723        let model = MLModel::QLearning {
724            q_table: HashMap::new(),
725            learning_rate,
726            discount_factor,
727        };
728
729        let mut models = self
730            .models
731            .lock()
732            .map_err(|e| QuantRS2Error::RuntimeError(format!("Failed to lock models: {e}")))?;
733        models.insert("q_learning".to_string(), model);
734
735        Ok(())
736    }
737}
738
739/// ML optimization result
740#[derive(Debug, Clone)]
741pub struct MLOptimizationResult<const N: usize> {
742    /// Original circuit
743    pub original_circuit: Circuit<N>,
744    /// Optimized circuit
745    pub optimized_circuit: Circuit<N>,
746    /// Extracted features
747    pub features: Vec<f64>,
748    /// Time taken for optimization
749    pub optimization_time: Duration,
750    /// Improvement metrics
751    pub improvement_metrics: ImprovementMetrics,
752    /// Strategy used
753    pub strategy_used: MLStrategy,
754}
755
756/// Improvement metrics
757#[derive(Debug, Clone)]
758pub struct ImprovementMetrics {
759    /// Relative depth reduction
760    pub depth_reduction: f64,
761    /// Relative gate count reduction
762    pub gate_reduction: f64,
763    /// Compilation speedup factor
764    pub compilation_speedup: f64,
765    /// Fidelity improvement
766    pub fidelity_improvement: f64,
767}
768
769#[cfg(test)]
770mod tests {
771    use super::*;
772    use quantrs2_core::gate::multi::CNOT;
773    use quantrs2_core::gate::single::Hadamard;
774
775    #[test]
776    fn test_feature_extraction() {
777        let mut extractor = FeatureExtractor::new();
778
779        let mut circuit = Circuit::<2>::new();
780        circuit
781            .add_gate(Hadamard { target: QubitId(0) })
782            .expect("Failed to add Hadamard gate");
783        circuit
784            .add_gate(CNOT {
785                control: QubitId(0),
786                target: QubitId(1),
787            })
788            .expect("Failed to add CNOT gate");
789
790        let features = extractor
791            .extract_features(&circuit)
792            .expect("Failed to extract features");
793        assert!(!features.is_empty());
794        assert!(features.len() > 10); // Should have multiple feature categories
795    }
796
797    #[test]
798    fn test_ml_optimizer_creation() {
799        let strategy = MLStrategy::NeuralNetwork {
800            architecture: vec![64, 32, 16],
801            learning_rate: 0.001,
802            epochs: 100,
803            batch_size: 32,
804        };
805
806        let optimizer = MLCircuitOptimizer::new(strategy);
807        assert!(matches!(
808            optimizer.strategy,
809            MLStrategy::NeuralNetwork { .. }
810        ));
811    }
812
813    #[test]
814    fn test_training_example() {
815        let features = vec![1.0, 2.0, 3.0, 4.0];
816        let target = OptimizationTarget::MinimizeDepth { target_depth: 5 };
817
818        let representation = MLCircuitRepresentation {
819            features,
820            gate_sequence: vec![0, 1, 2],
821            adjacency_matrix: vec![vec![0.0, 1.0], vec![1.0, 0.0]],
822            qubit_connectivity: vec![vec![false, true], vec![true, false]],
823            metrics: CircuitMetrics {
824                depth: 3,
825                gate_count: 3,
826                two_qubit_gate_count: 1,
827                entanglement_measure: 0.5,
828                critical_path_length: 3,
829                parallelization_potential: 0.3,
830            },
831        };
832
833        let example = TrainingExample {
834            input: representation,
835            target,
836            score: 0.8,
837            metadata: HashMap::new(),
838        };
839
840        assert!(example.score > 0.0);
841    }
842
843    #[test]
844    fn test_ml_strategies() {
845        let rl_strategy = MLStrategy::ReinforcementLearning {
846            learning_rate: 0.1,
847            discount_factor: 0.9,
848            exploration_rate: 0.1,
849            episodes: 1000,
850        };
851
852        let nn_strategy = MLStrategy::NeuralNetwork {
853            architecture: vec![32, 16, 8],
854            learning_rate: 0.001,
855            epochs: 50,
856            batch_size: 16,
857        };
858
859        assert!(matches!(
860            rl_strategy,
861            MLStrategy::ReinforcementLearning { .. }
862        ));
863        assert!(matches!(nn_strategy, MLStrategy::NeuralNetwork { .. }));
864    }
865
866    #[test]
867    fn test_circuit_representation() {
868        let metrics = CircuitMetrics {
869            depth: 5,
870            gate_count: 10,
871            two_qubit_gate_count: 3,
872            entanglement_measure: 0.7,
873            critical_path_length: 5,
874            parallelization_potential: 0.4,
875        };
876
877        let representation = MLCircuitRepresentation {
878            features: vec![1.0, 2.0, 3.0],
879            gate_sequence: vec![0, 1, 2, 1, 0],
880            adjacency_matrix: vec![vec![0.0; 3]; 3],
881            qubit_connectivity: vec![vec![false; 3]; 3],
882            metrics,
883        };
884
885        assert_eq!(representation.metrics.depth, 5);
886        assert_eq!(representation.gate_sequence.len(), 5);
887    }
888}