quantrs2_ml/
simulator_backends.rs

1//! Simulator backend integration for quantum machine learning
2//!
3//! This module provides unified interfaces to all quantum simulators
4//! available in the QuantRS2 ecosystem, enabling seamless backend
5//! switching for quantum ML algorithms.
6
7use crate::error::{MLError, Result};
8use quantrs2_circuit::prelude::*;
9use quantrs2_core::prelude::*;
10use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
11use scirs2_core::Complex64;
12// GpuStateVectorSimulator import removed - not used in this file
13// The GPUBackend is a placeholder that doesn't use the actual GPU simulator yet
14use quantrs2_sim::prelude::{MPSSimulator, PauliString, StateVectorSimulator};
15use std::collections::HashMap;
16
17/// Dynamic circuit representation for trait objects
18#[derive(Debug, Clone)]
19pub enum DynamicCircuit {
20    Circuit1(Circuit<1>),
21    Circuit2(Circuit<2>),
22    Circuit4(Circuit<4>),
23    Circuit8(Circuit<8>),
24    Circuit16(Circuit<16>),
25    Circuit32(Circuit<32>),
26    Circuit64(Circuit<64>),
27}
28
29impl DynamicCircuit {
30    /// Create from a generic circuit
31    pub fn from_circuit<const N: usize>(circuit: Circuit<N>) -> Result<Self> {
32        match N {
33            1 => Ok(DynamicCircuit::Circuit1(unsafe {
34                std::mem::transmute(circuit)
35            })),
36            2 => Ok(DynamicCircuit::Circuit2(unsafe {
37                std::mem::transmute(circuit)
38            })),
39            4 => Ok(DynamicCircuit::Circuit4(unsafe {
40                std::mem::transmute(circuit)
41            })),
42            8 => Ok(DynamicCircuit::Circuit8(unsafe {
43                std::mem::transmute(circuit)
44            })),
45            16 => Ok(DynamicCircuit::Circuit16(unsafe {
46                std::mem::transmute(circuit)
47            })),
48            32 => Ok(DynamicCircuit::Circuit32(unsafe {
49                std::mem::transmute(circuit)
50            })),
51            64 => Ok(DynamicCircuit::Circuit64(unsafe {
52                std::mem::transmute(circuit)
53            })),
54            _ => Err(MLError::ValidationError(format!(
55                "Unsupported circuit size: {}",
56                N
57            ))),
58        }
59    }
60
61    /// Get the number of qubits
62    pub fn num_qubits(&self) -> usize {
63        match self {
64            DynamicCircuit::Circuit1(_) => 1,
65            DynamicCircuit::Circuit2(_) => 2,
66            DynamicCircuit::Circuit4(_) => 4,
67            DynamicCircuit::Circuit8(_) => 8,
68            DynamicCircuit::Circuit16(_) => 16,
69            DynamicCircuit::Circuit32(_) => 32,
70            DynamicCircuit::Circuit64(_) => 64,
71        }
72    }
73
74    /// Get the number of gates (placeholder implementation)
75    pub fn num_gates(&self) -> usize {
76        match self {
77            DynamicCircuit::Circuit1(c) => c.gates().len(),
78            DynamicCircuit::Circuit2(c) => c.gates().len(),
79            DynamicCircuit::Circuit4(c) => c.gates().len(),
80            DynamicCircuit::Circuit8(c) => c.gates().len(),
81            DynamicCircuit::Circuit16(c) => c.gates().len(),
82            DynamicCircuit::Circuit32(c) => c.gates().len(),
83            DynamicCircuit::Circuit64(c) => c.gates().len(),
84        }
85    }
86
87    /// Get circuit depth (placeholder implementation)
88    pub fn depth(&self) -> usize {
89        // Simplified depth calculation - just return number of gates for now
90        self.num_gates()
91    }
92
93    /// Get gates (placeholder implementation)
94    pub fn gates(&self) -> Vec<&dyn quantrs2_core::gate::GateOp> {
95        match self {
96            DynamicCircuit::Circuit1(c) => c
97                .gates()
98                .iter()
99                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
100                .collect(),
101            DynamicCircuit::Circuit2(c) => c
102                .gates()
103                .iter()
104                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
105                .collect(),
106            DynamicCircuit::Circuit4(c) => c
107                .gates()
108                .iter()
109                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
110                .collect(),
111            DynamicCircuit::Circuit8(c) => c
112                .gates()
113                .iter()
114                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
115                .collect(),
116            DynamicCircuit::Circuit16(c) => c
117                .gates()
118                .iter()
119                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
120                .collect(),
121            DynamicCircuit::Circuit32(c) => c
122                .gates()
123                .iter()
124                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
125                .collect(),
126            DynamicCircuit::Circuit64(c) => c
127                .gates()
128                .iter()
129                .map(|g| g.as_ref() as &dyn quantrs2_core::gate::GateOp)
130                .collect(),
131        }
132    }
133}
134
135/// Unified simulator backend interface
136pub trait SimulatorBackend: Send + Sync {
137    /// Execute a quantum circuit
138    fn execute_circuit(
139        &self,
140        circuit: &DynamicCircuit,
141        parameters: &[f64],
142        shots: Option<usize>,
143    ) -> Result<SimulationResult>;
144
145    /// Compute expectation value
146    fn expectation_value(
147        &self,
148        circuit: &DynamicCircuit,
149        parameters: &[f64],
150        observable: &Observable,
151    ) -> Result<f64>;
152
153    /// Compute gradients using backend-specific methods
154    fn compute_gradients(
155        &self,
156        circuit: &DynamicCircuit,
157        parameters: &[f64],
158        observable: &Observable,
159        gradient_method: GradientMethod,
160    ) -> Result<Array1<f64>>;
161
162    /// Get backend capabilities
163    fn capabilities(&self) -> BackendCapabilities;
164
165    /// Get backend name
166    fn name(&self) -> &str;
167
168    /// Maximum number of qubits supported
169    fn max_qubits(&self) -> usize;
170
171    /// Check if backend supports noise simulation
172    fn supports_noise(&self) -> bool;
173}
174
175/// Simulation result containing various outputs
176#[derive(Debug, Clone)]
177pub struct SimulationResult {
178    /// Final quantum state (if available)
179    pub state: Option<Array1<Complex64>>,
180    /// Measurement outcomes
181    pub measurements: Option<Array1<usize>>,
182    /// Measurement probabilities
183    pub probabilities: Option<Array1<f64>>,
184    /// Execution metadata
185    pub metadata: HashMap<String, f64>,
186}
187
188/// Observable for expectation value computations
189#[derive(Debug, Clone)]
190pub enum Observable {
191    /// Pauli string observable
192    PauliString(PauliString),
193    /// Pauli Z on specified qubits
194    PauliZ(Vec<usize>),
195    /// Custom Hermitian matrix
196    Matrix(Array2<Complex64>),
197    /// Hamiltonian as sum of Pauli strings
198    Hamiltonian(Vec<(f64, PauliString)>),
199}
200
201/// Gradient computation methods
202#[derive(Debug, Clone, Copy)]
203pub enum GradientMethod {
204    /// Parameter shift rule
205    ParameterShift,
206    /// Finite differences
207    FiniteDifference,
208    /// Adjoint differentiation (if supported)
209    Adjoint,
210    /// Stochastic parameter shift
211    StochasticParameterShift,
212}
213
214/// Backend capabilities
215#[derive(Debug, Clone, Default)]
216pub struct BackendCapabilities {
217    /// Maximum qubits
218    pub max_qubits: usize,
219    /// Supports noise simulation
220    pub noise_simulation: bool,
221    /// Supports GPU acceleration
222    pub gpu_acceleration: bool,
223    /// Supports distributed computation
224    pub distributed: bool,
225    /// Supports adjoint gradients
226    pub adjoint_gradients: bool,
227    /// Memory requirements per qubit (bytes)
228    pub memory_per_qubit: usize,
229}
230
231/// Statevector simulator backend
232#[derive(Debug)]
233pub struct StatevectorBackend {
234    /// Internal simulator
235    simulator: StateVectorSimulator,
236    /// Maximum qubits
237    max_qubits: usize,
238}
239
240impl StatevectorBackend {
241    /// Create new statevector backend
242    pub fn new(max_qubits: usize) -> Self {
243        Self {
244            simulator: StateVectorSimulator::new(),
245            max_qubits,
246        }
247    }
248}
249
250impl SimulatorBackend for StatevectorBackend {
251    fn execute_circuit(
252        &self,
253        circuit: &DynamicCircuit,
254        _parameters: &[f64],
255        _shots: Option<usize>,
256    ) -> Result<SimulationResult> {
257        match circuit {
258            DynamicCircuit::Circuit1(c) => {
259                let state = self.simulator.run(c)?;
260                let probabilities = state
261                    .amplitudes()
262                    .iter()
263                    .map(|c| c.norm_sqr())
264                    .collect::<Vec<_>>();
265                Ok(SimulationResult {
266                    state: None, // TODO: Convert Register to Array
267                    measurements: None,
268                    probabilities: Some(probabilities.into()),
269                    metadata: HashMap::new(),
270                })
271            }
272            DynamicCircuit::Circuit2(c) => {
273                let state = self.simulator.run(c)?;
274                let probabilities = state
275                    .amplitudes()
276                    .iter()
277                    .map(|c| c.norm_sqr())
278                    .collect::<Vec<_>>();
279                Ok(SimulationResult {
280                    state: None, // TODO: Convert Register to Array
281                    measurements: None,
282                    probabilities: Some(probabilities.into()),
283                    metadata: HashMap::new(),
284                })
285            }
286            // Add other circuit sizes as needed
287            _ => Err(MLError::ValidationError(
288                "Unsupported circuit size".to_string(),
289            )),
290        }
291    }
292
293    fn expectation_value(
294        &self,
295        circuit: &DynamicCircuit,
296        _parameters: &[f64],
297        observable: &Observable,
298    ) -> Result<f64> {
299        match circuit {
300            DynamicCircuit::Circuit1(c) => {
301                let _state = self.simulator.run(c)?;
302                // TODO: Convert Register to Array for expectation computation
303                Ok(0.0)
304            }
305            DynamicCircuit::Circuit2(c) => {
306                let _state = self.simulator.run(c)?;
307                // TODO: Convert Register to Array for expectation computation
308                Ok(0.0)
309            }
310            // Add other circuit sizes as needed
311            _ => Err(MLError::ValidationError(
312                "Unsupported circuit size".to_string(),
313            )),
314        }
315    }
316
317    fn compute_gradients(
318        &self,
319        circuit: &DynamicCircuit,
320        _parameters: &[f64],
321        _observable: &Observable,
322        _gradient_method: GradientMethod,
323    ) -> Result<Array1<f64>> {
324        // Placeholder implementation
325        match circuit {
326            DynamicCircuit::Circuit1(_) => Ok(Array1::zeros(1)),
327            DynamicCircuit::Circuit2(_) => Ok(Array1::zeros(1)),
328            _ => Err(MLError::ValidationError(
329                "Unsupported circuit size".to_string(),
330            )),
331        }
332    }
333
334    /// Get backend capabilities
335    fn capabilities(&self) -> BackendCapabilities {
336        self.capabilities()
337    }
338
339    /// Get backend name
340    fn name(&self) -> &str {
341        "StatevectorBackend"
342    }
343
344    /// Maximum number of qubits supported
345    fn max_qubits(&self) -> usize {
346        self.capabilities().max_qubits
347    }
348
349    /// Check if backend supports noise simulation
350    fn supports_noise(&self) -> bool {
351        self.capabilities().noise_simulation
352    }
353}
354
355impl StatevectorBackend {
356    /// Helper method to compute expectation values
357    fn compute_expectation(
358        &self,
359        state: &Array1<Complex64>,
360        observable: &Observable,
361    ) -> Result<f64> {
362        match observable {
363            Observable::PauliString(pauli) => {
364                // Placeholder implementation - compute expectation value manually
365                Ok(0.0) // TODO: Implement proper Pauli expectation value computation
366            }
367            Observable::PauliZ(_qubits) => {
368                // Placeholder implementation for Pauli Z expectation value
369                Ok(0.0) // TODO: Implement proper Pauli Z expectation value computation
370            }
371            Observable::Matrix(matrix) => {
372                // Compute <ψ|H|ψ>
373                let amplitudes = state;
374                let result = amplitudes
375                    .iter()
376                    .enumerate()
377                    .map(|(i, &amp)| {
378                        amplitudes
379                            .iter()
380                            .enumerate()
381                            .map(|(j, &amp2)| amp.conj() * matrix[[i, j]] * amp2)
382                            .sum::<Complex64>()
383                    })
384                    .sum::<Complex64>();
385                Ok(result.re)
386            }
387            Observable::Hamiltonian(terms) => {
388                let mut expectation = 0.0;
389                for (coeff, pauli) in terms {
390                    expectation += coeff * 0.0; // TODO: Implement proper Pauli expectation value
391                }
392                Ok(expectation)
393            }
394        }
395    }
396
397    fn max_qubits(&self) -> usize {
398        self.max_qubits
399    }
400
401    fn supports_noise(&self) -> bool {
402        false
403    }
404}
405
406/// Matrix Product State (MPS) simulator backend
407pub struct MPSBackend {
408    /// Internal MPS simulator
409    simulator: MPSSimulator,
410    /// Bond dimension
411    bond_dimension: usize,
412    /// Maximum qubits
413    max_qubits: usize,
414}
415
416impl MPSBackend {
417    /// Create new MPS backend
418    pub fn new(bond_dimension: usize, max_qubits: usize) -> Self {
419        Self {
420            simulator: MPSSimulator::new(bond_dimension),
421            bond_dimension,
422            max_qubits,
423        }
424    }
425}
426
427impl SimulatorBackend for MPSBackend {
428    fn execute_circuit(
429        &self,
430        circuit: &DynamicCircuit,
431        _parameters: &[f64],
432        _shots: Option<usize>,
433    ) -> Result<SimulationResult> {
434        // MPS implementation depends on circuit size
435        match circuit {
436            DynamicCircuit::Circuit1(c) => {
437                // For small circuits, use basic MPS simulation
438                Ok(SimulationResult {
439                    state: None, // MPS doesn't expose full state
440                    measurements: None,
441                    probabilities: None,
442                    metadata: {
443                        let mut meta = HashMap::new();
444                        meta.insert("bond_dimension".to_string(), self.bond_dimension as f64);
445                        meta.insert("num_qubits".to_string(), 1.0);
446                        meta
447                    },
448                })
449            }
450            DynamicCircuit::Circuit2(c) => Ok(SimulationResult {
451                state: None,
452                measurements: None,
453                probabilities: None,
454                metadata: {
455                    let mut meta = HashMap::new();
456                    meta.insert("bond_dimension".to_string(), self.bond_dimension as f64);
457                    meta.insert("num_qubits".to_string(), 2.0);
458                    meta
459                },
460            }),
461            _ => {
462                // For larger circuits, need proper MPS simulation
463                Ok(SimulationResult {
464                    state: None,
465                    measurements: None,
466                    probabilities: None,
467                    metadata: {
468                        let mut meta = HashMap::new();
469                        meta.insert("bond_dimension".to_string(), self.bond_dimension as f64);
470                        meta.insert("num_qubits".to_string(), circuit.num_qubits() as f64);
471                        meta
472                    },
473                })
474            }
475        }
476    }
477
478    fn expectation_value(
479        &self,
480        circuit: &DynamicCircuit,
481        _parameters: &[f64],
482        observable: &Observable,
483    ) -> Result<f64> {
484        match observable {
485            Observable::PauliString(_pauli) => {
486                // Would compute expectation using MPS for any circuit size
487                Ok(0.0) // Placeholder implementation
488            }
489            Observable::PauliZ(_qubits) => {
490                // Would compute Z expectation using MPS
491                Ok(0.0) // Placeholder implementation
492            }
493            Observable::Hamiltonian(terms) => {
494                let mut expectation = 0.0;
495                for (coeff, _pauli) in terms {
496                    // Would compute each term using MPS
497                    expectation += coeff * 0.0; // Placeholder
498                }
499                Ok(expectation)
500            }
501            Observable::Matrix(_) => Err(MLError::NotSupported(
502                "Matrix observables not supported for MPS backend".to_string(),
503            )),
504        }
505    }
506
507    fn compute_gradients(
508        &self,
509        circuit: &DynamicCircuit,
510        parameters: &[f64],
511        observable: &Observable,
512        gradient_method: GradientMethod,
513    ) -> Result<Array1<f64>> {
514        match gradient_method {
515            GradientMethod::ParameterShift => {
516                self.parameter_shift_gradients_dynamic(circuit, parameters, observable)
517            }
518            _ => Err(MLError::NotSupported(
519                "Only parameter shift gradients supported for MPS backend".to_string(),
520            )),
521        }
522    }
523
524    fn capabilities(&self) -> BackendCapabilities {
525        BackendCapabilities {
526            max_qubits: self.max_qubits,
527            noise_simulation: false,
528            gpu_acceleration: false,
529            distributed: false,
530            adjoint_gradients: false,
531            memory_per_qubit: self.bond_dimension * self.bond_dimension * 16, // D^2 * 16 bytes
532        }
533    }
534
535    fn name(&self) -> &str {
536        "mps"
537    }
538
539    fn max_qubits(&self) -> usize {
540        self.max_qubits
541    }
542
543    fn supports_noise(&self) -> bool {
544        false
545    }
546}
547
548impl MPSBackend {
549    fn parameter_shift_gradients_dynamic(
550        &self,
551        circuit: &DynamicCircuit,
552        parameters: &[f64],
553        observable: &Observable,
554    ) -> Result<Array1<f64>> {
555        let shift = std::f64::consts::PI / 2.0;
556        let mut gradients = Array1::zeros(parameters.len());
557
558        for i in 0..parameters.len() {
559            let mut params_plus = parameters.to_vec();
560            params_plus[i] += shift;
561            let val_plus = self.expectation_value(circuit, &params_plus, observable)?;
562
563            let mut params_minus = parameters.to_vec();
564            params_minus[i] -= shift;
565            let val_minus = self.expectation_value(circuit, &params_minus, observable)?;
566
567            gradients[i] = (val_plus - val_minus) / 2.0;
568        }
569
570        Ok(gradients)
571    }
572}
573
574// GPU backend is now implemented in gpu_backend_impl module
575#[cfg(feature = "gpu")]
576pub use crate::gpu_backend_impl::GPUBackend;
577
578// SimulatorBackend implementation for GPUBackend is in gpu_backend_impl.rs
579
580/// Enum for different backend types (avoids dyn compatibility issues)
581pub enum Backend {
582    Statevector(StatevectorBackend),
583    MPS(MPSBackend),
584    #[cfg(feature = "gpu")]
585    GPU(GPUBackend),
586}
587
588impl SimulatorBackend for Backend {
589    fn execute_circuit(
590        &self,
591        circuit: &DynamicCircuit,
592        parameters: &[f64],
593        shots: Option<usize>,
594    ) -> Result<SimulationResult> {
595        match self {
596            Backend::Statevector(backend) => backend.execute_circuit(circuit, parameters, shots),
597            Backend::MPS(backend) => backend.execute_circuit(circuit, parameters, shots),
598            #[cfg(feature = "gpu")]
599            Backend::GPU(backend) => backend.execute_circuit(circuit, parameters, shots),
600        }
601    }
602
603    fn expectation_value(
604        &self,
605        circuit: &DynamicCircuit,
606        parameters: &[f64],
607        observable: &Observable,
608    ) -> Result<f64> {
609        match self {
610            Backend::Statevector(backend) => {
611                backend.expectation_value(circuit, parameters, observable)
612            }
613            Backend::MPS(backend) => backend.expectation_value(circuit, parameters, observable),
614            #[cfg(feature = "gpu")]
615            Backend::GPU(backend) => backend.expectation_value(circuit, parameters, observable),
616        }
617    }
618
619    fn compute_gradients(
620        &self,
621        circuit: &DynamicCircuit,
622        parameters: &[f64],
623        observable: &Observable,
624        gradient_method: GradientMethod,
625    ) -> Result<Array1<f64>> {
626        match self {
627            Backend::Statevector(backend) => {
628                backend.compute_gradients(circuit, parameters, observable, gradient_method)
629            }
630            Backend::MPS(backend) => {
631                backend.compute_gradients(circuit, parameters, observable, gradient_method)
632            }
633            #[cfg(feature = "gpu")]
634            Backend::GPU(backend) => {
635                backend.compute_gradients(circuit, parameters, observable, gradient_method)
636            }
637        }
638    }
639
640    fn capabilities(&self) -> BackendCapabilities {
641        match self {
642            Backend::Statevector(backend) => backend.capabilities(),
643            Backend::MPS(backend) => backend.capabilities(),
644            #[cfg(feature = "gpu")]
645            Backend::GPU(backend) => backend.capabilities(),
646        }
647    }
648
649    fn name(&self) -> &str {
650        match self {
651            Backend::Statevector(backend) => backend.name(),
652            Backend::MPS(backend) => backend.name(),
653            #[cfg(feature = "gpu")]
654            Backend::GPU(backend) => backend.name(),
655        }
656    }
657
658    fn max_qubits(&self) -> usize {
659        match self {
660            Backend::Statevector(backend) => backend.max_qubits(),
661            Backend::MPS(backend) => backend.max_qubits(),
662            #[cfg(feature = "gpu")]
663            Backend::GPU(backend) => backend.max_qubits(),
664        }
665    }
666
667    fn supports_noise(&self) -> bool {
668        match self {
669            Backend::Statevector(backend) => backend.supports_noise(),
670            Backend::MPS(backend) => backend.supports_noise(),
671            #[cfg(feature = "gpu")]
672            Backend::GPU(backend) => backend.supports_noise(),
673        }
674    }
675}
676
677/// Backend manager for automatic backend selection
678pub struct BackendManager {
679    /// Available backends
680    backends: HashMap<String, Backend>,
681    /// Current backend
682    current_backend: Option<String>,
683    /// Backend selection strategy
684    selection_strategy: BackendSelectionStrategy,
685}
686
687/// Backend selection strategies
688#[derive(Debug, Clone)]
689pub enum BackendSelectionStrategy {
690    /// Use fastest backend for given problem size
691    Fastest,
692    /// Use most memory-efficient backend
693    MemoryEfficient,
694    /// Use most accurate backend
695    MostAccurate,
696    /// User-specified backend
697    Manual(String),
698}
699
700impl BackendManager {
701    /// Create a new backend manager
702    pub fn new() -> Self {
703        Self {
704            backends: HashMap::new(),
705            current_backend: None,
706            selection_strategy: BackendSelectionStrategy::Fastest,
707        }
708    }
709
710    /// Register a backend
711    pub fn register_backend(&mut self, name: impl Into<String>, backend: Backend) {
712        self.backends.insert(name.into(), backend);
713    }
714
715    /// Set selection strategy
716    pub fn set_strategy(&mut self, strategy: BackendSelectionStrategy) {
717        self.selection_strategy = strategy;
718    }
719
720    /// Select optimal backend for given problem
721    pub fn select_backend(&mut self, num_qubits: usize, shots: Option<usize>) -> Result<()> {
722        let backend_name = match &self.selection_strategy {
723            BackendSelectionStrategy::Fastest => self.select_fastest_backend(num_qubits, shots)?,
724            BackendSelectionStrategy::MemoryEfficient => {
725                self.select_memory_efficient_backend(num_qubits)?
726            }
727            BackendSelectionStrategy::MostAccurate => {
728                self.select_most_accurate_backend(num_qubits)?
729            }
730            BackendSelectionStrategy::Manual(name) => name.clone(),
731        };
732
733        self.current_backend = Some(backend_name);
734        Ok(())
735    }
736
737    /// Execute circuit using selected backend
738    pub fn execute_circuit<const N: usize>(
739        &self,
740        circuit: &Circuit<N>,
741        parameters: &[f64],
742        shots: Option<usize>,
743    ) -> Result<SimulationResult> {
744        if let Some(ref backend_name) = self.current_backend {
745            if let Some(backend) = self.backends.get(backend_name) {
746                let dynamic_circuit = DynamicCircuit::from_circuit(circuit.clone())?;
747                backend.execute_circuit(&dynamic_circuit, parameters, shots)
748            } else {
749                Err(MLError::InvalidConfiguration(format!(
750                    "Backend '{}' not found",
751                    backend_name
752                )))
753            }
754        } else {
755            Err(MLError::InvalidConfiguration(
756                "No backend selected".to_string(),
757            ))
758        }
759    }
760
761    /// Get current backend
762    pub fn current_backend(&self) -> Option<&Backend> {
763        self.current_backend
764            .as_ref()
765            .and_then(|name| self.backends.get(name))
766    }
767
768    /// List available backends
769    pub fn list_backends(&self) -> Vec<(String, BackendCapabilities)> {
770        self.backends
771            .iter()
772            .map(|(name, backend)| (name.clone(), backend.capabilities()))
773            .collect()
774    }
775
776    fn select_fastest_backend(&self, num_qubits: usize, _shots: Option<usize>) -> Result<String> {
777        // Simple heuristic: GPU for large circuits, MPS for very large, statevector for small
778        if num_qubits <= 20 {
779            Ok("statevector".to_string())
780        } else if num_qubits <= 50 && self.backends.contains_key("gpu") {
781            Ok("gpu".to_string())
782        } else if self.backends.contains_key("mps") {
783            Ok("mps".to_string())
784        } else {
785            Err(MLError::InvalidConfiguration(
786                "No suitable backend for problem size".to_string(),
787            ))
788        }
789    }
790
791    fn select_memory_efficient_backend(&self, num_qubits: usize) -> Result<String> {
792        if num_qubits > 30 && self.backends.contains_key("mps") {
793            Ok("mps".to_string())
794        } else {
795            Ok("statevector".to_string())
796        }
797    }
798
799    fn select_most_accurate_backend(&self, _num_qubits: usize) -> Result<String> {
800        // Statevector is most accurate
801        Ok("statevector".to_string())
802    }
803}
804
805/// Helper functions for backend management
806pub mod backend_utils {
807    use super::*;
808
809    /// Create default backend manager with all available backends
810    pub fn create_default_manager() -> BackendManager {
811        let mut manager = BackendManager::new();
812
813        // Register statevector backend
814        manager.register_backend(
815            "statevector",
816            Backend::Statevector(StatevectorBackend::new(25)),
817        );
818
819        // Register MPS backend
820        manager.register_backend("mps", Backend::MPS(MPSBackend::new(64, 100)));
821
822        // Register GPU backend if available
823        #[cfg(feature = "gpu")]
824        {
825            if let Ok(gpu_backend) = GPUBackend::new(0, 30) {
826                manager.register_backend("gpu", Backend::GPU(gpu_backend));
827            }
828        }
829
830        manager
831    }
832
833    /// Benchmark backends for given problem
834    pub fn benchmark_backends<const N: usize>(
835        manager: &BackendManager,
836        circuit: &Circuit<N>,
837        parameters: &[f64],
838    ) -> Result<HashMap<String, f64>> {
839        let mut results = HashMap::new();
840
841        for (backend_name, _) in manager.list_backends() {
842            let start = std::time::Instant::now();
843
844            // Would execute circuit multiple times for accurate timing
845            let _result = manager.execute_circuit(circuit, parameters, None)?;
846
847            let duration = start.elapsed().as_secs_f64();
848            results.insert(backend_name, duration);
849        }
850
851        Ok(results)
852    }
853}
854
855#[cfg(test)]
856mod tests {
857    use super::*;
858
859    #[test]
860    #[ignore]
861    fn test_statevector_backend() {
862        let backend = StatevectorBackend::new(10);
863        assert_eq!(backend.name(), "statevector");
864        assert_eq!(backend.max_qubits(), 10);
865        assert!(!backend.supports_noise());
866    }
867
868    #[test]
869    fn test_mps_backend() {
870        let backend = MPSBackend::new(64, 50);
871        assert_eq!(backend.name(), "mps");
872        assert_eq!(backend.max_qubits(), 50);
873
874        let caps = backend.capabilities();
875        assert!(!caps.adjoint_gradients);
876        assert!(!caps.gpu_acceleration);
877    }
878
879    #[test]
880    #[ignore] // Temporarily disabled due to stack overflow issue
881    fn test_backend_manager() {
882        let mut manager = BackendManager::new();
883        manager.register_backend("test", Backend::Statevector(StatevectorBackend::new(10)));
884
885        let backends = manager.list_backends();
886        assert_eq!(backends.len(), 1);
887        assert_eq!(backends[0].0, "test");
888    }
889
890    #[test]
891    fn test_backend_selection() {
892        let mut manager = backend_utils::create_default_manager();
893        manager.set_strategy(BackendSelectionStrategy::Fastest);
894
895        let result = manager.select_backend(15, None);
896        assert!(result.is_ok());
897
898        let result = manager.select_backend(35, None);
899        assert!(result.is_ok());
900    }
901}