Skip to main content

quantrs2_ml/
simulator_backends.rs

1//! Simulator backend integration for quantum machine learning
2//!
3//! This module provides unified interfaces to all quantum simulators
4//! available in the QuantRS2 ecosystem, enabling seamless backend
5//! switching for quantum ML algorithms.
6
7use crate::error::{MLError, Result};
8use quantrs2_circuit::prelude::*;
9use quantrs2_core::prelude::*;
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
11use scirs2_core::Complex64;
12// GpuStateVectorSimulator import removed - not used in this file
13// The GPUBackend is a placeholder that doesn't use the actual GPU simulator yet
14use quantrs2_sim::prelude::{MPSSimulator, PauliString, StateVectorSimulator};
15use std::collections::HashMap;
16
17/// Dynamic circuit representation for trait objects
18#[derive(Debug, Clone)]
19pub enum DynamicCircuit {
20    Circuit1(Circuit<1>),
21    Circuit2(Circuit<2>),
22    Circuit4(Circuit<4>),
23    Circuit8(Circuit<8>),
24    Circuit16(Circuit<16>),
25    Circuit32(Circuit<32>),
26    Circuit64(Circuit<64>),
27}
28
29impl DynamicCircuit {
30    /// Create from a generic circuit
31    pub fn from_circuit<const N: usize>(circuit: Circuit<N>) -> Result<Self> {
32        match N {
33            1 => Ok(DynamicCircuit::Circuit1(unsafe {
34                std::mem::transmute(circuit)
35            })),
36            2 => Ok(DynamicCircuit::Circuit2(unsafe {
37                std::mem::transmute(circuit)
38            })),
39            4 => Ok(DynamicCircuit::Circuit4(unsafe {
40                std::mem::transmute(circuit)
41            })),
42            8 => Ok(DynamicCircuit::Circuit8(unsafe {
43                std::mem::transmute(circuit)
44            })),
45            16 => Ok(DynamicCircuit::Circuit16(unsafe {
46                std::mem::transmute(circuit)
47            })),
48            32 => Ok(DynamicCircuit::Circuit32(unsafe {
49                std::mem::transmute(circuit)
50            })),
51            64 => Ok(DynamicCircuit::Circuit64(unsafe {
52                std::mem::transmute(circuit)
53            })),
54            _ => Err(MLError::ValidationError(format!(
55                "Unsupported circuit size: {}",
56                N
57            ))),
58        }
59    }
60
61    /// Get the number of qubits
62    pub fn num_qubits(&self) -> usize {
63        match self {
64            DynamicCircuit::Circuit1(_) => 1,
65            DynamicCircuit::Circuit2(_) => 2,
66            DynamicCircuit::Circuit4(_) => 4,
67            DynamicCircuit::Circuit8(_) => 8,
68            DynamicCircuit::Circuit16(_) => 16,
69            DynamicCircuit::Circuit32(_) => 32,
70            DynamicCircuit::Circuit64(_) => 64,
71        }
72    }
73
74    /// Get the number of gates (placeholder implementation)
75    pub fn num_gates(&self) -> usize {
76        match self {
77            DynamicCircuit::Circuit1(c) => c.gates().len(),
78            DynamicCircuit::Circuit2(c) => c.gates().len(),
79            DynamicCircuit::Circuit4(c) => c.gates().len(),
80            DynamicCircuit::Circuit8(c) => c.gates().len(),
81            DynamicCircuit::Circuit16(c) => c.gates().len(),
82            DynamicCircuit::Circuit32(c) => c.gates().len(),
83            DynamicCircuit::Circuit64(c) => c.gates().len(),
84        }
85    }
86
87    /// Get circuit depth (placeholder implementation)
88    pub fn depth(&self) -> usize {
89        // Simplified depth calculation - just return number of gates for now
90        self.num_gates()
91    }
92
93    /// Get gates (placeholder implementation)
94    pub fn gates(&self) -> Vec<&dyn quantrs2_core::gate::GateOp> {
95        match self {
96            DynamicCircuit::Circuit1(c) => c
97                .gates()
98                .iter()
99                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
100                .collect(),
101            DynamicCircuit::Circuit2(c) => c
102                .gates()
103                .iter()
104                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
105                .collect(),
106            DynamicCircuit::Circuit4(c) => c
107                .gates()
108                .iter()
109                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
110                .collect(),
111            DynamicCircuit::Circuit8(c) => c
112                .gates()
113                .iter()
114                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
115                .collect(),
116            DynamicCircuit::Circuit16(c) => c
117                .gates()
118                .iter()
119                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
120                .collect(),
121            DynamicCircuit::Circuit32(c) => c
122                .gates()
123                .iter()
124                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
125                .collect(),
126            DynamicCircuit::Circuit64(c) => c
127                .gates()
128                .iter()
129                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
130                .collect(),
131        }
132    }
133}
134
135/// Unified simulator backend interface
136pub trait SimulatorBackend: Send + Sync {
137    /// Execute a quantum circuit
138    fn execute_circuit(
139        &self,
140        circuit: &DynamicCircuit,
141        parameters: &[f64],
142        shots: Option<usize>,
143    ) -> Result<SimulationResult>;
144
145    /// Compute expectation value
146    fn expectation_value(
147        &self,
148        circuit: &DynamicCircuit,
149        parameters: &[f64],
150        observable: &Observable,
151    ) -> Result<f64>;
152
153    /// Compute gradients using backend-specific methods
154    fn compute_gradients(
155        &self,
156        circuit: &DynamicCircuit,
157        parameters: &[f64],
158        observable: &Observable,
159        gradient_method: GradientMethod,
160    ) -> Result<Array1<f64>>;
161
162    /// Get backend capabilities
163    fn capabilities(&self) -> BackendCapabilities;
164
165    /// Get backend name
166    fn name(&self) -> &str;
167
168    /// Maximum number of qubits supported
169    fn max_qubits(&self) -> usize;
170
171    /// Check if backend supports noise simulation
172    fn supports_noise(&self) -> bool;
173}
174
175/// Simulation result containing various outputs
176#[derive(Debug, Clone)]
177pub struct SimulationResult {
178    /// Final quantum state (if available)
179    pub state: Option<Array1<Complex64>>,
180    /// Measurement outcomes
181    pub measurements: Option<Array1<usize>>,
182    /// Measurement probabilities
183    pub probabilities: Option<Array1<f64>>,
184    /// Execution metadata
185    pub metadata: HashMap<String, f64>,
186}
187
188/// Observable for expectation value computations
189#[derive(Debug, Clone)]
190pub enum Observable {
191    /// Pauli string observable
192    PauliString(PauliString),
193    /// Pauli Z on specified qubits
194    PauliZ(Vec<usize>),
195    /// Custom Hermitian matrix
196    Matrix(Array2<Complex64>),
197    /// Hamiltonian as sum of Pauli strings
198    Hamiltonian(Vec<(f64, PauliString)>),
199}
200
201/// Gradient computation methods
202#[derive(Debug, Clone, Copy)]
203pub enum GradientMethod {
204    /// Parameter shift rule
205    ParameterShift,
206    /// Finite differences
207    FiniteDifference,
208    /// Adjoint differentiation (if supported)
209    Adjoint,
210    /// Stochastic parameter shift
211    StochasticParameterShift,
212}
213
214/// Backend capabilities
215#[derive(Debug, Clone, Default)]
216pub struct BackendCapabilities {
217    /// Maximum qubits
218    pub max_qubits: usize,
219    /// Supports noise simulation
220    pub noise_simulation: bool,
221    /// Supports GPU acceleration
222    pub gpu_acceleration: bool,
223    /// Supports distributed computation
224    pub distributed: bool,
225    /// Supports adjoint gradients
226    pub adjoint_gradients: bool,
227    /// Memory requirements per qubit (bytes)
228    pub memory_per_qubit: usize,
229}
230
231/// Statevector simulator backend
232#[derive(Debug)]
233pub struct StatevectorBackend {
234    /// Internal simulator
235    simulator: StateVectorSimulator,
236    /// Maximum qubits
237    max_qubits: usize,
238}
239
240impl StatevectorBackend {
241    /// Create new statevector backend
242    pub fn new(max_qubits: usize) -> Self {
243        Self {
244            simulator: StateVectorSimulator::new(),
245            max_qubits,
246        }
247    }
248}
249
250impl SimulatorBackend for StatevectorBackend {
251    fn execute_circuit(
252        &self,
253        circuit: &DynamicCircuit,
254        _parameters: &[f64],
255        _shots: Option<usize>,
256    ) -> Result<SimulationResult> {
257        /// Convert a Register's amplitude slice into an `Array1<Complex64>`.
258        fn register_to_array(amplitudes: &[Complex64]) -> Array1<Complex64> {
259            Array1::from_vec(amplitudes.to_vec())
260        }
261
262        macro_rules! run_circuit {
263            ($c:expr) => {{
264                let state = self.simulator.run($c)?;
265                let amps = register_to_array(state.amplitudes());
266                let probabilities: Vec<f64> = amps.iter().map(|c| c.norm_sqr()).collect();
267                Ok(SimulationResult {
268                    state: Some(amps),
269                    measurements: None,
270                    probabilities: Some(Array1::from_vec(probabilities)),
271                    metadata: HashMap::new(),
272                })
273            }};
274        }
275
276        match circuit {
277            DynamicCircuit::Circuit1(c) => run_circuit!(c),
278            DynamicCircuit::Circuit2(c) => run_circuit!(c),
279            DynamicCircuit::Circuit4(c) => run_circuit!(c),
280            DynamicCircuit::Circuit8(c) => run_circuit!(c),
281            DynamicCircuit::Circuit16(c) => run_circuit!(c),
282            DynamicCircuit::Circuit32(c) => run_circuit!(c),
283            DynamicCircuit::Circuit64(c) => run_circuit!(c),
284        }
285    }
286
287    fn expectation_value(
288        &self,
289        circuit: &DynamicCircuit,
290        _parameters: &[f64],
291        observable: &Observable,
292    ) -> Result<f64> {
293        macro_rules! run_and_expect {
294            ($c:expr) => {{
295                let state = self.simulator.run($c)?;
296                let amps = Array1::from_vec(state.amplitudes().to_vec());
297                self.compute_expectation(&amps, observable)
298            }};
299        }
300
301        match circuit {
302            DynamicCircuit::Circuit1(c) => run_and_expect!(c),
303            DynamicCircuit::Circuit2(c) => run_and_expect!(c),
304            DynamicCircuit::Circuit4(c) => run_and_expect!(c),
305            DynamicCircuit::Circuit8(c) => run_and_expect!(c),
306            DynamicCircuit::Circuit16(c) => run_and_expect!(c),
307            DynamicCircuit::Circuit32(c) => run_and_expect!(c),
308            DynamicCircuit::Circuit64(c) => run_and_expect!(c),
309        }
310    }
311
312    fn compute_gradients(
313        &self,
314        circuit: &DynamicCircuit,
315        _parameters: &[f64],
316        _observable: &Observable,
317        _gradient_method: GradientMethod,
318    ) -> Result<Array1<f64>> {
319        // Placeholder implementation
320        match circuit {
321            DynamicCircuit::Circuit1(_) => Ok(Array1::zeros(1)),
322            DynamicCircuit::Circuit2(_) => Ok(Array1::zeros(1)),
323            _ => Err(MLError::ValidationError(
324                "Unsupported circuit size".to_string(),
325            )),
326        }
327    }
328
329    /// Get backend capabilities
330    fn capabilities(&self) -> BackendCapabilities {
331        BackendCapabilities {
332            max_qubits: self.max_qubits,
333            noise_simulation: false,
334            gpu_acceleration: false,
335            distributed: false,
336            adjoint_gradients: false,
337            memory_per_qubit: 16, // 16 bytes per amplitude (Complex64)
338        }
339    }
340
341    /// Get backend name
342    fn name(&self) -> &str {
343        "statevector"
344    }
345
346    /// Maximum number of qubits supported
347    fn max_qubits(&self) -> usize {
348        self.max_qubits
349    }
350
351    /// Check if backend supports noise simulation
352    fn supports_noise(&self) -> bool {
353        false
354    }
355}
356
357impl StatevectorBackend {
358    /// Compute `<ψ|P|ψ>` for a single-qubit Pauli on qubit index `qubit_idx`.
359    ///
360    /// The statevector has `2^n` entries.  The basis states are indexed by integers where
361    /// bit `qubit_idx` (LSB = qubit 0) selects the computational basis state of that qubit.
362    ///
363    /// - `Z_i`: `<Z_i>` = Σ_{j: bit i is 0} |ψ_j|² - Σ_{j: bit i is 1} |ψ_j|²
364    /// - `X_i`: `<X_i>` = 2 · Re[ Σ_{j: bit i is 0} ψ_j* · ψ_{j ⊕ (1 << i)} ]
365    /// - `Y_i`: `<Y_i>` = 2 · Im[ Σ_{j: bit i is 0} ψ_j* · ψ_{j ⊕ (1 << i)} ]  (sign convention: Y = [[0,-i],[i,0]])
366    /// - `I_i`:  1.0
367    fn pauli_expectation_single(
368        &self,
369        state: &Array1<Complex64>,
370        pauli: char,
371        qubit_idx: usize,
372    ) -> Result<f64> {
373        let dim = state.len();
374        if dim == 0 {
375            return Err(MLError::ValidationError("Empty statevector".to_string()));
376        }
377        // Check dim is a power of 2.
378        if dim & (dim - 1) != 0 {
379            return Err(MLError::ValidationError(format!(
380                "Statevector dimension {dim} is not a power of 2"
381            )));
382        }
383        let n = dim.trailing_zeros() as usize; // number of qubits
384        if qubit_idx >= n {
385            return Err(MLError::ValidationError(format!(
386                "Qubit index {qubit_idx} out of range for {n}-qubit state"
387            )));
388        }
389
390        let bit = 1usize << qubit_idx;
391
392        match pauli {
393            'I' => Ok(1.0),
394            'Z' => {
395                let mut expectation = 0.0_f64;
396                for (j, amp) in state.iter().enumerate() {
397                    let prob = amp.norm_sqr();
398                    if j & bit == 0 {
399                        expectation += prob; // eigenvalue +1
400                    } else {
401                        expectation -= prob; // eigenvalue -1
402                    }
403                }
404                Ok(expectation)
405            }
406            'X' => {
407                // <X_i> = 2 Re[ Σ_{j: bit i=0} ψ_j* · ψ_{j ^ bit} ]
408                let mut sum = Complex64::new(0.0, 0.0);
409                for (j, amp) in state.iter().enumerate() {
410                    if j & bit == 0 {
411                        let partner = j ^ bit;
412                        if partner < dim {
413                            sum += amp.conj() * state[partner];
414                        }
415                    }
416                }
417                Ok(2.0 * sum.re)
418            }
419            'Y' => {
420                // Y = [[0,-i],[i,0]]; <Y_i> = 2 Im[ Σ_{j: bit i=0} ψ_j* · ψ_{j ^ bit} ]
421                // Because Y = -i·σ^+ + i·σ^-, the expectation is purely real:
422                // <Y> = Im[2 · Σ_{j: bit=0} conj(ψ_j) * ψ_{j^bit}]
423                let mut sum = Complex64::new(0.0, 0.0);
424                for (j, amp) in state.iter().enumerate() {
425                    if j & bit == 0 {
426                        let partner = j ^ bit;
427                        if partner < dim {
428                            sum += amp.conj() * state[partner];
429                        }
430                    }
431                }
432                Ok(2.0 * sum.im)
433            }
434            _ => Err(MLError::ValidationError(format!(
435                "Unknown Pauli operator '{pauli}'"
436            ))),
437        }
438    }
439
440    /// Helper method to compute expectation values `<ψ|O|ψ>`.
441    fn compute_expectation(
442        &self,
443        state: &Array1<Complex64>,
444        observable: &Observable,
445    ) -> Result<f64> {
446        match observable {
447            Observable::PauliString(pauli_string) => {
448                use quantrs2_sim::prelude::PauliOperator;
449
450                let dim = state.len();
451                // Apply the Pauli string to |ψ⟩: result_vec = P|ψ⟩
452                // Then <ψ|P|ψ> = Re[<ψ|result_vec>]  (expectation is real for Hermitian P)
453                let mut result_vec = state.clone();
454
455                for (qubit_idx, pauli_op) in pauli_string.operators.iter().enumerate() {
456                    let bit = 1usize << qubit_idx;
457                    match pauli_op {
458                        PauliOperator::I => {} // identity — no change
459                        PauliOperator::Z => {
460                            // Z flips sign for basis states where qubit i = 1
461                            for j in 0..dim {
462                                if j & bit != 0 {
463                                    result_vec[j] = -result_vec[j];
464                                }
465                            }
466                        }
467                        PauliOperator::X => {
468                            // X bit-flips qubit i: swap amplitude pairs (j, j^bit)
469                            for j in 0..dim {
470                                if j & bit == 0 {
471                                    let partner = j ^ bit;
472                                    if partner < dim {
473                                        result_vec.swap(j, partner);
474                                    }
475                                }
476                            }
477                        }
478                        PauliOperator::Y => {
479                            // Y = [[0,-i],[i,0]]:  |0⟩ → i|1⟩,  |1⟩ → -i|0⟩
480                            let mut new_vec = result_vec.clone();
481                            for j in 0..dim {
482                                if j & bit == 0 {
483                                    let partner = j ^ bit;
484                                    if partner < dim {
485                                        let orig_0 = result_vec[j];
486                                        let orig_1 = result_vec[partner];
487                                        new_vec[j] = Complex64::new(0.0, -1.0) * orig_1;
488                                        new_vec[partner] = Complex64::new(0.0, 1.0) * orig_0;
489                                    }
490                                }
491                            }
492                            result_vec = new_vec;
493                        }
494                    }
495                }
496
497                // Apply the overall coefficient from the PauliString.
498                // Then <ψ|P|ψ> = Re[ coeff · Σ_j conj(ψ_j) * (P|ψ>)_j ]
499                let inner: Complex64 = state
500                    .iter()
501                    .zip(result_vec.iter())
502                    .map(|(&a, &b)| a.conj() * b)
503                    .sum();
504                Ok((pauli_string.coefficient * inner).re)
505            }
506            Observable::PauliZ(qubits) => {
507                // Product of Z expectations on the given qubits.
508                // For a single-qubit problem: <Z_i> as defined above.
509                // For multiple qubits: < ⊗_i Z_i > using the combined bit-flip parity.
510                let dim = state.len();
511                let mut expectation = 0.0_f64;
512                for (j, amp) in state.iter().enumerate() {
513                    // Parity of bits at qubit positions listed in `qubits`.
514                    let parity: u32 = qubits
515                        .iter()
516                        .map(|&q| if j & (1 << q) != 0 { 1u32 } else { 0u32 })
517                        .sum::<u32>()
518                        % 2;
519                    let eigenvalue = if parity == 0 { 1.0 } else { -1.0 };
520                    expectation += eigenvalue * amp.norm_sqr();
521                }
522                Ok(expectation)
523            }
524            Observable::Matrix(matrix) => {
525                // Compute <ψ|H|ψ> = Σ_{i,j} ψ_i* H_{ij} ψ_j
526                let result: Complex64 = state
527                    .iter()
528                    .enumerate()
529                    .map(|(i, &amp_i)| {
530                        state
531                            .iter()
532                            .enumerate()
533                            .map(|(j, &amp_j)| amp_i.conj() * matrix[[i, j]] * amp_j)
534                            .sum::<Complex64>()
535                    })
536                    .sum();
537                Ok(result.re)
538            }
539            Observable::Hamiltonian(terms) => {
540                // H = Σ_k c_k P_k  →  <H> = Σ_k c_k <P_k>
541                let mut expectation = 0.0_f64;
542                for (coeff, pauli_string) in terms {
543                    let term_exp = self.compute_expectation(
544                        state,
545                        &Observable::PauliString(pauli_string.clone()),
546                    )?;
547                    expectation += coeff * term_exp;
548                }
549                Ok(expectation)
550            }
551        }
552    }
553
554    fn max_qubits(&self) -> usize {
555        self.max_qubits
556    }
557
558    fn supports_noise(&self) -> bool {
559        false
560    }
561}
562
563/// Matrix Product State (MPS) simulator backend
564pub struct MPSBackend {
565    /// Internal MPS simulator
566    simulator: MPSSimulator,
567    /// Bond dimension
568    bond_dimension: usize,
569    /// Maximum qubits
570    max_qubits: usize,
571}
572
573impl MPSBackend {
574    /// Create new MPS backend
575    pub fn new(bond_dimension: usize, max_qubits: usize) -> Self {
576        Self {
577            simulator: MPSSimulator::new(bond_dimension),
578            bond_dimension,
579            max_qubits,
580        }
581    }
582}
583
584impl SimulatorBackend for MPSBackend {
585    fn execute_circuit(
586        &self,
587        circuit: &DynamicCircuit,
588        _parameters: &[f64],
589        _shots: Option<usize>,
590    ) -> Result<SimulationResult> {
591        // MPS implementation depends on circuit size
592        match circuit {
593            DynamicCircuit::Circuit1(c) => {
594                // For small circuits, use basic MPS simulation
595                Ok(SimulationResult {
596                    state: None, // MPS doesn't expose full state
597                    measurements: None,
598                    probabilities: None,
599                    metadata: {
600                        let mut meta = HashMap::new();
601                        meta.insert("bond_dimension".to_string(), self.bond_dimension as f64);
602                        meta.insert("num_qubits".to_string(), 1.0);
603                        meta
604                    },
605                })
606            }
607            DynamicCircuit::Circuit2(c) => Ok(SimulationResult {
608                state: None,
609                measurements: None,
610                probabilities: None,
611                metadata: {
612                    let mut meta = HashMap::new();
613                    meta.insert("bond_dimension".to_string(), self.bond_dimension as f64);
614                    meta.insert("num_qubits".to_string(), 2.0);
615                    meta
616                },
617            }),
618            _ => {
619                // For larger circuits, need proper MPS simulation
620                Ok(SimulationResult {
621                    state: None,
622                    measurements: None,
623                    probabilities: None,
624                    metadata: {
625                        let mut meta = HashMap::new();
626                        meta.insert("bond_dimension".to_string(), self.bond_dimension as f64);
627                        meta.insert("num_qubits".to_string(), circuit.num_qubits() as f64);
628                        meta
629                    },
630                })
631            }
632        }
633    }
634
635    fn expectation_value(
636        &self,
637        circuit: &DynamicCircuit,
638        _parameters: &[f64],
639        observable: &Observable,
640    ) -> Result<f64> {
641        match observable {
642            Observable::PauliString(_pauli) => {
643                // Would compute expectation using MPS for any circuit size
644                Ok(0.0) // Placeholder implementation
645            }
646            Observable::PauliZ(_qubits) => {
647                // Would compute Z expectation using MPS
648                Ok(0.0) // Placeholder implementation
649            }
650            Observable::Hamiltonian(terms) => {
651                let mut expectation = 0.0;
652                for (coeff, _pauli) in terms {
653                    // Would compute each term using MPS
654                    expectation += coeff * 0.0; // Placeholder
655                }
656                Ok(expectation)
657            }
658            Observable::Matrix(_) => Err(MLError::NotSupported(
659                "Matrix observables not supported for MPS backend".to_string(),
660            )),
661        }
662    }
663
664    fn compute_gradients(
665        &self,
666        circuit: &DynamicCircuit,
667        parameters: &[f64],
668        observable: &Observable,
669        gradient_method: GradientMethod,
670    ) -> Result<Array1<f64>> {
671        match gradient_method {
672            GradientMethod::ParameterShift => {
673                self.parameter_shift_gradients_dynamic(circuit, parameters, observable)
674            }
675            _ => Err(MLError::NotSupported(
676                "Only parameter shift gradients supported for MPS backend".to_string(),
677            )),
678        }
679    }
680
681    fn capabilities(&self) -> BackendCapabilities {
682        BackendCapabilities {
683            max_qubits: self.max_qubits,
684            noise_simulation: false,
685            gpu_acceleration: false,
686            distributed: false,
687            adjoint_gradients: false,
688            memory_per_qubit: self.bond_dimension * self.bond_dimension * 16, // D^2 * 16 bytes
689        }
690    }
691
692    fn name(&self) -> &str {
693        "mps"
694    }
695
696    fn max_qubits(&self) -> usize {
697        self.max_qubits
698    }
699
700    fn supports_noise(&self) -> bool {
701        false
702    }
703}
704
705impl MPSBackend {
706    fn parameter_shift_gradients_dynamic(
707        &self,
708        circuit: &DynamicCircuit,
709        parameters: &[f64],
710        observable: &Observable,
711    ) -> Result<Array1<f64>> {
712        let shift = std::f64::consts::PI / 2.0;
713        let mut gradients = Array1::zeros(parameters.len());
714
715        for i in 0..parameters.len() {
716            let mut params_plus = parameters.to_vec();
717            params_plus[i] += shift;
718            let val_plus = self.expectation_value(circuit, &params_plus, observable)?;
719
720            let mut params_minus = parameters.to_vec();
721            params_minus[i] -= shift;
722            let val_minus = self.expectation_value(circuit, &params_minus, observable)?;
723
724            gradients[i] = (val_plus - val_minus) / 2.0;
725        }
726
727        Ok(gradients)
728    }
729}
730
731// GPU backend is now implemented in gpu_backend_impl module
732#[cfg(feature = "gpu")]
733pub use crate::gpu_backend_impl::GPUBackend;
734
735// SimulatorBackend implementation for GPUBackend is in gpu_backend_impl.rs
736
737/// Enum for different backend types (avoids dyn compatibility issues)
738pub enum Backend {
739    Statevector(StatevectorBackend),
740    MPS(MPSBackend),
741    #[cfg(feature = "gpu")]
742    GPU(GPUBackend),
743}
744
745impl SimulatorBackend for Backend {
746    fn execute_circuit(
747        &self,
748        circuit: &DynamicCircuit,
749        parameters: &[f64],
750        shots: Option<usize>,
751    ) -> Result<SimulationResult> {
752        match self {
753            Backend::Statevector(backend) => backend.execute_circuit(circuit, parameters, shots),
754            Backend::MPS(backend) => backend.execute_circuit(circuit, parameters, shots),
755            #[cfg(feature = "gpu")]
756            Backend::GPU(backend) => backend.execute_circuit(circuit, parameters, shots),
757        }
758    }
759
760    fn expectation_value(
761        &self,
762        circuit: &DynamicCircuit,
763        parameters: &[f64],
764        observable: &Observable,
765    ) -> Result<f64> {
766        match self {
767            Backend::Statevector(backend) => {
768                backend.expectation_value(circuit, parameters, observable)
769            }
770            Backend::MPS(backend) => backend.expectation_value(circuit, parameters, observable),
771            #[cfg(feature = "gpu")]
772            Backend::GPU(backend) => backend.expectation_value(circuit, parameters, observable),
773        }
774    }
775
776    fn compute_gradients(
777        &self,
778        circuit: &DynamicCircuit,
779        parameters: &[f64],
780        observable: &Observable,
781        gradient_method: GradientMethod,
782    ) -> Result<Array1<f64>> {
783        match self {
784            Backend::Statevector(backend) => {
785                backend.compute_gradients(circuit, parameters, observable, gradient_method)
786            }
787            Backend::MPS(backend) => {
788                backend.compute_gradients(circuit, parameters, observable, gradient_method)
789            }
790            #[cfg(feature = "gpu")]
791            Backend::GPU(backend) => {
792                backend.compute_gradients(circuit, parameters, observable, gradient_method)
793            }
794        }
795    }
796
797    fn capabilities(&self) -> BackendCapabilities {
798        match self {
799            Backend::Statevector(backend) => backend.capabilities(),
800            Backend::MPS(backend) => backend.capabilities(),
801            #[cfg(feature = "gpu")]
802            Backend::GPU(backend) => backend.capabilities(),
803        }
804    }
805
806    fn name(&self) -> &str {
807        match self {
808            Backend::Statevector(backend) => backend.name(),
809            Backend::MPS(backend) => backend.name(),
810            #[cfg(feature = "gpu")]
811            Backend::GPU(backend) => backend.name(),
812        }
813    }
814
815    fn max_qubits(&self) -> usize {
816        match self {
817            Backend::Statevector(backend) => backend.max_qubits(),
818            Backend::MPS(backend) => backend.max_qubits(),
819            #[cfg(feature = "gpu")]
820            Backend::GPU(backend) => backend.max_qubits(),
821        }
822    }
823
824    fn supports_noise(&self) -> bool {
825        match self {
826            Backend::Statevector(backend) => backend.supports_noise(),
827            Backend::MPS(backend) => backend.supports_noise(),
828            #[cfg(feature = "gpu")]
829            Backend::GPU(backend) => backend.supports_noise(),
830        }
831    }
832}
833
834/// Backend manager for automatic backend selection
835pub struct BackendManager {
836    /// Available backends
837    backends: HashMap<String, Backend>,
838    /// Current backend
839    current_backend: Option<String>,
840    /// Backend selection strategy
841    selection_strategy: BackendSelectionStrategy,
842}
843
844/// Backend selection strategies
845#[derive(Debug, Clone)]
846pub enum BackendSelectionStrategy {
847    /// Use fastest backend for given problem size
848    Fastest,
849    /// Use most memory-efficient backend
850    MemoryEfficient,
851    /// Use most accurate backend
852    MostAccurate,
853    /// User-specified backend
854    Manual(String),
855}
856
857impl BackendManager {
858    /// Create a new backend manager
859    pub fn new() -> Self {
860        Self {
861            backends: HashMap::new(),
862            current_backend: None,
863            selection_strategy: BackendSelectionStrategy::Fastest,
864        }
865    }
866
867    /// Register a backend
868    pub fn register_backend(&mut self, name: impl Into<String>, backend: Backend) {
869        self.backends.insert(name.into(), backend);
870    }
871
872    /// Set selection strategy
873    pub fn set_strategy(&mut self, strategy: BackendSelectionStrategy) {
874        self.selection_strategy = strategy;
875    }
876
877    /// Select optimal backend for given problem
878    pub fn select_backend(&mut self, num_qubits: usize, shots: Option<usize>) -> Result<()> {
879        let backend_name = match &self.selection_strategy {
880            BackendSelectionStrategy::Fastest => self.select_fastest_backend(num_qubits, shots)?,
881            BackendSelectionStrategy::MemoryEfficient => {
882                self.select_memory_efficient_backend(num_qubits)?
883            }
884            BackendSelectionStrategy::MostAccurate => {
885                self.select_most_accurate_backend(num_qubits)?
886            }
887            BackendSelectionStrategy::Manual(name) => name.clone(),
888        };
889
890        self.current_backend = Some(backend_name);
891        Ok(())
892    }
893
894    /// Execute circuit using selected backend
895    pub fn execute_circuit<const N: usize>(
896        &self,
897        circuit: &Circuit<N>,
898        parameters: &[f64],
899        shots: Option<usize>,
900    ) -> Result<SimulationResult> {
901        if let Some(ref backend_name) = self.current_backend {
902            if let Some(backend) = self.backends.get(backend_name) {
903                let dynamic_circuit = DynamicCircuit::from_circuit(circuit.clone())?;
904                backend.execute_circuit(&dynamic_circuit, parameters, shots)
905            } else {
906                Err(MLError::InvalidConfiguration(format!(
907                    "Backend '{}' not found",
908                    backend_name
909                )))
910            }
911        } else {
912            Err(MLError::InvalidConfiguration(
913                "No backend selected".to_string(),
914            ))
915        }
916    }
917
918    /// Get current backend
919    pub fn current_backend(&self) -> Option<&Backend> {
920        self.current_backend
921            .as_ref()
922            .and_then(|name| self.backends.get(name))
923    }
924
925    /// List available backends
926    pub fn list_backends(&self) -> Vec<(String, BackendCapabilities)> {
927        self.backends
928            .iter()
929            .map(|(name, backend)| (name.clone(), backend.capabilities()))
930            .collect()
931    }
932
933    fn select_fastest_backend(&self, num_qubits: usize, _shots: Option<usize>) -> Result<String> {
934        // Simple heuristic: GPU for large circuits, MPS for very large, statevector for small
935        if num_qubits <= 20 {
936            Ok("statevector".to_string())
937        } else if num_qubits <= 50 && self.backends.contains_key("gpu") {
938            Ok("gpu".to_string())
939        } else if self.backends.contains_key("mps") {
940            Ok("mps".to_string())
941        } else {
942            Err(MLError::InvalidConfiguration(
943                "No suitable backend for problem size".to_string(),
944            ))
945        }
946    }
947
948    fn select_memory_efficient_backend(&self, num_qubits: usize) -> Result<String> {
949        if num_qubits > 30 && self.backends.contains_key("mps") {
950            Ok("mps".to_string())
951        } else {
952            Ok("statevector".to_string())
953        }
954    }
955
956    fn select_most_accurate_backend(&self, _num_qubits: usize) -> Result<String> {
957        // Statevector is most accurate
958        Ok("statevector".to_string())
959    }
960}
961
962/// Helper functions for backend management
963pub mod backend_utils {
964    use super::*;
965
966    /// Create default backend manager with all available backends
967    pub fn create_default_manager() -> BackendManager {
968        let mut manager = BackendManager::new();
969
970        // Register statevector backend
971        manager.register_backend(
972            "statevector",
973            Backend::Statevector(StatevectorBackend::new(25)),
974        );
975
976        // Register MPS backend
977        manager.register_backend("mps", Backend::MPS(MPSBackend::new(64, 100)));
978
979        // Register GPU backend if available
980        #[cfg(feature = "gpu")]
981        {
982            if let Ok(gpu_backend) = GPUBackend::new(0, 30) {
983                manager.register_backend("gpu", Backend::GPU(gpu_backend));
984            }
985        }
986
987        manager
988    }
989
990    /// Benchmark backends for given problem
991    pub fn benchmark_backends<const N: usize>(
992        manager: &BackendManager,
993        circuit: &Circuit<N>,
994        parameters: &[f64],
995    ) -> Result<HashMap<String, f64>> {
996        let mut results = HashMap::new();
997
998        for (backend_name, _) in manager.list_backends() {
999            let start = std::time::Instant::now();
1000
1001            // Would execute circuit multiple times for accurate timing
1002            let _result = manager.execute_circuit(circuit, parameters, None)?;
1003
1004            let duration = start.elapsed().as_secs_f64();
1005            results.insert(backend_name, duration);
1006        }
1007
1008        Ok(results)
1009    }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014    use super::*;
1015
1016    #[test]
1017    fn test_statevector_backend() {
1018        let backend = StatevectorBackend::new(10);
1019        assert_eq!(backend.name(), "statevector");
1020        assert_eq!(backend.max_qubits(), 10);
1021        assert!(!backend.supports_noise());
1022    }
1023
1024    #[test]
1025    fn test_mps_backend() {
1026        let backend = MPSBackend::new(64, 50);
1027        assert_eq!(backend.name(), "mps");
1028        assert_eq!(backend.max_qubits(), 50);
1029
1030        let caps = backend.capabilities();
1031        assert!(!caps.adjoint_gradients);
1032        assert!(!caps.gpu_acceleration);
1033    }
1034
1035    #[test]
1036    fn test_backend_manager() {
1037        let mut manager = BackendManager::new();
1038        manager.register_backend("test", Backend::Statevector(StatevectorBackend::new(10)));
1039
1040        let backends = manager.list_backends();
1041        assert_eq!(backends.len(), 1);
1042        assert_eq!(backends[0].0, "test");
1043    }
1044
1045    #[test]
1046    fn test_backend_selection() {
1047        let mut manager = backend_utils::create_default_manager();
1048        manager.set_strategy(BackendSelectionStrategy::Fastest);
1049
1050        let result = manager.select_backend(15, None);
1051        assert!(result.is_ok());
1052
1053        let result = manager.select_backend(35, None);
1054        assert!(result.is_ok());
1055    }
1056}