Skip to main content

quantrs2_sim/tensor_network/
tensor.rs

1//! Tensor representation for quantum states and operations
2//!
3//! This module provides a tensor-based representation for quantum states
4//! and operations used in the tensor network simulator.
5
6use quantrs2_core::error::{QuantRS2Error, QuantRS2Result};
7use scirs2_core::ndarray::{Array, Array1, Array2, ArrayD, Dimension, IxDyn};
8use scirs2_core::Complex64;
9
10/// A tensor representing a quantum state or operation
11#[derive(Debug, Clone)]
12pub struct Tensor {
13    /// The tensor data
14    pub data: ArrayD<Complex64>,
15
16    /// The tensor rank (number of indices)
17    pub rank: usize,
18
19    /// The dimensions of each index
20    pub dimensions: Vec<usize>,
21}
22
23impl Tensor {
24    /// Create a new tensor from a multi-dimensional array
25    pub fn new(data: ArrayD<Complex64>) -> Self {
26        let dimensions = data.shape().to_vec();
27        let rank = dimensions.len();
28
29        Self {
30            data,
31            rank,
32            dimensions,
33        }
34    }
35
36    /// Create a tensor from a matrix (gate)
37    pub fn from_matrix(matrix: &[Complex64], dim: usize) -> Self {
38        // Determine the shape based on the matrix size and dimension
39        let _n = (matrix.len() as f64).sqrt() as usize;
40
41        // Reshape the matrix into a multi-dimensional array
42        let mut shape = Vec::new();
43        for _ in 0..dim {
44            shape.push(2); // Each qubit has dimension 2
45        }
46
47        // Create the tensor data
48        let mut data = ArrayD::zeros(IxDyn(&shape));
49
50        // Fill the tensor with matrix elements
51        let flat_data = data
52            .as_slice_mut()
53            .expect("Tensor data should be contiguous in memory");
54        for (i, val) in matrix.iter().enumerate() {
55            if i < flat_data.len() {
56                flat_data[i] = *val;
57            }
58        }
59
60        Self::new(data)
61    }
62
63    /// Create a tensor representing the |0⟩ state
64    pub fn qubit_zero() -> Self {
65        let data = Array::from_shape_vec(
66            IxDyn(&[2]),
67            vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
68        )
69        .expect("Valid shape for qubit |0> state");
70
71        Self::new(data)
72    }
73
74    /// Create a tensor representing the |1⟩ state
75    pub fn qubit_one() -> Self {
76        let data = Array::from_shape_vec(
77            IxDyn(&[2]),
78            vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
79        )
80        .expect("Valid shape for qubit |1> state");
81
82        Self::new(data)
83    }
84
85    /// Create a tensor representing the |+⟩ state
86    pub fn qubit_plus() -> Self {
87        let data = Array::from_shape_vec(
88            IxDyn(&[2]),
89            vec![
90                Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
91                Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
92            ],
93        )
94        .expect("Valid shape for qubit |+> state");
95
96        Self::new(data)
97    }
98
99    /// Contract this tensor with another tensor along specified axes.
100    ///
101    /// Performs the Einstein summation over one pair of indices:
102    ///   result[i₀,…,iₙ₋₁, j₀,…,jₘ₋₁] = Σₖ self[…k…] * other[…k…]
103    /// where `k` runs over `self.dimensions[self_axis]` (= `other.dimensions[other_axis]`).
104    ///
105    /// The output shape is `self.dimensions` with `self_axis` removed, followed by
106    /// `other.dimensions` with `other_axis` removed.
107    pub fn contract(
108        &self,
109        other: &Self,
110        self_axis: usize,
111        other_axis: usize,
112    ) -> QuantRS2Result<Self> {
113        // Validate axis indices
114        if self_axis >= self.rank || other_axis >= other.rank {
115            return Err(QuantRS2Error::CircuitValidationFailed(format!(
116                "Invalid contraction axes: {self_axis} and {other_axis}"
117            )));
118        }
119
120        // Validate axis dimensions match (contraction is only valid when dims agree)
121        if self.dimensions[self_axis] != other.dimensions[other_axis] {
122            return Err(QuantRS2Error::CircuitValidationFailed(format!(
123                "Mismatched dimensions for contraction: {} and {}",
124                self.dimensions[self_axis], other.dimensions[other_axis]
125            )));
126        }
127
128        let _contract_dim = self.dimensions[self_axis];
129
130        // Build the result dimensions:
131        //   all self dims except self_axis, then all other dims except other_axis
132        let self_outer_dims: Vec<usize> = self
133            .dimensions
134            .iter()
135            .enumerate()
136            .filter(|&(i, _)| i != self_axis)
137            .map(|(_, &d)| d)
138            .collect();
139        let other_outer_dims: Vec<usize> = other
140            .dimensions
141            .iter()
142            .enumerate()
143            .filter(|&(i, _)| i != other_axis)
144            .map(|(_, &d)| d)
145            .collect();
146
147        let mut result_dims = self_outer_dims.clone();
148        result_dims.extend_from_slice(&other_outer_dims);
149
150        // For scalar output (both tensors were rank-1 vectors)
151        let result_is_scalar = result_dims.is_empty();
152
153        let result_shape = if result_is_scalar {
154            IxDyn(&[1usize])
155        } else {
156            IxDyn(result_dims.as_slice())
157        };
158
159        let mut result_data = ArrayD::zeros(result_shape);
160
161        // Perform contraction via explicit index iteration.
162        // This is O(N_self * N_other) but is simple and correct for any rank,
163        // which is appropriate for small quantum-circuit tensors (dim 2–16).
164        for (self_idx, self_val) in self.data.indexed_iter() {
165            let self_raw = self_idx.slice();
166            let k = self_raw[self_axis];
167
168            // Build the partial result index from self (excluding self_axis)
169            let self_outer_idx: Vec<usize> = self_raw
170                .iter()
171                .enumerate()
172                .filter(|&(i, _)| i != self_axis)
173                .map(|(_, &v)| v)
174                .collect();
175
176            for (other_idx, other_val) in other.data.indexed_iter() {
177                let other_raw = other_idx.slice();
178                if other_raw[other_axis] != k {
179                    continue;
180                }
181
182                // Build the partial result index from other (excluding other_axis)
183                let other_outer_idx: Vec<usize> = other_raw
184                    .iter()
185                    .enumerate()
186                    .filter(|&(i, _)| i != other_axis)
187                    .map(|(_, &v)| v)
188                    .collect();
189
190                // Concatenate to get full result index
191                let mut res_idx = self_outer_idx.clone();
192                res_idx.extend_from_slice(&other_outer_idx);
193
194                let target = if result_is_scalar {
195                    &mut result_data[IxDyn(&[0usize])]
196                } else {
197                    &mut result_data[IxDyn(res_idx.as_slice())]
198                };
199                *target += *self_val * *other_val;
200            }
201        }
202
203        // Unwrap scalar result back to empty shape
204        let final_data = if result_is_scalar {
205            let scalar_val = result_data[IxDyn(&[0usize])];
206            ArrayD::from_elem(IxDyn(&[]), scalar_val)
207        } else {
208            result_data
209        };
210
211        let result_rank = result_dims.len();
212        Ok(Self {
213            data: final_data,
214            dimensions: result_dims,
215            rank: result_rank,
216        })
217    }
218
219    /// Perform SVD decomposition on this tensor, splitting it into two lower-rank tensors.
220    ///
221    /// The tensor is logically reshaped into a matrix by grouping `left_axes` into rows
222    /// and `right_axes` into columns. The SVD is then computed and the result is split
223    /// into two tensors:
224    ///
225    /// - `left_tensor`: shape `(*left_dims, bond_dim)` — absorbs U * diag(S)
226    /// - `right_tensor`: shape `(bond_dim, *right_dims)` — contains Vᴴ
227    ///
228    /// `max_bond_dim` caps how many singular values are kept (bond dimension).
229    pub fn svd(
230        &self,
231        left_axes: &[usize],
232        right_axes: &[usize],
233        max_bond_dim: usize,
234    ) -> QuantRS2Result<(Self, Self)> {
235        use scirs2_core::ndarray::ndarray_linalg::SVD;
236
237        // ---- validation -------------------------------------------------------
238        let total_axes = left_axes.len() + right_axes.len();
239        if total_axes != self.rank {
240            return Err(QuantRS2Error::CircuitValidationFailed(format!(
241                "SVD: left_axes ({}) + right_axes ({}) must equal tensor rank ({})",
242                left_axes.len(),
243                right_axes.len(),
244                self.rank
245            )));
246        }
247        // Check no duplicates and all in range
248        {
249            let mut seen = vec![false; self.rank];
250            for &ax in left_axes.iter().chain(right_axes.iter()) {
251                if ax >= self.rank {
252                    return Err(QuantRS2Error::CircuitValidationFailed(format!(
253                        "SVD: axis {ax} out of range for rank-{} tensor",
254                        self.rank
255                    )));
256                }
257                if seen[ax] {
258                    return Err(QuantRS2Error::CircuitValidationFailed(format!(
259                        "SVD: duplicate axis {ax}"
260                    )));
261                }
262                seen[ax] = true;
263            }
264        }
265        if max_bond_dim == 0 {
266            return Err(QuantRS2Error::CircuitValidationFailed(
267                "SVD: max_bond_dim must be >= 1".to_string(),
268            ));
269        }
270
271        // ---- compute row/col sizes for the reshaped matrix --------------------
272        let left_dims: Vec<usize> = left_axes.iter().map(|&ax| self.dimensions[ax]).collect();
273        let right_dims: Vec<usize> = right_axes.iter().map(|&ax| self.dimensions[ax]).collect();
274
275        let left_size: usize = left_dims.iter().product::<usize>().max(1);
276        let right_size: usize = right_dims.iter().product::<usize>().max(1);
277
278        // ---- permute and reshape to matrix (left_size, right_size) ------------
279        // Build a permutation: left_axes first, then right_axes
280        let permutation: Vec<usize> = left_axes.iter().chain(right_axes.iter()).copied().collect();
281
282        // Collect self.data into standard layout after permuting axes
283        let perm_data: ArrayD<Complex64> = {
284            // permuted_axes on a dynamic array view returns a view with reordered axes
285            let view = self.data.view();
286            let permuted = view.permuted_axes(permutation.as_slice());
287            // Force into owned contiguous array (standard layout)
288            permuted.as_standard_layout().into_owned()
289        };
290
291        // Reshape to 2D matrix by collecting the permuted data into a flat vec,
292        // then building an Array2.  This approach avoids ndarray dimensionality
293        // conversion subtleties with IxDyn vs. Ix2.
294        let flat: Vec<Complex64> = perm_data.into_raw_vec_and_offset().0;
295        let matrix: Array2<Complex64> = Array2::from_shape_vec((left_size, right_size), flat)
296            .map_err(|e| {
297                QuantRS2Error::CircuitValidationFailed(format!("SVD reshape to matrix failed: {e}"))
298            })?;
299
300        // ---- SVD via OxiBLAS/ndarray_linalg -----------------------------------
301        // SVD trait: (U, S, Vt) where U is (m,k), S is (k,), Vt is (k,n)
302        // with compute_u=true, compute_vt=true and thin=true (economy SVD)
303        let (u_full, s_full, vt_full) = matrix.svd(true, true).map_err(|e| {
304            QuantRS2Error::CircuitValidationFailed(format!("SVD computation failed: {e}"))
305        })?;
306
307        // ---- truncation -------------------------------------------------------
308        let rank_cap = left_size.min(right_size);
309        let bond_dim = max_bond_dim.min(rank_cap).min(s_full.len());
310        let bond_dim = bond_dim.max(1);
311
312        // Keep only the top `bond_dim` singular triplets
313        let s_trunc: Array1<f64> = s_full
314            .slice(scirs2_core::ndarray::s![..bond_dim])
315            .to_owned();
316        let u_trunc: Array2<Complex64> = u_full
317            .slice(scirs2_core::ndarray::s![.., ..bond_dim])
318            .to_owned();
319        let vt_trunc: Array2<Complex64> = vt_full
320            .slice(scirs2_core::ndarray::s![..bond_dim, ..])
321            .to_owned();
322
323        // ---- build left tensor: U * diag(S), shape (*left_dims, bond_dim) ----
324        // Absorb singular values into U columns
325        let mut us: Array2<Complex64> = u_trunc;
326        for j in 0..bond_dim {
327            let sigma = Complex64::new(s_trunc[j], 0.0);
328            for i in 0..left_size {
329                us[[i, j]] *= sigma;
330            }
331        }
332
333        let mut left_shape = left_dims.clone();
334        left_shape.push(bond_dim);
335        // Flatten us to a vec, then rebuild as ArrayD with the desired shape.
336        // OxiBLAS returns U/Vᴴ as column-major (Fortran) arrays, so we must
337        // flatten in logical row-major order (not raw memory order) to match
338        // `Array::from_shape_vec`, which interprets the vec as row-major.
339        let us_flat: Vec<Complex64> = us.as_standard_layout().iter().copied().collect();
340        let left_data: ArrayD<Complex64> =
341            Array::from_shape_vec(IxDyn(left_shape.as_slice()), us_flat).map_err(|e| {
342                QuantRS2Error::CircuitValidationFailed(format!("SVD left reshape failed: {e}"))
343            })?;
344        let left_rank = left_shape.len();
345
346        // ---- build right tensor: Vᴴ, shape (bond_dim, *right_dims) -----------
347        let mut right_shape = vec![bond_dim];
348        right_shape.extend_from_slice(&right_dims);
349        // Same column-major → row-major fix as for the left tensor above.
350        let vt_flat: Vec<Complex64> = vt_trunc.as_standard_layout().iter().copied().collect();
351        let right_data: ArrayD<Complex64> =
352            Array::from_shape_vec(IxDyn(right_shape.as_slice()), vt_flat).map_err(|e| {
353                QuantRS2Error::CircuitValidationFailed(format!("SVD right reshape failed: {e}"))
354            })?;
355        let right_rank = right_shape.len();
356
357        let left_tensor = Self {
358            data: left_data,
359            dimensions: left_shape,
360            rank: left_rank,
361        };
362        let right_tensor = Self {
363            data: right_data,
364            dimensions: right_shape,
365            rank: right_rank,
366        };
367
368        Ok((left_tensor, right_tensor))
369    }
370}
371
372/// A reference to a specific tensor and one of its indices
373#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
374pub struct TensorIndex {
375    /// The ID of the tensor
376    pub tensor_id: usize,
377
378    /// The index within the tensor
379    pub index: usize,
380}