quantrs2_ml/torchquantum/
mod.rs

1//! TorchQuantum-compatible API for quantum machine learning
2//!
3//! This module provides a Pure Rust implementation compatible with TorchQuantum's API,
4//! enabling seamless migration from PyTorch-based quantum ML workflows.
5//!
6//! ## Key Features
7//!
8//! - **QuantumModule**: Base trait for quantum modules (similar to PyTorch's nn.Module)
9//! - **QuantumDevice**: Quantum state vector container with batch support
10//! - **Operators**: Parameterized quantum gates with automatic differentiation support
11//! - **Encoders**: Various encoding schemes (angle, amplitude, phase)
12//! - **Measurements**: Expectation values, sampling, and observable measurements
13//! - **Layers**: Pre-built quantum layer templates (Barren, Farhi, Maxwell, etc.)
14//!
15//! ## TorchQuantum Compatibility
16//!
17//! This module mirrors TorchQuantum's API patterns:
18//! - `tq.QuantumModule` → `TQModule`
19//! - `tq.QuantumDevice` → `TQDevice`
20//! - `tq.Operator` → `TQOperator`
21//! - `tq.encoding.*` → `encoding::*`
22//! - `tq.measurement.*` → `measurement::*`
23
24use crate::error::{MLError, Result};
25use scirs2_core::ndarray::{Array1, Array2, ArrayD, IxDyn};
26use scirs2_core::Complex64;
27use std::f64::consts::PI;
28
29// Sub-modules
30pub mod ansatz;
31pub mod autograd;
32pub mod conv;
33pub mod encoding;
34pub mod functional;
35pub mod gates;
36pub mod layer;
37pub mod measurement;
38pub mod noise;
39pub mod pooling;
40pub mod tensor_network;
41
42// ============================================================================
43// Core Types and Constants
44// ============================================================================
45
46/// Complex data type for quantum states (matches TorchQuantum's C_DTYPE)
47pub type CType = Complex64;
48
49/// Float data type for parameters (matches TorchQuantum's F_DTYPE)
50pub type FType = f64;
51
52/// Wire enumeration for operations
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum WiresEnum {
55    /// Operation applies to any wires
56    AnyWires,
57    /// Operation applies to all wires
58    AllWires,
59    /// Operation applies to specific number of wires
60    Fixed(usize),
61}
62
63/// Number of parameters enumeration
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub enum NParamsEnum {
66    /// Any number of parameters
67    AnyNParams,
68    /// Fixed number of parameters
69    Fixed(usize),
70}
71
72// ============================================================================
73// TQModule Trait - Core abstraction for quantum modules
74// ============================================================================
75
76/// Base trait for all TorchQuantum-compatible quantum modules
77///
78/// This trait mirrors TorchQuantum's `QuantumModule` class, providing:
79/// - Forward pass execution
80/// - Parameter management
81/// - Static/dynamic mode switching
82/// - Noise model support
83pub trait TQModule: Send + Sync {
84    /// Execute the forward pass on the quantum device
85    fn forward(&mut self, qdev: &mut TQDevice) -> Result<()>;
86
87    /// Execute forward pass with optional input data (for encoders)
88    fn forward_with_input(&mut self, qdev: &mut TQDevice, _x: Option<&Array2<f64>>) -> Result<()> {
89        self.forward(qdev)
90    }
91
92    /// Get all trainable parameters
93    fn parameters(&self) -> Vec<TQParameter>;
94
95    /// Get number of wires this module operates on
96    fn n_wires(&self) -> Option<usize>;
97
98    /// Set number of wires
99    fn set_n_wires(&mut self, n_wires: usize);
100
101    /// Check if module is in static mode
102    fn is_static_mode(&self) -> bool;
103
104    /// Enable static mode for graph optimization
105    fn static_on(&mut self);
106
107    /// Disable static mode
108    fn static_off(&mut self);
109
110    /// Get the unitary matrix representation (if available)
111    fn get_unitary(&self) -> Option<Array2<CType>> {
112        None
113    }
114
115    /// Module name for debugging
116    fn name(&self) -> &str;
117
118    /// Zero gradients of all parameters
119    fn zero_grad(&mut self) {
120        // Default implementation - override for modules with parameters
121    }
122
123    /// Set training mode
124    fn train(&mut self, _mode: bool) {
125        // Default implementation
126    }
127
128    /// Check if in training mode
129    fn training(&self) -> bool {
130        true
131    }
132}
133
134// ============================================================================
135// TQParameter - Trainable parameter wrapper
136// ============================================================================
137
138/// Quantum parameter wrapper (similar to TorchQuantum's parameter handling)
139#[derive(Debug, Clone)]
140pub struct TQParameter {
141    /// Parameter values
142    pub data: ArrayD<f64>,
143    /// Parameter name
144    pub name: String,
145    /// Whether parameter requires gradient
146    pub requires_grad: bool,
147    /// Gradient values (if computed)
148    pub grad: Option<ArrayD<f64>>,
149}
150
151impl TQParameter {
152    /// Create new trainable parameter
153    pub fn new(data: ArrayD<f64>, name: impl Into<String>) -> Self {
154        Self {
155            data,
156            name: name.into(),
157            requires_grad: true,
158            grad: None,
159        }
160    }
161
162    /// Create parameter without gradients
163    pub fn no_grad(data: ArrayD<f64>, name: impl Into<String>) -> Self {
164        Self {
165            data,
166            name: name.into(),
167            requires_grad: false,
168            grad: None,
169        }
170    }
171
172    /// Get parameter shape
173    pub fn shape(&self) -> &[usize] {
174        self.data.shape()
175    }
176
177    /// Get number of elements
178    pub fn numel(&self) -> usize {
179        self.data.len()
180    }
181
182    /// Zero the gradient
183    pub fn zero_grad(&mut self) {
184        self.grad = None;
185    }
186
187    /// Initialize with uniform random values in [-pi, pi]
188    pub fn init_uniform_pi(&mut self) {
189        for elem in self.data.iter_mut() {
190            *elem = (fastrand::f64() * 2.0 - 1.0) * PI;
191        }
192    }
193
194    /// Initialize with constant value
195    pub fn init_constant(&mut self, value: f64) {
196        for elem in self.data.iter_mut() {
197            *elem = value;
198        }
199    }
200}
201
202// ============================================================================
203// TQDevice - Quantum device with state vector
204// ============================================================================
205
206/// Quantum device containing the quantum state vector
207///
208/// This struct mirrors TorchQuantum's `QuantumDevice` class, providing:
209/// - Multi-dimensional state tensor representation
210/// - Batch support for parallel execution
211/// - State reset and cloning operations
212#[derive(Debug, Clone)]
213pub struct TQDevice {
214    /// Number of qubits
215    pub n_wires: usize,
216    /// Device name
217    pub device_name: String,
218    /// Batch size
219    pub bsz: usize,
220    /// Quantum state vector (batched, multi-dimensional)
221    pub states: ArrayD<CType>,
222    /// Whether to record operations
223    pub record_op: bool,
224    /// Operation history
225    pub op_history: Vec<OpHistoryEntry>,
226}
227
228/// Operation history entry
229#[derive(Debug, Clone)]
230pub struct OpHistoryEntry {
231    /// Gate name
232    pub name: String,
233    /// Wires the operation acts on
234    pub wires: Vec<usize>,
235    /// Parameters (if any)
236    pub params: Option<Vec<f64>>,
237    /// Whether operation is inverse
238    pub inverse: bool,
239    /// Whether parameters are trainable
240    pub trainable: bool,
241}
242
243impl TQDevice {
244    /// Create new quantum device
245    pub fn new(n_wires: usize) -> Self {
246        Self::with_batch_size(n_wires, 1)
247    }
248
249    /// Create quantum device with batch size
250    pub fn with_batch_size(n_wires: usize, bsz: usize) -> Self {
251        // Initialize state vector |0...0>
252        let state_size = 1 << n_wires; // 2^n_wires
253        let mut state_data = vec![CType::new(0.0, 0.0); state_size * bsz];
254        // Set |0...0> amplitude to 1 for each batch
255        for b in 0..bsz {
256            state_data[b * state_size] = CType::new(1.0, 0.0);
257        }
258
259        // Shape: [bsz, 2, 2, ..., 2] (n_wires times)
260        let mut shape = vec![bsz];
261        shape.extend(vec![2; n_wires]);
262
263        let states = ArrayD::from_shape_vec(IxDyn(&shape), state_data)
264            .unwrap_or_else(|_| ArrayD::zeros(IxDyn(&shape)));
265
266        Self {
267            n_wires,
268            device_name: "default".to_string(),
269            bsz,
270            states,
271            record_op: false,
272            op_history: Vec::new(),
273        }
274    }
275
276    /// Reset to |0...0> state
277    pub fn reset_states(&mut self, bsz: usize) {
278        self.bsz = bsz;
279        let state_size = 1 << self.n_wires;
280        let mut state_data = vec![CType::new(0.0, 0.0); state_size * bsz];
281        for b in 0..bsz {
282            state_data[b * state_size] = CType::new(1.0, 0.0);
283        }
284
285        let mut shape = vec![bsz];
286        shape.extend(vec![2; self.n_wires]);
287        self.states = ArrayD::from_shape_vec(IxDyn(&shape), state_data)
288            .unwrap_or_else(|_| ArrayD::zeros(IxDyn(&shape)));
289    }
290
291    /// Reset to identity matrix (useful for computing unitaries)
292    pub fn reset_identity_states(&mut self) {
293        let state_size = 1 << self.n_wires;
294        self.bsz = state_size;
295
296        let mut state_data = vec![CType::new(0.0, 0.0); state_size * state_size];
297        // Set diagonal elements to 1
298        for i in 0..state_size {
299            state_data[i * state_size + i] = CType::new(1.0, 0.0);
300        }
301
302        let mut shape = vec![state_size];
303        shape.extend(vec![2; self.n_wires]);
304        self.states = ArrayD::from_shape_vec(IxDyn(&shape), state_data)
305            .unwrap_or_else(|_| ArrayD::zeros(IxDyn(&shape)));
306    }
307
308    /// Reset to equal superposition state
309    pub fn reset_all_eq_states(&mut self, bsz: usize) {
310        self.bsz = bsz;
311        let state_size = 1 << self.n_wires;
312        let amplitude = 1.0 / (state_size as f64).sqrt();
313        let state_data = vec![CType::new(amplitude, 0.0); state_size * bsz];
314
315        let mut shape = vec![bsz];
316        shape.extend(vec![2; self.n_wires]);
317        self.states = ArrayD::from_shape_vec(IxDyn(&shape), state_data)
318            .unwrap_or_else(|_| ArrayD::zeros(IxDyn(&shape)));
319    }
320
321    /// Clone states from another device
322    pub fn clone_states(&mut self, other: &TQDevice) {
323        self.states = other.states.clone();
324        self.bsz = other.bsz;
325    }
326
327    /// Set states directly
328    pub fn set_states(&mut self, states: ArrayD<CType>) {
329        self.bsz = states.shape()[0];
330        self.states = states;
331    }
332
333    /// Get states as 1D vectors (shape: [bsz, 2^n_wires])
334    pub fn get_states_1d(&self) -> Array2<CType> {
335        let state_size = 1 << self.n_wires;
336        let flat: Vec<CType> = self.states.iter().cloned().collect();
337        Array2::from_shape_vec((self.bsz, state_size), flat)
338            .unwrap_or_else(|_| Array2::zeros((self.bsz, state_size)))
339    }
340
341    /// Get probabilities (|amplitude|^2) as 1D vectors
342    pub fn get_probs_1d(&self) -> Array2<f64> {
343        let states_1d = self.get_states_1d();
344        states_1d.mapv(|c| c.norm_sqr())
345    }
346
347    /// Record an operation in history
348    pub fn record_operation(&mut self, entry: OpHistoryEntry) {
349        if self.record_op {
350            self.op_history.push(entry);
351        }
352    }
353
354    /// Clear operation history
355    pub fn reset_op_history(&mut self) {
356        self.op_history.clear();
357    }
358
359    /// Apply a single-qubit gate matrix to specified wire
360    pub fn apply_single_qubit_gate(&mut self, wire: usize, matrix: &Array2<CType>) -> Result<()> {
361        if wire >= self.n_wires {
362            return Err(MLError::InvalidConfiguration(format!(
363                "Wire {} out of range for {} qubits",
364                wire, self.n_wires
365            )));
366        }
367
368        let state_size = 1 << self.n_wires;
369        let states_1d = self.get_states_1d();
370        let mut new_states = states_1d.clone();
371
372        for batch in 0..self.bsz {
373            for i in 0..state_size {
374                // Find the pair of indices that differ only at position `wire`
375                let bit = (i >> (self.n_wires - 1 - wire)) & 1;
376                if bit == 0 {
377                    let j = i | (1 << (self.n_wires - 1 - wire));
378                    let amp0 = states_1d[[batch, i]];
379                    let amp1 = states_1d[[batch, j]];
380                    new_states[[batch, i]] = matrix[[0, 0]] * amp0 + matrix[[0, 1]] * amp1;
381                    new_states[[batch, j]] = matrix[[1, 0]] * amp0 + matrix[[1, 1]] * amp1;
382                }
383            }
384        }
385
386        // Reshape back to multi-dimensional
387        let flat: Vec<CType> = new_states.iter().cloned().collect();
388        let mut shape = vec![self.bsz];
389        shape.extend(vec![2; self.n_wires]);
390        self.states = ArrayD::from_shape_vec(IxDyn(&shape), flat)
391            .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
392
393        Ok(())
394    }
395
396    /// Apply a two-qubit gate matrix to specified wires
397    pub fn apply_two_qubit_gate(
398        &mut self,
399        wire0: usize,
400        wire1: usize,
401        matrix: &Array2<CType>,
402    ) -> Result<()> {
403        if wire0 >= self.n_wires || wire1 >= self.n_wires {
404            return Err(MLError::InvalidConfiguration(format!(
405                "Wires ({}, {}) out of range for {} qubits",
406                wire0, wire1, self.n_wires
407            )));
408        }
409
410        let state_size = 1 << self.n_wires;
411        let states_1d = self.get_states_1d();
412        let mut new_states = states_1d.clone();
413
414        let pos0 = self.n_wires - 1 - wire0;
415        let pos1 = self.n_wires - 1 - wire1;
416
417        for batch in 0..self.bsz {
418            let mut visited = vec![false; state_size];
419
420            for i in 0..state_size {
421                if visited[i] {
422                    continue;
423                }
424
425                // Get the 4 indices for the 2-qubit subspace
426                // Base index (both bits = 0)
427                let base = i & !(1 << pos0) & !(1 << pos1);
428
429                let indices = [
430                    base,                             // 00
431                    base | (1 << pos1),               // 01
432                    base | (1 << pos0),               // 10
433                    base | (1 << pos0) | (1 << pos1), // 11
434                ];
435
436                let amps: Vec<CType> = indices.iter().map(|&idx| states_1d[[batch, idx]]).collect();
437
438                for (row, &idx) in indices.iter().enumerate() {
439                    let mut new_amp = CType::new(0.0, 0.0);
440                    for (col, &amp) in amps.iter().enumerate() {
441                        new_amp += matrix[[row, col]] * amp;
442                    }
443                    new_states[[batch, idx]] = new_amp;
444                    visited[idx] = true;
445                }
446            }
447        }
448
449        // Reshape back
450        let flat: Vec<CType> = new_states.iter().cloned().collect();
451        let mut shape = vec![self.bsz];
452        shape.extend(vec![2; self.n_wires]);
453        self.states = ArrayD::from_shape_vec(IxDyn(&shape), flat)
454            .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
455
456        Ok(())
457    }
458
459    /// Apply a multi-qubit gate matrix to specified wires (n-qubit gate)
460    pub fn apply_multi_qubit_gate(
461        &mut self,
462        wires: &[usize],
463        matrix: &Array2<CType>,
464    ) -> Result<()> {
465        let n_qubits = wires.len();
466
467        // Validate wires
468        for &wire in wires {
469            if wire >= self.n_wires {
470                return Err(MLError::InvalidConfiguration(format!(
471                    "Wire {} out of range for {} qubits",
472                    wire, self.n_wires
473                )));
474            }
475        }
476
477        // Expected matrix dimension: 2^n_qubits x 2^n_qubits
478        let gate_dim = 1 << n_qubits;
479        if matrix.nrows() != gate_dim || matrix.ncols() != gate_dim {
480            return Err(MLError::InvalidConfiguration(format!(
481                "Gate matrix must be {}x{} for {}-qubit gate",
482                gate_dim, gate_dim, n_qubits
483            )));
484        }
485
486        let state_size = 1 << self.n_wires;
487        let states_1d = self.get_states_1d();
488        let mut new_states = states_1d.clone();
489
490        // Pre-compute bit positions for the wires (in reversed order for state indexing)
491        let positions: Vec<usize> = wires.iter().map(|&w| self.n_wires - 1 - w).collect();
492
493        // Create mask to identify which bits correspond to the gate qubits
494        let mut wire_mask: usize = 0;
495        for &pos in &positions {
496            wire_mask |= 1 << pos;
497        }
498
499        for batch in 0..self.bsz {
500            let mut visited = vec![false; state_size];
501
502            for base_idx in 0..state_size {
503                if visited[base_idx] {
504                    continue;
505                }
506
507                // Get base index with all gate qubit bits cleared
508                let base = base_idx & !wire_mask;
509
510                // Generate all 2^n indices for the gate subspace
511                let mut indices = Vec::with_capacity(gate_dim);
512                for gate_idx in 0..gate_dim {
513                    let mut idx = base;
514                    // Set bits according to gate_idx
515                    for (bit_pos, &pos) in positions.iter().enumerate() {
516                        if (gate_idx >> (n_qubits - 1 - bit_pos)) & 1 == 1 {
517                            idx |= 1 << pos;
518                        }
519                    }
520                    indices.push(idx);
521                }
522
523                // Get current amplitudes
524                let amps: Vec<CType> = indices.iter().map(|&idx| states_1d[[batch, idx]]).collect();
525
526                // Apply matrix
527                for (row, &idx) in indices.iter().enumerate() {
528                    let mut new_amp = CType::new(0.0, 0.0);
529                    for (col, &amp) in amps.iter().enumerate() {
530                        new_amp += matrix[[row, col]] * amp;
531                    }
532                    new_states[[batch, idx]] = new_amp;
533                    visited[idx] = true;
534                }
535            }
536        }
537
538        // Reshape back
539        let flat: Vec<CType> = new_states.iter().cloned().collect();
540        let mut shape = vec![self.bsz];
541        shape.extend(vec![2; self.n_wires]);
542        self.states = ArrayD::from_shape_vec(IxDyn(&shape), flat)
543            .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
544
545        Ok(())
546    }
547}
548
549// ============================================================================
550// TQOperator - Base quantum operator
551// ============================================================================
552
553/// Base quantum operator trait
554pub trait TQOperator: TQModule {
555    /// Number of wires this operator acts on
556    fn num_wires(&self) -> WiresEnum;
557
558    /// Number of parameters
559    fn num_params(&self) -> NParamsEnum;
560
561    /// Get the unitary matrix for given parameters
562    fn get_matrix(&self, params: Option<&[f64]>) -> Array2<CType>;
563
564    /// Get eigenvalues (if applicable)
565    fn get_eigvals(&self, _params: Option<&[f64]>) -> Option<Array1<CType>> {
566        None
567    }
568
569    /// Apply the operator to a quantum device
570    fn apply(&mut self, qdev: &mut TQDevice, wires: &[usize]) -> Result<()>;
571
572    /// Apply with specific parameters
573    fn apply_with_params(
574        &mut self,
575        qdev: &mut TQDevice,
576        wires: &[usize],
577        params: Option<&[f64]>,
578    ) -> Result<()>;
579
580    /// Whether this operator has trainable parameters
581    fn has_params(&self) -> bool;
582
583    /// Whether parameters are trainable
584    fn trainable(&self) -> bool;
585
586    /// Get/set inverse flag
587    fn inverse(&self) -> bool;
588    fn set_inverse(&mut self, inverse: bool);
589}
590
591// ============================================================================
592// TQModuleList - Container for modules
593// ============================================================================
594
595/// Container for a list of TQModules (similar to PyTorch's ModuleList)
596pub struct TQModuleList {
597    modules: Vec<Box<dyn TQModule>>,
598    static_mode: bool,
599}
600
601impl TQModuleList {
602    /// Create empty module list
603    pub fn new() -> Self {
604        Self {
605            modules: Vec::new(),
606            static_mode: false,
607        }
608    }
609
610    /// Add a module to the list
611    pub fn append(&mut self, module: Box<dyn TQModule>) {
612        self.modules.push(module);
613    }
614
615    /// Get number of modules
616    pub fn len(&self) -> usize {
617        self.modules.len()
618    }
619
620    /// Check if empty
621    pub fn is_empty(&self) -> bool {
622        self.modules.is_empty()
623    }
624
625    /// Get module at index
626    pub fn get(&self, index: usize) -> Option<&Box<dyn TQModule>> {
627        self.modules.get(index)
628    }
629
630    /// Get mutable module at index
631    pub fn get_mut(&mut self, index: usize) -> Option<&mut Box<dyn TQModule>> {
632        self.modules.get_mut(index)
633    }
634
635    /// Iterate over modules
636    pub fn iter(&self) -> impl Iterator<Item = &Box<dyn TQModule>> {
637        self.modules.iter()
638    }
639
640    /// Iterate mutably over modules
641    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Box<dyn TQModule>> {
642        self.modules.iter_mut()
643    }
644}
645
646impl Default for TQModuleList {
647    fn default() -> Self {
648        Self::new()
649    }
650}
651
652impl TQModule for TQModuleList {
653    fn forward(&mut self, qdev: &mut TQDevice) -> Result<()> {
654        for module in &mut self.modules {
655            module.forward(qdev)?;
656        }
657        Ok(())
658    }
659
660    fn parameters(&self) -> Vec<TQParameter> {
661        self.modules.iter().flat_map(|m| m.parameters()).collect()
662    }
663
664    fn n_wires(&self) -> Option<usize> {
665        self.modules.first().and_then(|m| m.n_wires())
666    }
667
668    fn set_n_wires(&mut self, n_wires: usize) {
669        for module in &mut self.modules {
670            module.set_n_wires(n_wires);
671        }
672    }
673
674    fn is_static_mode(&self) -> bool {
675        self.static_mode
676    }
677
678    fn static_on(&mut self) {
679        self.static_mode = true;
680        for module in &mut self.modules {
681            module.static_on();
682        }
683    }
684
685    fn static_off(&mut self) {
686        self.static_mode = false;
687        for module in &mut self.modules {
688            module.static_off();
689        }
690    }
691
692    fn name(&self) -> &str {
693        "ModuleList"
694    }
695
696    fn zero_grad(&mut self) {
697        for module in &mut self.modules {
698            module.zero_grad();
699        }
700    }
701}
702
703// ============================================================================
704// Prelude - Convenient re-exports
705// ============================================================================
706
707pub mod prelude {
708    //! Convenient re-exports for TorchQuantum-compatible API
709
710    pub use super::{
711        CType, FType, NParamsEnum, OpHistoryEntry, TQDevice, TQModule, TQModuleList, TQOperator,
712        TQParameter, WiresEnum,
713    };
714
715    // Gates
716    pub use super::gates::{
717        // Single-qubit gates
718        TQHadamard,
719        TQPauliX,
720        TQPauliY,
721        TQPauliZ,
722        TQRx,
723        TQRy,
724        TQRz,
725        // Two-qubit gates
726        TQCNOT,
727        // Controlled rotation gates
728        TQCRX,
729        TQCRY,
730        TQCRZ,
731        TQCZ,
732        // Parameterized two-qubit gates
733        TQRXX,
734        TQRYY,
735        TQRZX,
736        TQRZZ,
737        TQS,
738        TQSWAP,
739        TQSX,
740        TQT,
741        TQU1,
742        TQU2,
743        TQU3,
744    };
745
746    // Encoding
747    pub use super::encoding::{
748        EncodingOp, TQAmplitudeEncoder, TQEncoder, TQGeneralEncoder, TQPhaseEncoder, TQStateEncoder,
749    };
750
751    // Measurement
752    pub use super::measurement::{
753        expval_joint_analytical, expval_joint_sampling, gen_bitstrings, measure, TQMeasureAll,
754    };
755
756    // Layers
757    pub use super::layer::{
758        TQBarrenLayer, TQFarhiLayer, TQLayerConfig, TQMaxwellLayer, TQOp1QAllLayer, TQOp2QAllLayer,
759        TQRXYZCXLayer, TQSethLayer, TQStrongEntanglingLayer,
760    };
761
762    // Autograd
763    pub use super::autograd::{
764        gradient_norm, gradient_statistics, ClippingStatistics, ClippingStrategy,
765        GradientAccumulator, GradientCheckResult, GradientChecker, GradientClipper,
766        GradientStatistics, ParameterGroup, ParameterGroupManager, ParameterRegistry,
767        ParameterStatistics,
768    };
769
770    // Ansatz templates
771    pub use super::ansatz::{
772        EfficientSU2Layer, EntanglementPattern, RealAmplitudesLayer, TwoLocalLayer,
773    };
774
775    // Convolutional layers
776    pub use super::conv::{QConv1D, QConv2D};
777
778    // Pooling layers
779    pub use super::pooling::{QAvgPool, QMaxPool};
780
781    // Tensor network backend
782    pub use super::tensor_network::{
783        CompressionMethod, MPSTensor, MatrixProductState, TQTensorNetworkBackend,
784        TensorNetworkConfig,
785    };
786
787    // Noise-aware training
788    pub use super::noise::{
789        GateTimes, MitigatedExpectation, MitigatedExpectationConfig, MitigationMethod,
790        NoiseAwareGradient, NoiseAwareGradientConfig, NoiseAwareTrainer, NoiseModel,
791        SingleQubitNoiseType, TrainingHistory, TrainingStatistics, TwoQubitNoiseType,
792        VarianceReduction, ZNEExtrapolation,
793    };
794}
795
796// ============================================================================
797// Tests
798// ============================================================================
799
800#[cfg(test)]
801mod tests {
802    use super::prelude::*;
803    use std::f64::consts::PI;
804
805    #[test]
806    fn test_tq_device_creation() {
807        let qdev = TQDevice::new(4);
808        assert_eq!(qdev.n_wires, 4);
809        assert_eq!(qdev.bsz, 1);
810
811        // Check initial state is |0000>
812        let probs = qdev.get_probs_1d();
813        assert!((probs[[0, 0]] - 1.0).abs() < 1e-10);
814        for i in 1..(1 << 4) {
815            assert!(probs[[0, i]].abs() < 1e-10);
816        }
817    }
818
819    #[test]
820    fn test_tq_device_reset() {
821        let mut qdev = TQDevice::new(2);
822        qdev.reset_all_eq_states(1);
823
824        let probs = qdev.get_probs_1d();
825        let expected = 0.25; // 1/4 for 2 qubits
826        for i in 0..4 {
827            assert!((probs[[0, i]] - expected).abs() < 1e-10);
828        }
829    }
830
831    #[test]
832    fn test_tq_parameter() {
833        use scirs2_core::ndarray::ArrayD;
834
835        let mut param =
836            TQParameter::new(ArrayD::zeros(scirs2_core::ndarray::IxDyn(&[2, 3])), "test");
837        assert_eq!(param.shape(), &[2, 3]);
838        assert_eq!(param.numel(), 6);
839
840        param.init_constant(1.5);
841        for elem in param.data.iter() {
842            assert!((elem - 1.5).abs() < 1e-10);
843        }
844    }
845
846    #[test]
847    fn test_hadamard_gate() {
848        let mut qdev = TQDevice::new(1);
849        let mut h = TQHadamard::new();
850
851        h.apply(&mut qdev, &[0]).expect("Hadamard should succeed");
852
853        let probs = qdev.get_probs_1d();
854        assert!((probs[[0, 0]] - 0.5).abs() < 1e-10);
855        assert!((probs[[0, 1]] - 0.5).abs() < 1e-10);
856    }
857
858    #[test]
859    fn test_pauli_x_gate() {
860        let mut qdev = TQDevice::new(1);
861        let mut x = TQPauliX::new();
862
863        x.apply(&mut qdev, &[0]).expect("PauliX should succeed");
864
865        let probs = qdev.get_probs_1d();
866        assert!(probs[[0, 0]].abs() < 1e-10);
867        assert!((probs[[0, 1]] - 1.0).abs() < 1e-10);
868    }
869
870    #[test]
871    fn test_rx_gate() {
872        let mut qdev = TQDevice::new(1);
873        let mut rx = TQRx::new(true, false);
874
875        // RX(π) should be equivalent to X (up to global phase)
876        rx.apply_with_params(&mut qdev, &[0], Some(&[PI]))
877            .expect("RX should succeed");
878
879        let probs = qdev.get_probs_1d();
880        assert!(probs[[0, 0]].abs() < 1e-10);
881        assert!((probs[[0, 1]] - 1.0).abs() < 1e-10);
882    }
883
884    #[test]
885    fn test_cnot_gate() {
886        let mut qdev = TQDevice::new(2);
887        let mut x = TQPauliX::new();
888        let mut cnot = TQCNOT::new();
889
890        // Apply X to first qubit, then CNOT
891        x.apply(&mut qdev, &[0]).expect("X should succeed");
892        cnot.apply(&mut qdev, &[0, 1]).expect("CNOT should succeed");
893
894        let probs = qdev.get_probs_1d();
895        // Should be in |11> state
896        assert!(probs[[0, 0]].abs() < 1e-10); // |00>
897        assert!(probs[[0, 1]].abs() < 1e-10); // |01>
898        assert!(probs[[0, 2]].abs() < 1e-10); // |10>
899        assert!((probs[[0, 3]] - 1.0).abs() < 1e-10); // |11>
900    }
901
902    #[test]
903    fn test_bell_state() {
904        let mut qdev = TQDevice::new(2);
905        let mut h = TQHadamard::new();
906        let mut cnot = TQCNOT::new();
907
908        h.apply(&mut qdev, &[0]).expect("H should succeed");
909        cnot.apply(&mut qdev, &[0, 1]).expect("CNOT should succeed");
910
911        let probs = qdev.get_probs_1d();
912        // Bell state: (|00> + |11>)/sqrt(2)
913        assert!((probs[[0, 0]] - 0.5).abs() < 1e-10); // |00>
914        assert!(probs[[0, 1]].abs() < 1e-10); // |01>
915        assert!(probs[[0, 2]].abs() < 1e-10); // |10>
916        assert!((probs[[0, 3]] - 0.5).abs() < 1e-10); // |11>
917    }
918
919    #[test]
920    fn test_module_list() {
921        let mut qdev = TQDevice::new(2);
922        let mut module_list = TQModuleList::new();
923
924        module_list.append(Box::new(TQHadamard::new()));
925        module_list.append(Box::new(TQPauliX::new()));
926
927        assert_eq!(module_list.len(), 2);
928        assert!(!module_list.is_empty());
929    }
930}