Skip to main content

quantrs2_circuit/
tensor_network.rs

1//! Tensor network compression for quantum circuits
2//!
3//! This module provides tensor network representations of quantum circuits
4//! for efficient simulation and optimization.
5
6use crate::builder::Circuit;
7use crate::dag::{circuit_to_dag, CircuitDag, DagNode};
8// SciRS2 POLICY compliant - using scirs2_core::Complex64
9use quantrs2_core::{
10    error::{QuantRS2Error, QuantRS2Result},
11    gate::GateOp,
12    qubit::QubitId,
13};
14use scirs2_core::ndarray::{Array2, ArrayView2};
15use scirs2_core::Complex64;
16use scirs2_linalg::svd;
17use std::collections::{HashMap, HashSet};
18use std::f64::consts::PI;
19
20/// Complex number type
21type C64 = Complex64;
22
23/// Tensor representing a quantum gate or state
24#[derive(Debug, Clone)]
25pub struct Tensor {
26    /// Tensor data in row-major order
27    pub data: Vec<C64>,
28    /// Shape of the tensor (dimensions)
29    pub shape: Vec<usize>,
30    /// Labels for each index
31    pub indices: Vec<String>,
32}
33
34impl Tensor {
35    /// Create a new tensor
36    #[must_use]
37    pub fn new(data: Vec<C64>, shape: Vec<usize>, indices: Vec<String>) -> Self {
38        assert_eq!(shape.len(), indices.len());
39        let total_size: usize = shape.iter().product();
40        assert_eq!(data.len(), total_size);
41
42        Self {
43            data,
44            shape,
45            indices,
46        }
47    }
48
49    /// Create an identity tensor
50    #[must_use]
51    pub fn identity(dim: usize, in_label: String, out_label: String) -> Self {
52        let mut data = vec![C64::new(0.0, 0.0); dim * dim];
53        for i in 0..dim {
54            data[i * dim + i] = C64::new(1.0, 0.0);
55        }
56
57        Self::new(data, vec![dim, dim], vec![in_label, out_label])
58    }
59
60    /// Get the rank (number of indices)
61    #[must_use]
62    pub fn rank(&self) -> usize {
63        self.shape.len()
64    }
65
66    /// Get the total number of elements
67    #[must_use]
68    pub fn size(&self) -> usize {
69        self.data.len()
70    }
71
72    /// Contract two tensors along specified indices
73    pub fn contract(&self, other: &Self, self_idx: &str, other_idx: &str) -> QuantRS2Result<Self> {
74        // Find index positions
75        let self_pos = self
76            .indices
77            .iter()
78            .position(|s| s == self_idx)
79            .ok_or_else(|| QuantRS2Error::InvalidInput(format!("Index {self_idx} not found")))?;
80        let other_pos = other
81            .indices
82            .iter()
83            .position(|s| s == other_idx)
84            .ok_or_else(|| QuantRS2Error::InvalidInput(format!("Index {other_idx} not found")))?;
85
86        // Check dimensions match
87        if self.shape[self_pos] != other.shape[other_pos] {
88            return Err(QuantRS2Error::InvalidInput(format!(
89                "Dimension mismatch: {} vs {}",
90                self.shape[self_pos], other.shape[other_pos]
91            )));
92        }
93
94        // Compute new shape and indices
95        let mut new_shape = Vec::new();
96        let mut new_indices = Vec::new();
97
98        for (i, (dim, idx)) in self.shape.iter().zip(&self.indices).enumerate() {
99            if i != self_pos {
100                new_shape.push(*dim);
101                new_indices.push(idx.clone());
102            }
103        }
104
105        for (i, (dim, idx)) in other.shape.iter().zip(&other.indices).enumerate() {
106            if i != other_pos {
107                new_shape.push(*dim);
108                new_indices.push(idx.clone());
109            }
110        }
111
112        // Perform contraction (simplified implementation)
113        let new_size: usize = new_shape.iter().product();
114        let mut new_data = vec![C64::new(0.0, 0.0); new_size];
115
116        // This is a simplified contraction - in practice, would use optimized tensor libraries
117        let contract_dim = self.shape[self_pos];
118
119        // For now, return a placeholder
120        Ok(Self::new(new_data, new_shape, new_indices))
121    }
122
123    /// Reshape the tensor
124    pub fn reshape(&mut self, new_shape: Vec<usize>) -> QuantRS2Result<()> {
125        let new_size: usize = new_shape.iter().product();
126        if new_size != self.size() {
127            return Err(QuantRS2Error::InvalidInput(format!(
128                "Cannot reshape {} elements to shape {:?}",
129                self.size(),
130                new_shape
131            )));
132        }
133
134        self.shape = new_shape;
135        Ok(())
136    }
137}
138
139/// Tensor network representation of a quantum circuit
140#[derive(Debug)]
141pub struct TensorNetwork {
142    /// Tensors in the network
143    tensors: Vec<Tensor>,
144    /// Connections between tensors (`tensor_idx1`, idx1, `tensor_idx2`, idx2)
145    bonds: Vec<(usize, String, usize, String)>,
146    /// Open indices (external legs)
147    open_indices: HashMap<String, (usize, usize)>, // index -> (tensor_idx, position)
148}
149
150impl Default for TensorNetwork {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156impl TensorNetwork {
157    /// Create a new empty tensor network
158    #[must_use]
159    pub fn new() -> Self {
160        Self {
161            tensors: Vec::new(),
162            bonds: Vec::new(),
163            open_indices: HashMap::new(),
164        }
165    }
166
167    /// Add a tensor to the network
168    pub fn add_tensor(&mut self, tensor: Tensor) -> usize {
169        let idx = self.tensors.len();
170
171        // Track open indices
172        for (pos, index) in tensor.indices.iter().enumerate() {
173            self.open_indices.insert(index.clone(), (idx, pos));
174        }
175
176        self.tensors.push(tensor);
177        idx
178    }
179
180    /// Connect two tensor indices
181    pub fn add_bond(
182        &mut self,
183        t1: usize,
184        idx1: String,
185        t2: usize,
186        idx2: String,
187    ) -> QuantRS2Result<()> {
188        if t1 >= self.tensors.len() || t2 >= self.tensors.len() {
189            return Err(QuantRS2Error::InvalidInput(
190                "Tensor index out of range".to_string(),
191            ));
192        }
193
194        // Remove from open indices
195        self.open_indices.remove(&idx1);
196        self.open_indices.remove(&idx2);
197
198        self.bonds.push((t1, idx1, t2, idx2));
199        Ok(())
200    }
201
202    /// Contract the entire network to a single tensor
203    pub fn contract_all(&self) -> QuantRS2Result<Tensor> {
204        if self.tensors.is_empty() {
205            return Err(QuantRS2Error::InvalidInput(
206                "Empty tensor network".to_string(),
207            ));
208        }
209
210        // Simple contraction order: left to right
211        // In practice, would use optimal contraction ordering
212        let mut result = self.tensors[0].clone();
213
214        for bond in &self.bonds {
215            let (t1, idx1, t2, idx2) = bond;
216            if *t1 == 0 {
217                result = result.contract(&self.tensors[*t2], idx1, idx2)?;
218            }
219        }
220
221        Ok(result)
222    }
223
224    /// Apply SVD-based bond compression to the tensor network.
225    ///
226    /// For each internal bond in the network:
227    /// 1. Reshape the pair of connected tensors into a bipartite matrix M (rows = left legs, cols = right legs).
228    /// 2. Compute the real-valued Schmidt decomposition via SVD on |M_ij|.
229    /// 3. Truncate to `max_bond_dim` singular values (or drop those below `tolerance`).
230    /// 4. Reconstruct the tensors: left ← U * diag(s), right ← Vt.
231    pub fn compress(&mut self, max_bond_dim: usize, tolerance: f64) -> QuantRS2Result<()> {
232        // Iterate over all bonds; for each bond compress the pair of adjacent tensors.
233        // We collect bond info first to avoid borrow issues.
234        let bond_indices: Vec<usize> = (0..self.bonds.len()).collect();
235
236        for bond_idx in bond_indices {
237            let (t1_idx, ref idx1, t2_idx, ref idx2) = self.bonds[bond_idx].clone();
238
239            if t1_idx >= self.tensors.len() || t2_idx >= self.tensors.len() {
240                continue;
241            }
242
243            // Build real-valued matrix from |amplitude|^2 of t1's data (left) × t2's data (right).
244            // Rows = size of t1, cols = size of t2 (simplified: treat each tensor as a flattened vector).
245            let rows = self.tensors[t1_idx].size();
246            let cols = self.tensors[t2_idx].size();
247
248            if rows == 0 || cols == 0 {
249                continue;
250            }
251
252            // Build the coupling matrix: M[i,j] = Re(conj(t1[i]) * t2[j])
253            let mut mat_data = Vec::with_capacity(rows * cols);
254            for i in 0..rows {
255                let a = self.tensors[t1_idx].data[i];
256                for j in 0..cols {
257                    let b = self.tensors[t2_idx].data[j];
258                    // Real part of ⟨a|b⟩ coupling
259                    mat_data.push(a.re * b.re + a.im * b.im);
260                }
261            }
262
263            let mat = Array2::from_shape_vec((rows, cols), mat_data).map_err(|e| {
264                QuantRS2Error::RuntimeError(format!("SVD matrix build failed: {e}"))
265            })?;
266
267            // Compute SVD: M = U * diag(s) * Vt
268            let svd_result = svd(&mat.view(), false, None).map_err(|e| {
269                QuantRS2Error::RuntimeError(format!("SVD failed on bond {bond_idx}: {e}"))
270            });
271
272            let (u_mat, s_vec, vt_mat) = match svd_result {
273                Ok(result) => result,
274                Err(_) => {
275                    // If SVD fails (e.g., tiny matrix), leave bond unchanged
276                    continue;
277                }
278            };
279
280            // Determine truncation rank
281            let s_total: f64 = s_vec.iter().copied().sum();
282            let mut rank = s_vec.len();
283
284            // Truncate by tolerance (keep singular values whose cumulative fraction > 1-tolerance)
285            if s_total > 0.0 {
286                let mut cumulative = 0.0;
287                for (k, &sv) in s_vec.iter().enumerate() {
288                    cumulative += sv / s_total;
289                    if cumulative >= 1.0 - tolerance {
290                        rank = k + 1;
291                        break;
292                    }
293                }
294            }
295
296            // Apply max_bond_dim cap
297            rank = rank.min(max_bond_dim).min(s_vec.len());
298
299            if rank == 0 {
300                rank = 1;
301            }
302
303            // Reconstruct left tensor data: new_left[i] = sum_k U[i,k] * s[k]  (for k < rank)
304            // We store the result back as the left tensor's flat data (rows dimension preserved, compressed).
305            let mut new_t1_data: Vec<C64> = self.tensors[t1_idx].data.clone();
306            let mut new_t2_data: Vec<C64> = self.tensors[t2_idx].data.clone();
307
308            // Project t1 onto the rank-truncated left singular vectors
309            // new_t1[i] = sum_{k=0}^{rank-1} U[i,k] * s[k] * (old_t1[i] magnitude)
310            for i in 0..rows {
311                let mut proj = 0.0f64;
312                for k in 0..rank {
313                    proj += u_mat[[i, k]] * s_vec[k];
314                }
315                // Scale the complex amplitude by the projected singular value weight
316                let original_norm = (new_t1_data[i].norm_sqr() + 1e-300_f64).sqrt();
317                let scale = proj.abs() / (original_norm + 1e-300_f64);
318                new_t1_data[i] = C64::new(new_t1_data[i].re * scale, new_t1_data[i].im * scale);
319            }
320
321            // Project t2 onto the rank-truncated right singular vectors
322            for j in 0..cols {
323                let mut proj = 0.0f64;
324                for k in 0..rank {
325                    proj += vt_mat[[k, j]];
326                }
327                let original_norm = (new_t2_data[j].norm_sqr() + 1e-300_f64).sqrt();
328                let scale = proj.abs() / (original_norm + 1e-300_f64);
329                new_t2_data[j] = C64::new(new_t2_data[j].re * scale, new_t2_data[j].im * scale);
330            }
331
332            self.tensors[t1_idx].data = new_t1_data;
333            self.tensors[t2_idx].data = new_t2_data;
334        }
335
336        Ok(())
337    }
338}
339
340/// Convert a quantum circuit to tensor network representation
341pub struct CircuitToTensorNetwork<const N: usize> {
342    /// Maximum bond dimension for compression
343    max_bond_dim: Option<usize>,
344    /// Truncation tolerance
345    tolerance: f64,
346}
347
348impl<const N: usize> Default for CircuitToTensorNetwork<N> {
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354impl<const N: usize> CircuitToTensorNetwork<N> {
355    /// Create a new converter
356    #[must_use]
357    pub const fn new() -> Self {
358        Self {
359            max_bond_dim: None,
360            tolerance: 1e-10,
361        }
362    }
363
364    /// Set maximum bond dimension
365    #[must_use]
366    pub const fn with_max_bond_dim(mut self, dim: usize) -> Self {
367        self.max_bond_dim = Some(dim);
368        self
369    }
370
371    /// Set truncation tolerance
372    #[must_use]
373    pub const fn with_tolerance(mut self, tol: f64) -> Self {
374        self.tolerance = tol;
375        self
376    }
377
378    /// Convert circuit to tensor network
379    pub fn convert(&self, circuit: &Circuit<N>) -> QuantRS2Result<TensorNetwork> {
380        let mut tn = TensorNetwork::new();
381        let mut qubit_wires: HashMap<usize, String> = HashMap::new();
382
383        // Initialize qubit wires
384        for i in 0..N {
385            qubit_wires.insert(i, format!("q{i}_in"));
386        }
387
388        // Convert each gate to a tensor
389        for (gate_idx, gate) in circuit.gates().iter().enumerate() {
390            let tensor = self.gate_to_tensor(gate.as_ref(), gate_idx)?;
391            let tensor_idx = tn.add_tensor(tensor);
392
393            // Connect to previous wires
394            for qubit in gate.qubits() {
395                let q = qubit.id() as usize;
396                let prev_wire = qubit_wires
397                    .get(&q)
398                    .ok_or_else(|| {
399                        QuantRS2Error::InvalidInput(format!("Qubit wire {q} not found"))
400                    })?
401                    .clone();
402                let new_wire = format!("q{q}_g{gate_idx}");
403
404                // Add bond from previous wire to this gate
405                if gate_idx > 0 || prev_wire.contains("_g") {
406                    tn.add_bond(
407                        tensor_idx - 1,
408                        prev_wire.clone(),
409                        tensor_idx,
410                        format!("in_{q}"),
411                    )?;
412                }
413
414                // Update wire for next connection
415                qubit_wires.insert(q, new_wire);
416            }
417        }
418
419        Ok(tn)
420    }
421
422    /// Convert a gate to tensor representation
423    fn gate_to_tensor(&self, gate: &dyn GateOp, gate_idx: usize) -> QuantRS2Result<Tensor> {
424        let qubits = gate.qubits();
425        let n_qubits = qubits.len();
426
427        match n_qubits {
428            1 => {
429                // Single-qubit gate
430                let matrix = self.get_single_qubit_matrix(gate)?;
431                let q = qubits[0].id() as usize;
432
433                Ok(Tensor::new(
434                    matrix,
435                    vec![2, 2],
436                    vec![format!("in_{}", q), format!("out_{}", q)],
437                ))
438            }
439            2 => {
440                // Two-qubit gate
441                let matrix = self.get_two_qubit_matrix(gate)?;
442                let q0 = qubits[0].id() as usize;
443                let q1 = qubits[1].id() as usize;
444
445                Ok(Tensor::new(
446                    matrix,
447                    vec![2, 2, 2, 2],
448                    vec![
449                        format!("in_{}", q0),
450                        format!("in_{}", q1),
451                        format!("out_{}", q0),
452                        format!("out_{}", q1),
453                    ],
454                ))
455            }
456            _ => Err(QuantRS2Error::UnsupportedOperation(format!(
457                "{n_qubits}-qubit gates not yet supported for tensor networks"
458            ))),
459        }
460    }
461
462    /// Get matrix representation of single-qubit gate
463    fn get_single_qubit_matrix(&self, gate: &dyn GateOp) -> QuantRS2Result<Vec<C64>> {
464        // Simplified - would use actual gate matrices
465        match gate.name() {
466            "H" => Ok(vec![
467                C64::new(1.0 / 2.0_f64.sqrt(), 0.0),
468                C64::new(1.0 / 2.0_f64.sqrt(), 0.0),
469                C64::new(1.0 / 2.0_f64.sqrt(), 0.0),
470                C64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
471            ]),
472            "X" => Ok(vec![
473                C64::new(0.0, 0.0),
474                C64::new(1.0, 0.0),
475                C64::new(1.0, 0.0),
476                C64::new(0.0, 0.0),
477            ]),
478            "Y" => Ok(vec![
479                C64::new(0.0, 0.0),
480                C64::new(0.0, -1.0),
481                C64::new(0.0, 1.0),
482                C64::new(0.0, 0.0),
483            ]),
484            "Z" => Ok(vec![
485                C64::new(1.0, 0.0),
486                C64::new(0.0, 0.0),
487                C64::new(0.0, 0.0),
488                C64::new(-1.0, 0.0),
489            ]),
490            _ => Ok(vec![
491                C64::new(1.0, 0.0),
492                C64::new(0.0, 0.0),
493                C64::new(0.0, 0.0),
494                C64::new(1.0, 0.0),
495            ]),
496        }
497    }
498
499    /// Get matrix representation of two-qubit gate
500    fn get_two_qubit_matrix(&self, gate: &dyn GateOp) -> QuantRS2Result<Vec<C64>> {
501        // Simplified - would use actual gate matrices
502        if gate.name() == "CNOT" {
503            let mut matrix = vec![C64::new(0.0, 0.0); 16];
504            matrix[0] = C64::new(1.0, 0.0); // |00⟩ -> |00⟩
505            matrix[5] = C64::new(1.0, 0.0); // |01⟩ -> |01⟩
506            matrix[15] = C64::new(1.0, 0.0); // |10⟩ -> |11⟩
507            matrix[10] = C64::new(1.0, 0.0); // |11⟩ -> |10⟩
508            Ok(matrix)
509        } else {
510            // Identity for unsupported gates
511            let mut matrix = vec![C64::new(0.0, 0.0); 16];
512            for i in 0..16 {
513                matrix[i * 16 + i] = C64::new(1.0, 0.0);
514            }
515            Ok(matrix)
516        }
517    }
518}
519
520/// Matrix Product State representation of a circuit
521#[derive(Debug)]
522pub struct MatrixProductState {
523    /// Site tensors
524    tensors: Vec<Tensor>,
525    /// Bond dimensions
526    bond_dims: Vec<usize>,
527    /// Number of qubits
528    n_qubits: usize,
529}
530
531impl MatrixProductState {
532    /// Create MPS from a quantum circuit via explicit unitary tensor contraction.
533    ///
534    /// Algorithm:
535    /// 1. Initialize the MPS as the |0...0⟩ product state: each site tensor is [1, 0] with
536    ///    bond dimensions [1, ..., 1].
537    /// 2. For each gate in the circuit:
538    ///    - Single-qubit gate U on site `i`: contract the 2x2 unitary into the rank-3 site tensor
539    ///      Γ\[i\] with shape \[χ_left, 2, χ_right\].
540    ///    - Two-qubit gate U on sites (i, i+1): reshape the two adjacent site tensors into a
541    ///      combined matrix of shape \[χ_left * 2, 2 * χ_right\], apply the 4x4 unitary, then
542    ///      perform SVD to split back into two site tensors and update the bond dimension.
543    pub fn from_circuit<const N: usize>(circuit: &Circuit<N>) -> QuantRS2Result<Self> {
544        if N == 0 {
545            return Ok(Self {
546                tensors: Vec::new(),
547                bond_dims: Vec::new(),
548                n_qubits: 0,
549            });
550        }
551
552        let converter = CircuitToTensorNetwork::<N>::new();
553        // bond_dims[i] = bond dimension between site i and i+1 (length N-1).
554        let mut bond_dims = vec![1usize; N.saturating_sub(1)];
555
556        // Site tensors: Γ[i] has shape [χ_left, 2, χ_right] stored as flat Vec<C64>
557        // For i=0: shape [1, 2, 1]; for the |0⟩ state: data = [1, 0] (physical index 0→1, 1→0)
558        let mut site_tensors: Vec<Vec<C64>> = (0..N)
559            .map(|_| {
560                // [1, 2, 1] tensor for |0⟩: Γ[0,0,0]=1, Γ[0,1,0]=0
561                vec![C64::new(1.0, 0.0), C64::new(0.0, 0.0)]
562            })
563            .collect();
564
565        // Helper: retrieve 2×2 matrix for a single-qubit gate (reuse converter logic)
566        let gate_to_single_mat = |g: &dyn GateOp| -> Option<[C64; 4]> {
567            match g.name() {
568                "H" => Some([
569                    C64::new(1.0 / 2.0_f64.sqrt(), 0.0),
570                    C64::new(1.0 / 2.0_f64.sqrt(), 0.0),
571                    C64::new(1.0 / 2.0_f64.sqrt(), 0.0),
572                    C64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
573                ]),
574                "X" => Some([
575                    C64::new(0.0, 0.0),
576                    C64::new(1.0, 0.0),
577                    C64::new(1.0, 0.0),
578                    C64::new(0.0, 0.0),
579                ]),
580                "Y" => Some([
581                    C64::new(0.0, 0.0),
582                    C64::new(0.0, -1.0),
583                    C64::new(0.0, 1.0),
584                    C64::new(0.0, 0.0),
585                ]),
586                "Z" => Some([
587                    C64::new(1.0, 0.0),
588                    C64::new(0.0, 0.0),
589                    C64::new(0.0, 0.0),
590                    C64::new(-1.0, 0.0),
591                ]),
592                "RY" | "RZ" | "RX" | "S" | "T" | "SX" | "ID" | "I" => {
593                    // Use identity as fallback for parameterized gates
594                    Some([
595                        C64::new(1.0, 0.0),
596                        C64::new(0.0, 0.0),
597                        C64::new(0.0, 0.0),
598                        C64::new(1.0, 0.0),
599                    ])
600                }
601                _ => None,
602            }
603        };
604
605        // Helper: 4×4 CNOT unitary (control=row 0 of physical indices)
606        let cnot_mat: [C64; 16] = {
607            let mut m = [C64::new(0.0, 0.0); 16];
608            m[0] = C64::new(1.0, 0.0); // |00⟩ → |00⟩
609            m[5] = C64::new(1.0, 0.0); // |01⟩ → |01⟩
610            m[14] = C64::new(1.0, 0.0); // |10⟩ → |11⟩
611            m[11] = C64::new(1.0, 0.0); // |11⟩ → |10⟩
612            m
613        };
614
615        // Default max_bond_dim during construction (no truncation limit)
616        let max_bd = 32usize;
617
618        for gate in circuit.gates() {
619            let qubits = gate.qubits();
620            match qubits.len() {
621                1 => {
622                    let qi = qubits[0].id() as usize;
623                    if qi >= N {
624                        continue;
625                    }
626                    if let Some(u) = gate_to_single_mat(gate.as_ref()) {
627                        // Contract: new_Γ[α, σ', β] = Σ_σ U[σ', σ] * Γ[α, σ, β]
628                        // Current shape: [χ_l, 2, χ_r] — chi_l=1, chi_r=1 for initial state
629                        // Since we store flat [2] for the initial product state:
630                        let old = site_tensors[qi].clone();
631                        let phys = old.len(); // = 2 * χ_l * χ_r in general
632                                              // Simple case: apply 2×2 unitary to physical index dimension 2
633                        let half = phys / 2;
634                        let mut new_site = vec![C64::new(0.0, 0.0); phys];
635                        for alpha in 0..half {
636                            let s0 = old[alpha]; // physical |0⟩
637                            let s1 = old[alpha + half]; // physical |1⟩
638                            new_site[alpha] = u[0] * s0 + u[1] * s1; // U[0,0]*|0⟩ + U[0,1]*|1⟩
639                            new_site[alpha + half] = u[2] * s0 + u[3] * s1; // U[1,0]*|0⟩ + U[1,1]*|1⟩
640                        }
641                        site_tensors[qi] = new_site;
642                    }
643                }
644                2 => {
645                    let qi = qubits[0].id() as usize;
646                    let qj = qubits[1].id() as usize;
647                    // Only handle adjacent qubits (i, i+1)
648                    if qi >= N || qj >= N || qj != qi + 1 {
649                        continue;
650                    }
651                    let gate_name = gate.name();
652                    let unitary_mat: [C64; 16] = if gate_name == "CNOT" || gate_name == "CX" {
653                        cnot_mat
654                    } else {
655                        // Identity 4×4 for unsupported two-qubit gates
656                        let mut id = [C64::new(0.0, 0.0); 16];
657                        id[0] = C64::new(1.0, 0.0);
658                        id[5] = C64::new(1.0, 0.0);
659                        id[10] = C64::new(1.0, 0.0);
660                        id[15] = C64::new(1.0, 0.0);
661                        id
662                    };
663
664                    // Left site: [χ_l, 2, χ_m], right site: [χ_m, 2, χ_r]
665                    // Merge into: Θ[χ_l * 2, 2 * χ_r] via Θ[α*2+σ, σ'*χ_r+β] = Σ_m Γ_i[α,σ,m] * Γ_j[m,σ',β]
666                    let left = &site_tensors[qi];
667                    let right = &site_tensors[qj];
668                    let left_phys = left.len(); // χ_l * 2
669                    let right_phys = right.len(); // 2 * χ_r
670                    let chi_m = bond_dims.get(qi).copied().unwrap_or(1);
671                    let chi_l = left_phys / 2; // should equal χ_l * χ_m / χ_m
672                    let chi_r = right_phys / 2; // should equal χ_m * χ_r / χ_m
673
674                    // Build merged tensor Θ: shape [chi_l * 2, chi_r * 2]
675                    // Index convention: row = (chi_l_idx * 2 + sigma_i), col = (sigma_j * chi_r + chi_r_idx)
676                    let nrows = chi_l * 2;
677                    let ncols = chi_r * 2;
678                    let mut theta = vec![C64::new(0.0, 0.0); nrows * ncols];
679
680                    // Contract over χ_m (bond index between site qi and qj)
681                    // left[alpha, sigma_i] = left_flat[sigma_i * chi_l + alpha]  (stored as [phys_0, phys_1])
682                    // right[sigma_j, beta] = right_flat[sigma_j * chi_r + beta]
683                    for sigma_i in 0..2usize {
684                        for alpha in 0..chi_l {
685                            let l_val = left
686                                .get(sigma_i * chi_l + alpha)
687                                .copied()
688                                .unwrap_or(C64::new(0.0, 0.0));
689                            for sigma_j in 0..2usize {
690                                for beta in 0..chi_r {
691                                    let r_val = right
692                                        .get(sigma_j * chi_r + beta)
693                                        .copied()
694                                        .unwrap_or(C64::new(0.0, 0.0));
695                                    let row = alpha * 2 + sigma_i;
696                                    let col = sigma_j * chi_r + beta;
697                                    if row < nrows && col < ncols {
698                                        theta[row * ncols + col] += l_val * r_val;
699                                    }
700                                }
701                            }
702                        }
703                    }
704
705                    // Apply two-qubit unitary: Θ' = U * Θ (in the combined physical index space)
706                    // U acts on (σ_i, σ_j) space (4×4), Θ rows ~ (α, σ_i), Θ cols ~ (σ_j, β)
707                    // Θ'[α*2+σ'_i, σ'_j*χ_r+β] = Σ_{σ_i, σ_j} U[σ'_i*2+σ'_j, σ_i*2+σ_j] * Θ[α*2+σ_i, σ_j*χ_r+β]
708                    let mut theta_prime = vec![C64::new(0.0, 0.0); nrows * ncols];
709                    for alpha in 0..chi_l {
710                        for sigma_i_out in 0..2usize {
711                            for sigma_j_out in 0..2usize {
712                                for beta in 0..chi_r {
713                                    let row_out = alpha * 2 + sigma_i_out;
714                                    let col_out = sigma_j_out * chi_r + beta;
715                                    let mut val = C64::new(0.0, 0.0);
716                                    for sigma_i_in in 0..2usize {
717                                        for sigma_j_in in 0..2usize {
718                                            let u_idx = (sigma_i_out * 2 + sigma_j_out) * 4
719                                                + sigma_i_in * 2
720                                                + sigma_j_in;
721                                            let u_val = unitary_mat
722                                                .get(u_idx)
723                                                .copied()
724                                                .unwrap_or(C64::new(0.0, 0.0));
725                                            let row_in = alpha * 2 + sigma_i_in;
726                                            let col_in = sigma_j_in * chi_r + beta;
727                                            val += u_val
728                                                * theta
729                                                    .get(row_in * ncols + col_in)
730                                                    .copied()
731                                                    .unwrap_or(C64::new(0.0, 0.0));
732                                        }
733                                    }
734                                    if row_out < nrows && col_out < ncols {
735                                        theta_prime[row_out * ncols + col_out] = val;
736                                    }
737                                }
738                            }
739                        }
740                    }
741
742                    // SVD on the real part to get new bond dimension
743                    let real_mat_data: Vec<f64> = theta_prime.iter().map(|c| c.re).collect();
744                    let real_mat =
745                        Array2::from_shape_vec((nrows, ncols), real_mat_data).map_err(|e| {
746                            QuantRS2Error::RuntimeError(format!("MPS matrix reshape failed: {e}"))
747                        })?;
748
749                    let svd_res = svd(&real_mat.view(), false, None)
750                        .map_err(|e| QuantRS2Error::RuntimeError(format!("MPS SVD failed: {e}")));
751
752                    let (u_mat, s_vec, vt_mat) = match svd_res {
753                        Ok(r) => r,
754                        Err(_) => {
755                            // Fallback: keep tensors unchanged
756                            continue;
757                        }
758                    };
759
760                    // Truncate bond dimension to max_bd
761                    let new_chi_m = s_vec.len().min(max_bd);
762
763                    // Reconstruct left site: shape [chi_l * 2, new_chi_m]
764                    // new_left[row, k] = U[row, k] * sqrt(s[k])
765                    let mut new_left = vec![C64::new(0.0, 0.0); chi_l * 2 * new_chi_m];
766                    for row in 0..nrows {
767                        for k in 0..new_chi_m {
768                            let sv = s_vec[k].max(0.0).sqrt();
769                            let idx = row * new_chi_m + k;
770                            new_left[idx] = C64::new(u_mat[[row, k]] * sv, 0.0);
771                        }
772                    }
773
774                    // Reconstruct right site: shape [new_chi_m, chi_r * 2]
775                    // new_right[k, col] = sqrt(s[k]) * Vt[k, col]
776                    let mut new_right = vec![C64::new(0.0, 0.0); new_chi_m * chi_r * 2];
777                    for k in 0..new_chi_m {
778                        let sv = s_vec[k].max(0.0).sqrt();
779                        for col in 0..ncols {
780                            let idx = k * ncols + col;
781                            new_right[idx] = C64::new(vt_mat[[k, col]] * sv, 0.0);
782                        }
783                    }
784
785                    site_tensors[qi] = new_left;
786                    site_tensors[qj] = new_right;
787                    if qi < bond_dims.len() {
788                        bond_dims[qi] = new_chi_m;
789                    }
790                }
791                _ => {
792                    // Multi-qubit gates beyond 2-qubit: skip
793                }
794            }
795        }
796
797        // Build the site Tensors with correct shape annotations
798        let tensors: Vec<Tensor> = site_tensors
799            .into_iter()
800            .enumerate()
801            .map(|(i, data)| {
802                let chi_l = if i == 0 { 1 } else { bond_dims[i - 1] };
803                let chi_r = if i + 1 < N { bond_dims[i] } else { 1 };
804                let shape = vec![chi_l, 2, chi_r];
805                let indices = vec![
806                    format!("bond_left_{i}"),
807                    format!("phys_{i}"),
808                    format!("bond_right_{i}"),
809                ];
810                // Ensure data length matches shape product
811                let expected = chi_l * 2 * chi_r;
812                let mut padded = data;
813                padded.resize(expected, C64::new(0.0, 0.0));
814                Tensor::new(padded, shape, indices)
815            })
816            .collect();
817
818        Ok(Self {
819            tensors,
820            bond_dims,
821            n_qubits: N,
822        })
823    }
824
825    /// Compress the MPS via a left-to-right SVD sweep with bond truncation.
826    ///
827    /// For each bond between site i and i+1:
828    /// 1. Reshape tensors\[i\] (shape \[χ_l, 2, χ_m\]) and tensors\[i+1\] (shape \[χ_m, 2, χ_r\])
829    ///    into a combined matrix Θ of shape \[χ_l\*2, χ_r\*2\].
830    /// 2. Compute SVD Θ = U Σ Vt.
831    /// 3. Truncate to min(max_bond_dim, rank where σ_k / σ_0 > tolerance).
832    /// 4. Set tensors\[i\] = U\[:, :new_χ\] \* diag(Σ\[:new_χ\])^(1/2),
833    ///    tensors\[i+1\] = diag(Σ\[:new_χ\])^(1/2) \* Vt\[:new_χ, :\].
834    pub fn compress(&mut self, max_bond_dim: usize, tolerance: f64) -> QuantRS2Result<()> {
835        let n = self.n_qubits;
836        if n <= 1 {
837            return Ok(());
838        }
839
840        for i in 0..(n - 1) {
841            if i + 1 >= self.tensors.len() {
842                break;
843            }
844
845            let chi_l_i = self.tensors[i].shape.first().copied().unwrap_or(1);
846            let chi_r_i = self.tensors[i].shape.get(2).copied().unwrap_or(1); // = chi_m
847            let chi_r_j = self.tensors[i + 1].shape.get(2).copied().unwrap_or(1);
848
849            let nrows = chi_l_i * 2;
850            let ncols = chi_r_j * 2;
851
852            // Build combined real-valued matrix from amplitudes
853            // Θ[alpha*2+sigma_i, sigma_j*chi_r_j+beta] = Σ_m Γ_i[alpha,sigma_i,m] * Γ_{i+1}[m,sigma_j,beta]
854            let left = &self.tensors[i].data;
855            let right = &self.tensors[i + 1].data;
856            let mut theta_real = vec![0.0f64; nrows * ncols];
857
858            for alpha in 0..chi_l_i {
859                for sigma_i in 0..2usize {
860                    for m in 0..chi_r_i {
861                        let l_idx = (alpha * 2 + sigma_i) * chi_r_i + m;
862                        let l_val = left.get(l_idx).map(|c| c.re).unwrap_or(0.0);
863                        if l_val == 0.0 {
864                            continue;
865                        }
866                        for sigma_j in 0..2usize {
867                            for beta in 0..chi_r_j {
868                                let r_idx = (m * 2 + sigma_j) * chi_r_j + beta;
869                                let r_val = right.get(r_idx).map(|c| c.re).unwrap_or(0.0);
870                                let row = alpha * 2 + sigma_i;
871                                let col = sigma_j * chi_r_j + beta;
872                                if row < nrows && col < ncols {
873                                    theta_real[row * ncols + col] += l_val * r_val;
874                                }
875                            }
876                        }
877                    }
878                }
879            }
880
881            let mat = Array2::from_shape_vec((nrows, ncols), theta_real).map_err(|e| {
882                QuantRS2Error::RuntimeError(format!("MPS compress reshape failed: {e}"))
883            })?;
884
885            let svd_res = svd(&mat.view(), false, None).map_err(|e| {
886                QuantRS2Error::RuntimeError(format!("MPS compress SVD failed at bond {i}: {e}"))
887            });
888
889            let (u_mat, s_vec, vt_mat) = match svd_res {
890                Ok(r) => r,
891                Err(_) => continue,
892            };
893
894            // Determine truncation rank
895            let sigma_max = s_vec.first().copied().unwrap_or(0.0);
896            let rank = if sigma_max > 0.0 {
897                s_vec
898                    .iter()
899                    .take_while(|&&sv| sv / sigma_max > tolerance)
900                    .count()
901            } else {
902                1
903            };
904            let new_chi_m = rank.min(max_bond_dim).min(s_vec.len()).max(1);
905
906            // Rebuild left tensor: shape [chi_l_i, 2, new_chi_m]
907            let new_left_size = chi_l_i * 2 * new_chi_m;
908            let mut new_left = vec![C64::new(0.0, 0.0); new_left_size];
909            for row in 0..(chi_l_i * 2) {
910                for k in 0..new_chi_m {
911                    let sv = s_vec[k].max(0.0).sqrt();
912                    let flat_idx = row * new_chi_m + k;
913                    new_left[flat_idx] = C64::new(u_mat[[row, k]] * sv, 0.0);
914                }
915            }
916
917            // Rebuild right tensor: shape [new_chi_m, 2, chi_r_j]
918            let new_right_size = new_chi_m * 2 * chi_r_j;
919            let mut new_right = vec![C64::new(0.0, 0.0); new_right_size];
920            for k in 0..new_chi_m {
921                let sv = s_vec[k].max(0.0).sqrt();
922                for col in 0..(2 * chi_r_j) {
923                    let flat_idx = k * 2 * chi_r_j + col;
924                    new_right[flat_idx] = C64::new(vt_mat[[k, col]] * sv, 0.0);
925                }
926            }
927
928            // Update tensors in place
929            self.tensors[i].data = new_left;
930            self.tensors[i].shape = vec![chi_l_i, 2, new_chi_m];
931
932            self.tensors[i + 1].data = new_right;
933            self.tensors[i + 1].shape = vec![new_chi_m, 2, chi_r_j];
934
935            if i < self.bond_dims.len() {
936                self.bond_dims[i] = new_chi_m;
937            }
938        }
939
940        Ok(())
941    }
942
943    /// Calculate overlap with another MPS
944    pub fn overlap(&self, other: &Self) -> QuantRS2Result<C64> {
945        if self.n_qubits != other.n_qubits {
946            return Err(QuantRS2Error::InvalidInput(
947                "MPS have different number of qubits".to_string(),
948            ));
949        }
950
951        // Calculate ⟨ψ|φ⟩
952        Ok(C64::new(1.0, 0.0)) // Placeholder
953    }
954
955    /// Calculate expectation value of observable
956    pub const fn expectation_value(&self, observable: &TensorNetwork) -> QuantRS2Result<f64> {
957        // Calculate ⟨ψ|O|ψ⟩
958        Ok(0.0) // Placeholder
959    }
960}
961
962/// Circuit compression using tensor networks
963pub struct TensorNetworkCompressor {
964    /// Maximum bond dimension
965    max_bond_dim: usize,
966    /// Truncation tolerance
967    tolerance: f64,
968    /// Compression method
969    method: CompressionMethod,
970}
971
972#[derive(Debug, Clone)]
973pub enum CompressionMethod {
974    /// Singular Value Decomposition
975    SVD,
976    /// Density Matrix Renormalization Group
977    DMRG,
978    /// Time-Evolving Block Decimation
979    TEBD,
980}
981
982impl TensorNetworkCompressor {
983    /// Create a new compressor
984    #[must_use]
985    pub const fn new(max_bond_dim: usize) -> Self {
986        Self {
987            max_bond_dim,
988            tolerance: 1e-10,
989            method: CompressionMethod::SVD,
990        }
991    }
992
993    /// Set compression method
994    #[must_use]
995    pub const fn with_method(mut self, method: CompressionMethod) -> Self {
996        self.method = method;
997        self
998    }
999
1000    /// Compress a circuit
1001    pub fn compress<const N: usize>(
1002        &self,
1003        circuit: &Circuit<N>,
1004    ) -> QuantRS2Result<CompressedCircuit<N>> {
1005        let mps = MatrixProductState::from_circuit(circuit)?;
1006
1007        Ok(CompressedCircuit {
1008            mps,
1009            original_gates: circuit.num_gates(),
1010            compression_ratio: 1.0, // Placeholder
1011        })
1012    }
1013}
1014
1015/// Compressed circuit representation
1016#[derive(Debug)]
1017pub struct CompressedCircuit<const N: usize> {
1018    /// MPS representation
1019    mps: MatrixProductState,
1020    /// Original number of gates
1021    original_gates: usize,
1022    /// Compression ratio
1023    compression_ratio: f64,
1024}
1025
1026impl<const N: usize> CompressedCircuit<N> {
1027    /// Get compression ratio
1028    #[must_use]
1029    pub const fn compression_ratio(&self) -> f64 {
1030        self.compression_ratio
1031    }
1032
1033    /// Decompress back to circuit
1034    pub fn decompress(&self) -> QuantRS2Result<Circuit<N>> {
1035        // Convert MPS back to circuit representation
1036        // This is non-trivial and would require gate synthesis
1037        Ok(Circuit::<N>::new())
1038    }
1039
1040    /// Get fidelity with original circuit
1041    pub const fn fidelity(&self, original: &Circuit<N>) -> QuantRS2Result<f64> {
1042        // Calculate |⟨ψ_compressed|ψ_original⟩|²
1043        Ok(0.99) // Placeholder
1044    }
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049    use super::*;
1050    use quantrs2_core::gate::single::Hadamard;
1051
1052    #[test]
1053    fn test_tensor_creation() {
1054        let data = vec![
1055            C64::new(1.0, 0.0),
1056            C64::new(0.0, 0.0),
1057            C64::new(0.0, 0.0),
1058            C64::new(1.0, 0.0),
1059        ];
1060        let tensor = Tensor::new(data, vec![2, 2], vec!["in".to_string(), "out".to_string()]);
1061
1062        assert_eq!(tensor.rank(), 2);
1063        assert_eq!(tensor.size(), 4);
1064    }
1065
1066    #[test]
1067    fn test_tensor_network() {
1068        let mut tn = TensorNetwork::new();
1069
1070        let t1 = Tensor::identity(2, "a".to_string(), "b".to_string());
1071        let t2 = Tensor::identity(2, "c".to_string(), "d".to_string());
1072
1073        let idx1 = tn.add_tensor(t1);
1074        let idx2 = tn.add_tensor(t2);
1075
1076        tn.add_bond(idx1, "b".to_string(), idx2, "c".to_string())
1077            .expect("Failed to add bond between tensors");
1078
1079        assert_eq!(tn.tensors.len(), 2);
1080        assert_eq!(tn.bonds.len(), 1);
1081    }
1082
1083    #[test]
1084    fn test_circuit_to_tensor_network() {
1085        let mut circuit = Circuit::<2>::new();
1086        circuit
1087            .add_gate(Hadamard { target: QubitId(0) })
1088            .expect("Failed to add Hadamard gate");
1089
1090        let converter = CircuitToTensorNetwork::<2>::new();
1091        let tn = converter
1092            .convert(&circuit)
1093            .expect("Failed to convert circuit to tensor network");
1094
1095        assert!(!tn.tensors.is_empty());
1096    }
1097
1098    #[test]
1099    fn test_compression() {
1100        let circuit = Circuit::<2>::new();
1101        let compressor = TensorNetworkCompressor::new(32);
1102
1103        let compressed = compressor
1104            .compress(&circuit)
1105            .expect("Failed to compress circuit");
1106        assert!(compressed.compression_ratio() <= 1.0);
1107    }
1108
1109    #[test]
1110    fn test_tensor_network_svd_compress() {
1111        use quantrs2_core::gate::multi::CNOT;
1112
1113        // Build a circuit with a Hadamard + CNOT (Bell state preparation)
1114        let mut circuit = Circuit::<2>::new();
1115        circuit
1116            .add_gate(Hadamard { target: QubitId(0) })
1117            .expect("H gate");
1118        circuit
1119            .add_gate(CNOT {
1120                control: QubitId(0),
1121                target: QubitId(1),
1122            })
1123            .expect("CNOT gate");
1124
1125        let converter = CircuitToTensorNetwork::<2>::new();
1126        let mut tn = converter.convert(&circuit).expect("Convert to TN");
1127
1128        // Compress with max bond dim 4 and tolerance 1e-6
1129        tn.compress(4, 1e-6).expect("TN compress");
1130        // If we got here without panic, the test passes; verify structure
1131        assert_eq!(tn.tensors.len(), 2);
1132    }
1133
1134    #[test]
1135    fn test_mps_from_circuit_trivial() {
1136        // Empty circuit → valid MPS
1137        let circuit = Circuit::<2>::new();
1138        let mps = MatrixProductState::from_circuit(&circuit).expect("MPS from empty circuit");
1139        assert_eq!(mps.n_qubits, 2);
1140        assert_eq!(mps.tensors.len(), 2);
1141    }
1142
1143    #[test]
1144    fn test_mps_from_circuit_with_hadamard() {
1145        use quantrs2_core::gate::single::Hadamard;
1146
1147        let mut circuit = Circuit::<3>::new();
1148        circuit
1149            .add_gate(Hadamard { target: QubitId(0) })
1150            .expect("H gate");
1151
1152        let mps = MatrixProductState::from_circuit(&circuit).expect("MPS from H circuit");
1153        assert_eq!(mps.n_qubits, 3);
1154        assert_eq!(mps.tensors.len(), 3);
1155    }
1156
1157    #[test]
1158    fn test_mps_compress_reduces_bond_dim() {
1159        use quantrs2_core::gate::multi::CNOT;
1160
1161        // Bell state: H + CNOT should create a non-trivial entangled MPS
1162        let mut circuit = Circuit::<2>::new();
1163        circuit
1164            .add_gate(Hadamard { target: QubitId(0) })
1165            .expect("H gate");
1166        circuit
1167            .add_gate(CNOT {
1168                control: QubitId(0),
1169                target: QubitId(1),
1170            })
1171            .expect("CNOT gate");
1172
1173        let mut mps = MatrixProductState::from_circuit(&circuit).expect("MPS from Bell circuit");
1174
1175        // Compress with max bond dim 1 (strong truncation)
1176        mps.compress(1, 1e-10).expect("MPS compress");
1177        // Bond dims should be ≤ max_bond_dim
1178        for &bd in &mps.bond_dims {
1179            assert!(bd <= 1, "Bond dim {} exceeds max", bd);
1180        }
1181    }
1182}