quantrs2_ml/torchquantum/
tensor_network.rs

1//! TorchQuantum Tensor Network Backend
2//!
3//! This module provides tensor network simulation backend for TorchQuantum circuits,
4//! enabling efficient simulation of large circuits with limited entanglement.
5//!
6//! ## Key Features
7//!
8//! - **MPS Backend**: Matrix Product State representation for 1D circuits
9//! - **PEPS Backend**: Projected Entangled Pair States for 2D circuits
10//! - **Automatic Bond Dimension Management**: Adaptive truncation based on fidelity
11//! - **Gradient Support**: Tensor network compatible gradient computation
12
13use crate::error::{MLError, Result};
14use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
15use scirs2_core::Complex64;
16use std::collections::HashMap;
17
18use super::{CType, TQDevice, TQModule, TQParameter};
19
20// ============================================================================
21// Tensor Network Configuration
22// ============================================================================
23
24/// Configuration for tensor network simulation
25#[derive(Debug, Clone)]
26pub struct TensorNetworkConfig {
27    /// Maximum bond dimension for truncation
28    pub max_bond_dim: usize,
29    /// Truncation threshold (singular values below this are discarded)
30    pub truncation_threshold: f64,
31    /// Whether to use canonical form
32    pub use_canonical_form: bool,
33    /// Compression method
34    pub compression: CompressionMethod,
35}
36
37impl Default for TensorNetworkConfig {
38    fn default() -> Self {
39        Self {
40            max_bond_dim: 64,
41            truncation_threshold: 1e-12,
42            use_canonical_form: true,
43            compression: CompressionMethod::SVD,
44        }
45    }
46}
47
48/// Compression methods for tensor network
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum CompressionMethod {
51    /// Singular Value Decomposition
52    SVD,
53    /// QR decomposition
54    QR,
55    /// Variational compression
56    Variational,
57}
58
59// ============================================================================
60// MPS Tensor (Single Site)
61// ============================================================================
62
63/// Single site tensor in MPS representation
64///
65/// Shape: (bond_left, physical_dim, bond_right)
66#[derive(Debug, Clone)]
67pub struct MPSTensor {
68    /// Tensor data: shape (bond_left, physical_dim, bond_right)
69    pub data: Array3<CType>,
70    /// Site index
71    pub site: usize,
72}
73
74impl MPSTensor {
75    /// Create new MPS tensor
76    pub fn new(data: Array3<CType>, site: usize) -> Self {
77        Self { data, site }
78    }
79
80    /// Get bond dimensions
81    pub fn bond_dims(&self) -> (usize, usize) {
82        let shape = self.data.shape();
83        (shape[0], shape[2])
84    }
85
86    /// Get physical dimension
87    pub fn physical_dim(&self) -> usize {
88        self.data.shape()[1]
89    }
90
91    /// Contract with another tensor (right multiplication)
92    pub fn contract_right(&self, other: &MPSTensor) -> Array3<CType> {
93        let (d_left, phys_a, d_mid) = (
94            self.data.shape()[0],
95            self.data.shape()[1],
96            self.data.shape()[2],
97        );
98        let (_d_mid2, phys_b, d_right) = (
99            other.data.shape()[0],
100            other.data.shape()[1],
101            other.data.shape()[2],
102        );
103
104        // Contract over the shared bond dimension
105        let mut result = Array3::<CType>::zeros((d_left, phys_a * phys_b, d_right));
106
107        for i in 0..d_left {
108            for j in 0..phys_a {
109                for k in 0..d_mid {
110                    for l in 0..phys_b {
111                        for m in 0..d_right {
112                            let combined_phys = j * phys_b + l;
113                            result[[i, combined_phys, m]] +=
114                                self.data[[i, j, k]] * other.data[[k, l, m]];
115                        }
116                    }
117                }
118            }
119        }
120
121        result
122    }
123}
124
125// ============================================================================
126// Matrix Product State
127// ============================================================================
128
129/// Matrix Product State representation of quantum state
130#[derive(Debug, Clone)]
131pub struct MatrixProductState {
132    /// MPS tensors for each site
133    pub tensors: Vec<MPSTensor>,
134    /// Number of qubits
135    pub n_qubits: usize,
136    /// Configuration
137    pub config: TensorNetworkConfig,
138    /// Normalization factor
139    pub norm: f64,
140}
141
142impl MatrixProductState {
143    /// Create MPS from computational basis state (e.g., |00...0>)
144    pub fn from_computational_basis(n_qubits: usize, state: usize) -> Self {
145        let config = TensorNetworkConfig::default();
146        let mut tensors = Vec::with_capacity(n_qubits);
147
148        for site in 0..n_qubits {
149            // Each tensor is (1, 2, 1) for product states
150            let mut data = Array3::<CType>::zeros((1, 2, 1));
151            let bit = (state >> (n_qubits - 1 - site)) & 1;
152            data[[0, bit, 0]] = Complex64::new(1.0, 0.0);
153            tensors.push(MPSTensor::new(data, site));
154        }
155
156        Self {
157            tensors,
158            n_qubits,
159            config,
160            norm: 1.0,
161        }
162    }
163
164    /// Create MPS from TQDevice state
165    pub fn from_tq_device(qdev: &TQDevice) -> Result<Self> {
166        // Get state vector
167        let states = qdev.get_states_1d();
168        let state_vec: Vec<CType> = states.row(0).iter().cloned().collect();
169
170        Self::from_state_vector(&state_vec, qdev.n_wires)
171    }
172
173    /// Create MPS from state vector using SVD decomposition
174    pub fn from_state_vector(state_vec: &[CType], n_qubits: usize) -> Result<Self> {
175        let config = TensorNetworkConfig::default();
176        let dim = 1 << n_qubits;
177
178        if state_vec.len() != dim {
179            return Err(MLError::InvalidConfiguration(format!(
180                "State vector size {} doesn't match 2^{} = {}",
181                state_vec.len(),
182                n_qubits,
183                dim
184            )));
185        }
186
187        // For simplicity, use direct tensor construction for small systems
188        // For larger systems, use SVD decomposition
189        let mut tensors = Vec::with_capacity(n_qubits);
190
191        if n_qubits <= 4 {
192            // Direct construction for small systems
193            for site in 0..n_qubits {
194                let bond_left = 1.min(1 << site);
195                let bond_right = 1.min(1 << (n_qubits - site - 1));
196                let mut data = Array3::<CType>::zeros((bond_left, 2, bond_right));
197
198                // Fill tensor based on state amplitudes
199                for idx in 0..dim {
200                    let bit = (idx >> (n_qubits - 1 - site)) & 1;
201                    let left_idx = (idx >> (n_qubits - site)) % bond_left;
202                    let right_idx = idx % bond_right;
203                    data[[left_idx, bit, right_idx]] += state_vec[idx];
204                }
205
206                tensors.push(MPSTensor::new(data, site));
207            }
208        } else {
209            // SVD-based construction for larger systems
210            let mut remaining = Array2::<CType>::from_shape_vec((1, dim), state_vec.to_vec())
211                .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
212
213            for site in 0..n_qubits {
214                let rows = remaining.nrows();
215                let cols = remaining.ncols();
216                let new_cols = cols / 2;
217
218                // Clone and reshape to separate the physical index
219                let reshaped = remaining
220                    .clone()
221                    .into_shape_with_order((rows * 2, new_cols))
222                    .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
223
224                // For last site, no SVD needed
225                if site == n_qubits - 1 {
226                    let mut data = Array3::<CType>::zeros((rows, 2, 1));
227                    for i in 0..rows {
228                        for j in 0..2 {
229                            data[[i, j, 0]] = reshaped[[i * 2 + j, 0]];
230                        }
231                    }
232                    tensors.push(MPSTensor::new(data, site));
233                } else {
234                    // Simple truncation (for production, use proper SVD)
235                    let bond_dim = (rows * 2).min(config.max_bond_dim).min(new_cols);
236                    let mut data = Array3::<CType>::zeros((rows, 2, bond_dim));
237
238                    for i in 0..rows {
239                        for j in 0..2 {
240                            for k in 0..bond_dim {
241                                if i * 2 + j < rows * 2 && k < new_cols {
242                                    data[[i, j, k]] = reshaped[[i * 2 + j, k]];
243                                }
244                            }
245                        }
246                    }
247
248                    tensors.push(MPSTensor::new(data, site));
249
250                    // Prepare for next iteration
251                    remaining = Array2::<CType>::zeros((bond_dim, new_cols));
252                    for i in 0..bond_dim.min(rows * 2) {
253                        for j in 0..new_cols {
254                            remaining[[i.min(bond_dim - 1), j]] = reshaped[[i, j]];
255                        }
256                    }
257                }
258            }
259        }
260
261        Ok(Self {
262            tensors,
263            n_qubits,
264            config,
265            norm: 1.0,
266        })
267    }
268
269    /// Apply single-qubit gate to MPS
270    pub fn apply_single_qubit_gate(&mut self, site: usize, gate: &Array2<CType>) -> Result<()> {
271        if site >= self.n_qubits {
272            return Err(MLError::InvalidConfiguration(format!(
273                "Site {} out of range for {} qubits",
274                site, self.n_qubits
275            )));
276        }
277
278        let tensor = &mut self.tensors[site];
279        let (d_left, _phys, d_right) = (
280            tensor.data.shape()[0],
281            tensor.data.shape()[1],
282            tensor.data.shape()[2],
283        );
284
285        let mut new_data = Array3::<CType>::zeros((d_left, 2, d_right));
286
287        for i in 0..d_left {
288            for k in 0..d_right {
289                let old_0 = tensor.data[[i, 0, k]];
290                let old_1 = tensor.data[[i, 1, k]];
291                new_data[[i, 0, k]] = gate[[0, 0]] * old_0 + gate[[0, 1]] * old_1;
292                new_data[[i, 1, k]] = gate[[1, 0]] * old_0 + gate[[1, 1]] * old_1;
293            }
294        }
295
296        tensor.data = new_data;
297        Ok(())
298    }
299
300    /// Apply two-qubit gate to MPS (with truncation)
301    pub fn apply_two_qubit_gate(
302        &mut self,
303        site1: usize,
304        site2: usize,
305        gate: &Array2<CType>,
306    ) -> Result<()> {
307        // Ensure sites are adjacent for efficient application
308        if site1.abs_diff(site2) != 1 {
309            return Err(MLError::InvalidConfiguration(
310                "Two-qubit gates on non-adjacent sites require SWAP operations".to_string(),
311            ));
312        }
313
314        let (left_site, right_site) = if site1 < site2 {
315            (site1, site2)
316        } else {
317            (site2, site1)
318        };
319
320        // Contract the two tensors
321        let left_tensor = &self.tensors[left_site];
322        let right_tensor = &self.tensors[right_site];
323
324        let d_left = left_tensor.data.shape()[0];
325        let d_mid = left_tensor.data.shape()[2];
326        let d_right = right_tensor.data.shape()[2];
327
328        // Contract and apply gate
329        let mut contracted = Array3::<CType>::zeros((d_left, 4, d_right));
330
331        for i in 0..d_left {
332            for k in 0..d_mid {
333                for m in 0..d_right {
334                    for j1 in 0..2 {
335                        for j2 in 0..2 {
336                            let combined = j1 * 2 + j2;
337                            contracted[[i, combined, m]] +=
338                                left_tensor.data[[i, j1, k]] * right_tensor.data[[k, j2, m]];
339                        }
340                    }
341                }
342            }
343        }
344
345        // Apply gate
346        let mut gated = Array3::<CType>::zeros((d_left, 4, d_right));
347        for i in 0..d_left {
348            for m in 0..d_right {
349                for out_idx in 0..4 {
350                    for in_idx in 0..4 {
351                        gated[[i, out_idx, m]] +=
352                            gate[[out_idx, in_idx]] * contracted[[i, in_idx, m]];
353                    }
354                }
355            }
356        }
357
358        // Split back into two tensors (simplified truncation)
359        let new_bond = d_mid.min(self.config.max_bond_dim);
360
361        let mut new_left = Array3::<CType>::zeros((d_left, 2, new_bond));
362        let mut new_right = Array3::<CType>::zeros((new_bond, 2, d_right));
363
364        // Simple split (for production, use SVD)
365        for i in 0..d_left {
366            for j1 in 0..2 {
367                for k in 0..new_bond {
368                    for j2 in 0..2 {
369                        for m in 0..d_right {
370                            let combined = j1 * 2 + j2;
371                            // Distribute amplitude
372                            new_left[[i, j1, k]] += gated[[i, combined, m]]
373                                * Complex64::new(1.0 / (new_bond * d_right) as f64, 0.0);
374                            new_right[[k, j2, m]] += gated[[i, combined, m]]
375                                * Complex64::new(1.0 / (d_left * 2) as f64, 0.0);
376                        }
377                    }
378                }
379            }
380        }
381
382        self.tensors[left_site] = MPSTensor::new(new_left, left_site);
383        self.tensors[right_site] = MPSTensor::new(new_right, right_site);
384
385        Ok(())
386    }
387
388    /// Get the state vector representation
389    pub fn to_state_vector(&self) -> Result<Vec<CType>> {
390        let dim = 1 << self.n_qubits;
391        let mut state = vec![Complex64::new(0.0, 0.0); dim];
392
393        // Contract all tensors
394        for idx in 0..dim {
395            let mut amp = Complex64::new(1.0, 0.0);
396
397            for site in 0..self.n_qubits {
398                let bit = (idx >> (self.n_qubits - 1 - site)) & 1;
399                // For product states, just multiply the diagonal elements
400                amp *= self.tensors[site].data[[0, bit, 0]];
401            }
402
403            state[idx] = amp;
404        }
405
406        Ok(state)
407    }
408
409    /// Compute overlap with another MPS
410    pub fn overlap(&self, other: &MatrixProductState) -> Result<CType> {
411        if self.n_qubits != other.n_qubits {
412            return Err(MLError::InvalidConfiguration(
413                "MPS qubit counts don't match".to_string(),
414            ));
415        }
416
417        // Contract from left to right
418        let mut transfer = Array2::<CType>::eye(1);
419
420        for site in 0..self.n_qubits {
421            let t1 = &self.tensors[site];
422            let t2 = &other.tensors[site];
423
424            let d1_left = t1.data.shape()[0];
425            let d1_right = t1.data.shape()[2];
426            let d2_left = t2.data.shape()[0];
427            let d2_right = t2.data.shape()[2];
428
429            let mut new_transfer = Array2::<CType>::zeros((d1_right, d2_right));
430
431            for i1 in 0..d1_left {
432                for i2 in 0..d2_left {
433                    for j in 0..2 {
434                        for k1 in 0..d1_right {
435                            for k2 in 0..d2_right {
436                                new_transfer[[k1, k2]] += transfer
437                                    [[i1.min(transfer.nrows() - 1), i2.min(transfer.ncols() - 1)]]
438                                    * t1.data[[i1, j, k1]].conj()
439                                    * t2.data[[i2, j, k2]];
440                            }
441                        }
442                    }
443                }
444            }
445
446            transfer = new_transfer;
447        }
448
449        Ok(transfer[[0, 0]])
450    }
451
452    /// Get total bond dimension (max across all bonds)
453    pub fn max_bond_dim(&self) -> usize {
454        self.tensors
455            .iter()
456            .map(|t| t.bond_dims().1)
457            .max()
458            .unwrap_or(1)
459    }
460}
461
462// ============================================================================
463// TQ Tensor Network Backend
464// ============================================================================
465
466/// TorchQuantum Tensor Network Backend
467///
468/// Provides MPS/PEPS simulation backend for TorchQuantum circuits.
469#[derive(Debug, Clone)]
470pub struct TQTensorNetworkBackend {
471    /// MPS representation of the state
472    pub mps: Option<MatrixProductState>,
473    /// Number of qubits
474    pub n_wires: usize,
475    /// Configuration
476    pub config: TensorNetworkConfig,
477    /// Static mode flag
478    pub static_mode: bool,
479    /// Gate cache for static mode
480    pub gate_cache: HashMap<String, Array2<CType>>,
481}
482
483impl TQTensorNetworkBackend {
484    /// Create new tensor network backend
485    pub fn new(n_wires: usize) -> Self {
486        Self {
487            mps: Some(MatrixProductState::from_computational_basis(n_wires, 0)),
488            n_wires,
489            config: TensorNetworkConfig::default(),
490            static_mode: false,
491            gate_cache: HashMap::new(),
492        }
493    }
494
495    /// Create with custom configuration
496    pub fn with_config(n_wires: usize, config: TensorNetworkConfig) -> Self {
497        let mut mps = MatrixProductState::from_computational_basis(n_wires, 0);
498        mps.config = config.clone();
499
500        Self {
501            mps: Some(mps),
502            n_wires,
503            config,
504            static_mode: false,
505            gate_cache: HashMap::new(),
506        }
507    }
508
509    /// Reset to |0...0> state
510    pub fn reset(&mut self) {
511        self.mps = Some(MatrixProductState::from_computational_basis(
512            self.n_wires,
513            0,
514        ));
515        self.mps.as_mut().map(|m| m.config = self.config.clone());
516    }
517
518    /// Apply single-qubit gate
519    pub fn apply_gate(&mut self, site: usize, gate: &Array2<CType>) -> Result<()> {
520        if let Some(ref mut mps) = self.mps {
521            mps.apply_single_qubit_gate(site, gate)
522        } else {
523            Err(MLError::InvalidConfiguration(
524                "MPS not initialized".to_string(),
525            ))
526        }
527    }
528
529    /// Apply two-qubit gate
530    pub fn apply_two_qubit_gate(
531        &mut self,
532        site1: usize,
533        site2: usize,
534        gate: &Array2<CType>,
535    ) -> Result<()> {
536        if let Some(ref mut mps) = self.mps {
537            mps.apply_two_qubit_gate(site1, site2, gate)
538        } else {
539            Err(MLError::InvalidConfiguration(
540                "MPS not initialized".to_string(),
541            ))
542        }
543    }
544
545    /// Get state vector (contracts MPS)
546    pub fn get_state_vector(&self) -> Result<Vec<CType>> {
547        if let Some(ref mps) = self.mps {
548            mps.to_state_vector()
549        } else {
550            Err(MLError::InvalidConfiguration(
551                "MPS not initialized".to_string(),
552            ))
553        }
554    }
555
556    /// Get expectation value of observable
557    pub fn expectation_value(&self, observable: &Array2<CType>, sites: &[usize]) -> Result<f64> {
558        // For single-qubit observables, contract efficiently
559        if sites.len() == 1 && observable.nrows() == 2 {
560            if let Some(ref mps) = self.mps {
561                let site = sites[0];
562                let tensor = &mps.tensors[site];
563
564                // <O> = sum_{ij} O_{ij} * rho_{ji}
565                // where rho is the reduced density matrix at this site
566                let mut exp_val = Complex64::new(0.0, 0.0);
567
568                for i in 0..2 {
569                    for j in 0..2 {
570                        // Simplified: assume product state for now
571                        let rho_ji = tensor.data[[0, j, 0]].conj() * tensor.data[[0, i, 0]];
572                        exp_val += observable[[i, j]] * rho_ji;
573                    }
574                }
575
576                return Ok(exp_val.re);
577            }
578        }
579
580        Err(MLError::NotSupported(
581            "Multi-site observables not yet implemented for MPS".to_string(),
582        ))
583    }
584
585    /// Get current bond dimension
586    pub fn bond_dimension(&self) -> usize {
587        self.mps.as_ref().map(|m| m.max_bond_dim()).unwrap_or(0)
588    }
589
590    /// Convert to TQDevice for compatibility
591    pub fn to_tq_device(&self) -> Result<TQDevice> {
592        let state_vec = self.get_state_vector()?;
593        let mut qdev = TQDevice::new(self.n_wires);
594
595        // Set state from vector
596        use scirs2_core::ndarray::{ArrayD, IxDyn};
597        let mut shape = vec![1usize]; // batch size 1
598        shape.extend(vec![2; self.n_wires]);
599
600        let states = ArrayD::from_shape_vec(IxDyn(&shape), state_vec)
601            .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
602        qdev.set_states(states);
603
604        Ok(qdev)
605    }
606
607    /// Create from TQDevice
608    pub fn from_tq_device(qdev: &TQDevice) -> Result<Self> {
609        let mps = MatrixProductState::from_tq_device(qdev)?;
610        Ok(Self {
611            n_wires: qdev.n_wires,
612            mps: Some(mps),
613            config: TensorNetworkConfig::default(),
614            static_mode: false,
615            gate_cache: HashMap::new(),
616        })
617    }
618}
619
620impl TQModule for TQTensorNetworkBackend {
621    fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
622        // Backend doesn't have a forward pass - it IS the state
623        Ok(())
624    }
625
626    fn parameters(&self) -> Vec<TQParameter> {
627        Vec::new()
628    }
629
630    fn n_wires(&self) -> Option<usize> {
631        Some(self.n_wires)
632    }
633
634    fn set_n_wires(&mut self, n_wires: usize) {
635        self.n_wires = n_wires;
636        self.reset();
637    }
638
639    fn is_static_mode(&self) -> bool {
640        self.static_mode
641    }
642
643    fn static_on(&mut self) {
644        self.static_mode = true;
645    }
646
647    fn static_off(&mut self) {
648        self.static_mode = false;
649        self.gate_cache.clear();
650    }
651
652    fn name(&self) -> &str {
653        "TQTensorNetworkBackend"
654    }
655}
656
657// ============================================================================
658// Tests
659// ============================================================================
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664
665    #[test]
666    fn test_mps_creation() {
667        let mps = MatrixProductState::from_computational_basis(4, 0);
668        assert_eq!(mps.n_qubits, 4);
669        assert_eq!(mps.tensors.len(), 4);
670    }
671
672    #[test]
673    fn test_mps_state_vector() {
674        let mps = MatrixProductState::from_computational_basis(2, 0);
675        let state = mps.to_state_vector().expect("Should succeed");
676        assert_eq!(state.len(), 4);
677        assert!((state[0].re - 1.0).abs() < 1e-10);
678        for i in 1..4 {
679            assert!(state[i].norm() < 1e-10);
680        }
681    }
682
683    #[test]
684    fn test_tensor_network_backend() {
685        let backend = TQTensorNetworkBackend::new(3);
686        assert_eq!(backend.n_wires, 3);
687        assert!(backend.mps.is_some());
688    }
689
690    #[test]
691    fn test_single_qubit_gate_application() {
692        let mut backend = TQTensorNetworkBackend::new(2);
693
694        // Apply X gate
695        let x_gate = Array2::from_shape_vec(
696            (2, 2),
697            vec![
698                Complex64::new(0.0, 0.0),
699                Complex64::new(1.0, 0.0),
700                Complex64::new(1.0, 0.0),
701                Complex64::new(0.0, 0.0),
702            ],
703        )
704        .expect("Should create matrix");
705
706        backend.apply_gate(0, &x_gate).expect("Should apply gate");
707
708        let state = backend.get_state_vector().expect("Should get state");
709        // |10> state
710        assert!(state[0].norm() < 1e-10);
711        assert!(state[1].norm() < 1e-10);
712        assert!((state[2].re - 1.0).abs() < 1e-10);
713        assert!(state[3].norm() < 1e-10);
714    }
715
716    #[test]
717    fn test_config_defaults() {
718        let config = TensorNetworkConfig::default();
719        assert_eq!(config.max_bond_dim, 64);
720        assert_eq!(config.compression, CompressionMethod::SVD);
721    }
722}