quantrs2_ml/torchquantum/
noise.rs

1//! TorchQuantum Noise-Aware Training
2//!
3//! This module provides noise-aware gradient computation and error-mitigated
4//! expectation values for robust quantum machine learning on noisy hardware.
5//!
6//! ## Key Features
7//!
8//! - **NoiseAwareGradient**: Gradients that account for device noise
9//! - **MitigatedExpectation**: Error-mitigated expectation value computation
10//! - **NoiseModel Integration**: Compatible with various noise models
11//! - **Resilient Training**: Training strategies for noisy quantum devices
12
13use crate::error::{MLError, Result};
14use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::Complex64;
16use std::collections::HashMap;
17
18use super::{CType, TQDevice, TQModule, TQParameter};
19
20// ============================================================================
21// Noise Model Types
22// ============================================================================
23
24/// Single-qubit noise channel types
25#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum SingleQubitNoiseType {
27    /// Depolarizing noise with probability p
28    Depolarizing(f64),
29    /// Amplitude damping with decay probability
30    AmplitudeDamping(f64),
31    /// Phase damping with dephasing probability
32    PhaseDamping(f64),
33    /// Bit flip with probability p
34    BitFlip(f64),
35    /// Phase flip with probability p
36    PhaseFlip(f64),
37}
38
39/// Two-qubit noise channel types
40#[derive(Debug, Clone, Copy, PartialEq)]
41pub enum TwoQubitNoiseType {
42    /// Depolarizing noise on two qubits
43    Depolarizing(f64),
44    /// Correlated dephasing
45    CorrelatedDephasing(f64),
46    /// Cross-talk error
47    CrossTalk(f64),
48}
49
50/// Complete noise model for a quantum device
51#[derive(Debug, Clone)]
52pub struct NoiseModel {
53    /// Single-qubit gate errors per qubit
54    pub single_qubit_errors: HashMap<usize, SingleQubitNoiseType>,
55    /// Two-qubit gate errors per qubit pair
56    pub two_qubit_errors: HashMap<(usize, usize), TwoQubitNoiseType>,
57    /// Readout errors per qubit (probability of bit flip during measurement)
58    pub readout_errors: HashMap<usize, f64>,
59    /// Coherence times (T1, T2) per qubit in microseconds
60    pub coherence_times: HashMap<usize, (f64, f64)>,
61    /// Gate times in microseconds
62    pub gate_times: GateTimes,
63    /// Global noise scale factor
64    pub noise_scale: f64,
65}
66
67/// Gate execution times
68#[derive(Debug, Clone)]
69pub struct GateTimes {
70    /// Single-qubit gate time (microseconds)
71    pub single_qubit: f64,
72    /// Two-qubit gate time (microseconds)
73    pub two_qubit: f64,
74    /// Measurement time (microseconds)
75    pub measurement: f64,
76}
77
78impl Default for GateTimes {
79    fn default() -> Self {
80        Self {
81            single_qubit: 0.05, // 50 ns
82            two_qubit: 0.3,     // 300 ns
83            measurement: 1.0,   // 1 us
84        }
85    }
86}
87
88impl NoiseModel {
89    /// Create a noise-free model
90    pub fn ideal() -> Self {
91        Self {
92            single_qubit_errors: HashMap::new(),
93            two_qubit_errors: HashMap::new(),
94            readout_errors: HashMap::new(),
95            coherence_times: HashMap::new(),
96            gate_times: GateTimes::default(),
97            noise_scale: 0.0,
98        }
99    }
100
101    /// Create uniform depolarizing noise model
102    pub fn uniform_depolarizing(n_qubits: usize, p1: f64, p2: f64) -> Self {
103        let mut model = Self::ideal();
104        model.noise_scale = 1.0;
105
106        for q in 0..n_qubits {
107            model
108                .single_qubit_errors
109                .insert(q, SingleQubitNoiseType::Depolarizing(p1));
110        }
111
112        for q1 in 0..n_qubits {
113            for q2 in (q1 + 1)..n_qubits {
114                model
115                    .two_qubit_errors
116                    .insert((q1, q2), TwoQubitNoiseType::Depolarizing(p2));
117            }
118        }
119
120        model
121    }
122
123    /// Create noise model from IBM backend properties
124    pub fn from_ibm_properties(
125        n_qubits: usize,
126        t1_times: &[f64],
127        t2_times: &[f64],
128        single_gate_errors: &[f64],
129        two_gate_errors: &[(usize, usize, f64)],
130        readout_errors: &[f64],
131    ) -> Self {
132        let mut model = Self::ideal();
133        model.noise_scale = 1.0;
134
135        for q in 0..n_qubits {
136            if q < t1_times.len() && q < t2_times.len() {
137                model.coherence_times.insert(q, (t1_times[q], t2_times[q]));
138            }
139
140            if q < single_gate_errors.len() {
141                model
142                    .single_qubit_errors
143                    .insert(q, SingleQubitNoiseType::Depolarizing(single_gate_errors[q]));
144            }
145
146            if q < readout_errors.len() {
147                model.readout_errors.insert(q, readout_errors[q]);
148            }
149        }
150
151        for (q1, q2, err) in two_gate_errors {
152            model
153                .two_qubit_errors
154                .insert((*q1, *q2), TwoQubitNoiseType::Depolarizing(*err));
155        }
156
157        model
158    }
159
160    /// Get effective single-qubit error rate
161    pub fn effective_single_error(&self, qubit: usize) -> f64 {
162        self.single_qubit_errors
163            .get(&qubit)
164            .map(|e| match e {
165                SingleQubitNoiseType::Depolarizing(p) => *p,
166                SingleQubitNoiseType::AmplitudeDamping(p) => *p,
167                SingleQubitNoiseType::PhaseDamping(p) => *p,
168                SingleQubitNoiseType::BitFlip(p) => *p,
169                SingleQubitNoiseType::PhaseFlip(p) => *p,
170            })
171            .unwrap_or(0.0)
172            * self.noise_scale
173    }
174
175    /// Get effective two-qubit error rate
176    pub fn effective_two_qubit_error(&self, q1: usize, q2: usize) -> f64 {
177        let key = if q1 < q2 { (q1, q2) } else { (q2, q1) };
178        self.two_qubit_errors
179            .get(&key)
180            .map(|e| match e {
181                TwoQubitNoiseType::Depolarizing(p) => *p,
182                TwoQubitNoiseType::CorrelatedDephasing(p) => *p,
183                TwoQubitNoiseType::CrossTalk(p) => *p,
184            })
185            .unwrap_or(0.0)
186            * self.noise_scale
187    }
188}
189
190// ============================================================================
191// Noise-Aware Gradient
192// ============================================================================
193
194/// Configuration for noise-aware gradient computation
195#[derive(Debug, Clone)]
196pub struct NoiseAwareGradientConfig {
197    /// Number of shots for gradient estimation
198    pub shots: usize,
199    /// Parameter shift value (default: π/2)
200    pub shift: f64,
201    /// Whether to use noise model in gradient computation
202    pub include_noise_in_gradient: bool,
203    /// Variance reduction method
204    pub variance_reduction: VarianceReduction,
205    /// Number of repetitions for averaging
206    pub n_repetitions: usize,
207}
208
209impl Default for NoiseAwareGradientConfig {
210    fn default() -> Self {
211        Self {
212            shots: 1000,
213            shift: std::f64::consts::FRAC_PI_2,
214            include_noise_in_gradient: true,
215            variance_reduction: VarianceReduction::None,
216            n_repetitions: 1,
217        }
218    }
219}
220
221/// Variance reduction methods for gradient estimation
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223pub enum VarianceReduction {
224    /// No variance reduction
225    None,
226    /// Common random numbers
227    CommonRandomNumbers,
228    /// Antithetic variates
229    AntitheticVariates,
230    /// Control variates
231    ControlVariates,
232}
233
234/// Noise-aware gradient estimator
235///
236/// Computes gradients that account for device noise, providing more
237/// accurate optimization on noisy quantum hardware.
238#[derive(Debug, Clone)]
239pub struct NoiseAwareGradient {
240    /// Noise model
241    pub noise_model: NoiseModel,
242    /// Configuration
243    pub config: NoiseAwareGradientConfig,
244    /// Cached gradient variances
245    gradient_variances: HashMap<String, f64>,
246}
247
248impl NoiseAwareGradient {
249    /// Create new noise-aware gradient estimator
250    pub fn new(noise_model: NoiseModel) -> Self {
251        Self {
252            noise_model,
253            config: NoiseAwareGradientConfig::default(),
254            gradient_variances: HashMap::new(),
255        }
256    }
257
258    /// Create with custom configuration
259    pub fn with_config(noise_model: NoiseModel, config: NoiseAwareGradientConfig) -> Self {
260        Self {
261            noise_model,
262            config,
263            gradient_variances: HashMap::new(),
264        }
265    }
266
267    /// Compute parameter-shift gradient with noise awareness
268    ///
269    /// Returns the gradient estimate and its estimated variance
270    pub fn compute_gradient<F>(
271        &mut self,
272        param_idx: usize,
273        current_params: &[f64],
274        expectation_fn: F,
275    ) -> Result<(f64, f64)>
276    where
277        F: Fn(&[f64]) -> Result<f64>,
278    {
279        let mut params_plus = current_params.to_vec();
280        let mut params_minus = current_params.to_vec();
281
282        params_plus[param_idx] += self.config.shift;
283        params_minus[param_idx] -= self.config.shift;
284
285        let mut gradient_estimates = Vec::with_capacity(self.config.n_repetitions);
286
287        for _ in 0..self.config.n_repetitions {
288            let exp_plus = expectation_fn(&params_plus)?;
289            let exp_minus = expectation_fn(&params_minus)?;
290
291            let gradient = (exp_plus - exp_minus) / (2.0 * self.config.shift.sin());
292            gradient_estimates.push(gradient);
293        }
294
295        // Compute mean and variance
296        let mean = gradient_estimates.iter().sum::<f64>() / gradient_estimates.len() as f64;
297        let variance = if gradient_estimates.len() > 1 {
298            gradient_estimates
299                .iter()
300                .map(|g| (g - mean).powi(2))
301                .sum::<f64>()
302                / (gradient_estimates.len() - 1) as f64
303        } else {
304            0.0
305        };
306
307        // Apply noise correction if configured
308        let corrected_gradient = if self.config.include_noise_in_gradient {
309            self.apply_noise_correction(mean, param_idx)
310        } else {
311            mean
312        };
313
314        Ok((corrected_gradient, variance))
315    }
316
317    /// Apply noise correction to gradient estimate
318    fn apply_noise_correction(&self, gradient: f64, _param_idx: usize) -> f64 {
319        // Simple noise scaling based on average error rate
320        let avg_error: f64 = self
321            .noise_model
322            .single_qubit_errors
323            .values()
324            .map(|e| match e {
325                SingleQubitNoiseType::Depolarizing(p) => *p,
326                SingleQubitNoiseType::AmplitudeDamping(p) => *p,
327                SingleQubitNoiseType::PhaseDamping(p) => *p,
328                SingleQubitNoiseType::BitFlip(p) => *p,
329                SingleQubitNoiseType::PhaseFlip(p) => *p,
330            })
331            .sum::<f64>()
332            / self.noise_model.single_qubit_errors.len().max(1) as f64;
333
334        // Scale gradient to account for noise-induced suppression
335        let scale_factor = 1.0 / (1.0 - 2.0 * avg_error).max(0.1);
336        gradient * scale_factor
337    }
338
339    /// Compute all gradients for a parameter vector
340    pub fn compute_all_gradients<F>(
341        &mut self,
342        params: &[f64],
343        expectation_fn: F,
344    ) -> Result<(Vec<f64>, Vec<f64>)>
345    where
346        F: Fn(&[f64]) -> Result<f64> + Clone,
347    {
348        let mut gradients = Vec::with_capacity(params.len());
349        let mut variances = Vec::with_capacity(params.len());
350
351        for i in 0..params.len() {
352            let (grad, var) = self.compute_gradient(i, params, expectation_fn.clone())?;
353            gradients.push(grad);
354            variances.push(var);
355        }
356
357        Ok((gradients, variances))
358    }
359}
360
361// ============================================================================
362// Mitigated Expectation Value
363// ============================================================================
364
365/// Error mitigation method
366#[derive(Debug, Clone, Copy, PartialEq, Eq)]
367pub enum MitigationMethod {
368    /// No mitigation
369    None,
370    /// Zero-Noise Extrapolation
371    ZNE,
372    /// Probabilistic Error Cancellation
373    PEC,
374    /// Readout Error Mitigation
375    ReadoutMitigation,
376    /// Twirling (Pauli or Clifford)
377    Twirling,
378}
379
380/// Configuration for error-mitigated expectation values
381#[derive(Debug, Clone)]
382pub struct MitigatedExpectationConfig {
383    /// Primary mitigation method
384    pub method: MitigationMethod,
385    /// Number of shots
386    pub shots: usize,
387    /// Scale factors for ZNE
388    pub zne_scale_factors: Vec<f64>,
389    /// Extrapolation method for ZNE
390    pub zne_extrapolation: ZNEExtrapolation,
391    /// Whether to apply readout mitigation
392    pub apply_readout_mitigation: bool,
393}
394
395impl Default for MitigatedExpectationConfig {
396    fn default() -> Self {
397        Self {
398            method: MitigationMethod::ZNE,
399            shots: 4000,
400            zne_scale_factors: vec![1.0, 1.5, 2.0],
401            zne_extrapolation: ZNEExtrapolation::Linear,
402            apply_readout_mitigation: true,
403        }
404    }
405}
406
407/// ZNE extrapolation methods
408#[derive(Debug, Clone, Copy, PartialEq, Eq)]
409pub enum ZNEExtrapolation {
410    /// Linear extrapolation
411    Linear,
412    /// Polynomial extrapolation
413    Polynomial,
414    /// Exponential extrapolation
415    Exponential,
416    /// Richardson extrapolation
417    Richardson,
418}
419
420/// Error-mitigated expectation value estimator
421#[derive(Debug, Clone)]
422pub struct MitigatedExpectation {
423    /// Noise model
424    pub noise_model: NoiseModel,
425    /// Configuration
426    pub config: MitigatedExpectationConfig,
427    /// Readout calibration matrix (if computed)
428    readout_calibration: Option<Array2<f64>>,
429}
430
431impl MitigatedExpectation {
432    /// Create new mitigated expectation estimator
433    pub fn new(noise_model: NoiseModel) -> Self {
434        Self {
435            noise_model,
436            config: MitigatedExpectationConfig::default(),
437            readout_calibration: None,
438        }
439    }
440
441    /// Create with custom configuration
442    pub fn with_config(noise_model: NoiseModel, config: MitigatedExpectationConfig) -> Self {
443        Self {
444            noise_model,
445            config,
446            readout_calibration: None,
447        }
448    }
449
450    /// Compute error-mitigated expectation value
451    pub fn compute<F>(&self, raw_expectation_fn: F) -> Result<f64>
452    where
453        F: Fn(f64) -> Result<f64>,
454    {
455        match self.config.method {
456            MitigationMethod::None => raw_expectation_fn(1.0),
457            MitigationMethod::ZNE => self.compute_zne(raw_expectation_fn),
458            MitigationMethod::ReadoutMitigation => {
459                let raw = raw_expectation_fn(1.0)?;
460                self.apply_readout_mitigation(raw)
461            }
462            _ => raw_expectation_fn(1.0), // Fallback for unimplemented methods
463        }
464    }
465
466    /// Compute ZNE-mitigated expectation value
467    fn compute_zne<F>(&self, raw_expectation_fn: F) -> Result<f64>
468    where
469        F: Fn(f64) -> Result<f64>,
470    {
471        let mut scaled_values = Vec::with_capacity(self.config.zne_scale_factors.len());
472
473        for &scale in &self.config.zne_scale_factors {
474            let value = raw_expectation_fn(scale)?;
475            scaled_values.push((scale, value));
476        }
477
478        // Extrapolate to zero noise
479        self.extrapolate_to_zero(&scaled_values)
480    }
481
482    /// Extrapolate to zero noise using configured method
483    fn extrapolate_to_zero(&self, points: &[(f64, f64)]) -> Result<f64> {
484        if points.is_empty() {
485            return Err(MLError::InvalidConfiguration(
486                "No data points for extrapolation".to_string(),
487            ));
488        }
489
490        if points.len() == 1 {
491            return Ok(points[0].1);
492        }
493
494        match self.config.zne_extrapolation {
495            ZNEExtrapolation::Linear => {
496                // Linear fit: y = a + b*x, extrapolate to x=0
497                let n = points.len() as f64;
498                let sum_x: f64 = points.iter().map(|(x, _)| x).sum();
499                let sum_y: f64 = points.iter().map(|(_, y)| y).sum();
500                let sum_xy: f64 = points.iter().map(|(x, y)| x * y).sum();
501                let sum_x2: f64 = points.iter().map(|(x, _)| x * x).sum();
502
503                let denom = n * sum_x2 - sum_x * sum_x;
504                if denom.abs() < 1e-10 {
505                    return Ok(sum_y / n);
506                }
507
508                let a = (sum_y * sum_x2 - sum_x * sum_xy) / denom;
509                Ok(a) // y at x=0
510            }
511            ZNEExtrapolation::Exponential => {
512                // Exponential fit: y = a * exp(b*x)
513                // Take log and do linear fit
514                let log_points: Vec<(f64, f64)> = points
515                    .iter()
516                    .filter(|(_, y)| *y > 0.0)
517                    .map(|(x, y)| (*x, y.ln()))
518                    .collect();
519
520                if log_points.is_empty() {
521                    return Ok(points[0].1);
522                }
523
524                let n = log_points.len() as f64;
525                let sum_x: f64 = log_points.iter().map(|(x, _)| x).sum();
526                let sum_y: f64 = log_points.iter().map(|(_, y)| y).sum();
527                let sum_xy: f64 = log_points.iter().map(|(x, y)| x * y).sum();
528                let sum_x2: f64 = log_points.iter().map(|(x, _)| x * x).sum();
529
530                let denom = n * sum_x2 - sum_x * sum_x;
531                if denom.abs() < 1e-10 {
532                    return Ok((sum_y / n).exp());
533                }
534
535                let log_a = (sum_y * sum_x2 - sum_x * sum_xy) / denom;
536                Ok(log_a.exp())
537            }
538            _ => {
539                // Default to linear for other methods
540                let sum_y: f64 = points.iter().map(|(_, y)| y).sum();
541                Ok(sum_y / points.len() as f64)
542            }
543        }
544    }
545
546    /// Apply readout error mitigation
547    fn apply_readout_mitigation(&self, raw_value: f64) -> Result<f64> {
548        // Simple correction based on average readout error
549        let avg_readout_error: f64 = self.noise_model.readout_errors.values().sum::<f64>()
550            / self.noise_model.readout_errors.len().max(1) as f64;
551
552        // Correct for readout bias
553        let corrected = (raw_value - avg_readout_error) / (1.0 - 2.0 * avg_readout_error);
554        Ok(corrected.clamp(-1.0, 1.0))
555    }
556
557    /// Calibrate readout errors
558    pub fn calibrate_readout(&mut self, n_qubits: usize) -> Result<()> {
559        // Build simple diagonal calibration matrix
560        let dim = 1 << n_qubits;
561        let mut cal_matrix = Array2::<f64>::eye(dim);
562
563        // Apply readout errors to diagonal
564        for q in 0..n_qubits {
565            if let Some(&err) = self.noise_model.readout_errors.get(&q) {
566                for i in 0..dim {
567                    let bit = (i >> q) & 1;
568                    if bit == 0 {
569                        cal_matrix[[i, i]] *= 1.0 - err;
570                    } else {
571                        cal_matrix[[i, i]] *= 1.0 - err;
572                    }
573                }
574            }
575        }
576
577        self.readout_calibration = Some(cal_matrix);
578        Ok(())
579    }
580}
581
582// ============================================================================
583// Noise-Aware Training Wrapper
584// ============================================================================
585
586/// Wrapper for noise-aware training of quantum circuits
587#[derive(Debug)]
588pub struct NoiseAwareTrainer {
589    /// Noise-aware gradient estimator
590    pub gradient_estimator: NoiseAwareGradient,
591    /// Mitigated expectation estimator
592    pub expectation_estimator: MitigatedExpectation,
593    /// Training history
594    pub history: TrainingHistory,
595}
596
597/// Training history
598#[derive(Debug, Clone, Default)]
599pub struct TrainingHistory {
600    /// Loss values per epoch
601    pub losses: Vec<f64>,
602    /// Gradient norms per epoch
603    pub gradient_norms: Vec<f64>,
604    /// Mitigated vs raw loss difference
605    pub mitigation_improvement: Vec<f64>,
606}
607
608impl NoiseAwareTrainer {
609    /// Create new noise-aware trainer
610    pub fn new(noise_model: NoiseModel) -> Self {
611        Self {
612            gradient_estimator: NoiseAwareGradient::new(noise_model.clone()),
613            expectation_estimator: MitigatedExpectation::new(noise_model),
614            history: TrainingHistory::default(),
615        }
616    }
617
618    /// Perform one training step
619    pub fn step<F>(&mut self, params: &mut [f64], loss_fn: F, learning_rate: f64) -> Result<f64>
620    where
621        F: Fn(&[f64]) -> Result<f64> + Clone,
622    {
623        // Compute gradients with noise awareness
624        let (gradients, variances) = self
625            .gradient_estimator
626            .compute_all_gradients(params, loss_fn.clone())?;
627
628        // Compute gradient norm
629        let grad_norm: f64 = gradients.iter().map(|g| g * g).sum::<f64>().sqrt();
630        self.history.gradient_norms.push(grad_norm);
631
632        // Update parameters with variance-weighted learning rate
633        for (i, (param, grad)) in params.iter_mut().zip(gradients.iter()).enumerate() {
634            let variance_factor = 1.0 / (1.0 + variances[i].sqrt());
635            *param -= learning_rate * grad * variance_factor;
636        }
637
638        // Compute mitigated loss
639        let loss = self.expectation_estimator.compute(|scale| {
640            // For ZNE, we'd scale the noise here
641            let _ = scale; // Currently not used in simple implementation
642            loss_fn(params)
643        })?;
644
645        self.history.losses.push(loss);
646        Ok(loss)
647    }
648
649    /// Get current training statistics
650    pub fn statistics(&self) -> TrainingStatistics {
651        let n = self.history.losses.len();
652        if n == 0 {
653            return TrainingStatistics::default();
654        }
655
656        let avg_loss = self.history.losses.iter().sum::<f64>() / n as f64;
657        let avg_grad_norm = self.history.gradient_norms.iter().sum::<f64>() / n as f64;
658        let recent_loss = self.history.losses.last().copied().unwrap_or(0.0);
659
660        TrainingStatistics {
661            epochs: n,
662            average_loss: avg_loss,
663            recent_loss,
664            average_gradient_norm: avg_grad_norm,
665            converged: avg_grad_norm < 1e-6,
666        }
667    }
668}
669
670/// Training statistics
671#[derive(Debug, Clone, Default)]
672pub struct TrainingStatistics {
673    /// Number of epochs completed
674    pub epochs: usize,
675    /// Average loss across all epochs
676    pub average_loss: f64,
677    /// Most recent loss
678    pub recent_loss: f64,
679    /// Average gradient norm
680    pub average_gradient_norm: f64,
681    /// Whether training has converged
682    pub converged: bool,
683}
684
685// ============================================================================
686// Tests
687// ============================================================================
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692
693    #[test]
694    fn test_noise_model_ideal() {
695        let model = NoiseModel::ideal();
696        assert_eq!(model.noise_scale, 0.0);
697        assert!(model.single_qubit_errors.is_empty());
698    }
699
700    #[test]
701    fn test_noise_model_depolarizing() {
702        let model = NoiseModel::uniform_depolarizing(4, 0.01, 0.02);
703        assert_eq!(model.single_qubit_errors.len(), 4);
704        assert!(!model.two_qubit_errors.is_empty());
705    }
706
707    #[test]
708    fn test_noise_aware_gradient() {
709        let model = NoiseModel::uniform_depolarizing(2, 0.01, 0.02);
710        let mut estimator = NoiseAwareGradient::new(model);
711
712        let params = vec![0.5, 0.3];
713        let (grad, var) = estimator
714            .compute_gradient(0, &params, |p| Ok(p[0].sin() + p[1].cos()))
715            .expect("Should compute gradient");
716
717        // Gradient of sin(x) at x=0.5 is cos(0.5) ≈ 0.877
718        assert!((grad - 0.5_f64.cos()).abs() < 0.3); // Allow for noise correction
719        assert!(var >= 0.0);
720    }
721
722    #[test]
723    fn test_mitigated_expectation_linear() {
724        let model = NoiseModel::ideal();
725        let estimator = MitigatedExpectation::new(model);
726
727        // Test linear extrapolation
728        let result = estimator
729            .compute(|scale| Ok(1.0 - 0.1 * scale))
730            .expect("Should compute");
731
732        // Linear extrapolation to scale=0 should give ~1.0
733        assert!((result - 1.0).abs() < 0.1);
734    }
735
736    #[test]
737    fn test_training_statistics() {
738        let model = NoiseModel::ideal();
739        let trainer = NoiseAwareTrainer::new(model);
740        let stats = trainer.statistics();
741
742        assert_eq!(stats.epochs, 0);
743        assert!(!stats.converged);
744    }
745
746    #[test]
747    fn test_mitigation_methods() {
748        assert_eq!(MitigationMethod::ZNE, MitigationMethod::ZNE);
749        assert_ne!(MitigationMethod::ZNE, MitigationMethod::PEC);
750    }
751
752    #[test]
753    fn test_zne_extrapolation() {
754        let model = NoiseModel::ideal();
755        let config = MitigatedExpectationConfig {
756            method: MitigationMethod::ZNE,
757            shots: 1000,
758            zne_scale_factors: vec![1.0, 2.0, 3.0],
759            zne_extrapolation: ZNEExtrapolation::Linear,
760            apply_readout_mitigation: false,
761        };
762        let estimator = MitigatedExpectation::with_config(model, config);
763
764        // y = 1 - 0.1*x extrapolated to x=0 should be 1.0
765        let result = estimator
766            .compute(|scale| Ok(1.0 - 0.1 * scale))
767            .expect("Should succeed");
768
769        assert!((result - 1.0).abs() < 0.1);
770    }
771}