quantrs2_sim/tensor_network/
tensor.rs1use quantrs2_core::error::{QuantRS2Error, QuantRS2Result};
7use scirs2_core::ndarray::{Array, Array1, Array2, ArrayD, Dimension, IxDyn};
8use scirs2_core::Complex64;
9
10#[derive(Debug, Clone)]
12pub struct Tensor {
13 pub data: ArrayD<Complex64>,
15
16 pub rank: usize,
18
19 pub dimensions: Vec<usize>,
21}
22
23impl Tensor {
24 pub fn new(data: ArrayD<Complex64>) -> Self {
26 let dimensions = data.shape().to_vec();
27 let rank = dimensions.len();
28
29 Self {
30 data,
31 rank,
32 dimensions,
33 }
34 }
35
36 pub fn from_matrix(matrix: &[Complex64], dim: usize) -> Self {
38 let _n = (matrix.len() as f64).sqrt() as usize;
40
41 let mut shape = Vec::new();
43 for _ in 0..dim {
44 shape.push(2); }
46
47 let mut data = ArrayD::zeros(IxDyn(&shape));
49
50 let flat_data = data
52 .as_slice_mut()
53 .expect("Tensor data should be contiguous in memory");
54 for (i, val) in matrix.iter().enumerate() {
55 if i < flat_data.len() {
56 flat_data[i] = *val;
57 }
58 }
59
60 Self::new(data)
61 }
62
63 pub fn qubit_zero() -> Self {
65 let data = Array::from_shape_vec(
66 IxDyn(&[2]),
67 vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
68 )
69 .expect("Valid shape for qubit |0> state");
70
71 Self::new(data)
72 }
73
74 pub fn qubit_one() -> Self {
76 let data = Array::from_shape_vec(
77 IxDyn(&[2]),
78 vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
79 )
80 .expect("Valid shape for qubit |1> state");
81
82 Self::new(data)
83 }
84
85 pub fn qubit_plus() -> Self {
87 let data = Array::from_shape_vec(
88 IxDyn(&[2]),
89 vec![
90 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
91 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
92 ],
93 )
94 .expect("Valid shape for qubit |+> state");
95
96 Self::new(data)
97 }
98
99 pub fn contract(
108 &self,
109 other: &Self,
110 self_axis: usize,
111 other_axis: usize,
112 ) -> QuantRS2Result<Self> {
113 if self_axis >= self.rank || other_axis >= other.rank {
115 return Err(QuantRS2Error::CircuitValidationFailed(format!(
116 "Invalid contraction axes: {self_axis} and {other_axis}"
117 )));
118 }
119
120 if self.dimensions[self_axis] != other.dimensions[other_axis] {
122 return Err(QuantRS2Error::CircuitValidationFailed(format!(
123 "Mismatched dimensions for contraction: {} and {}",
124 self.dimensions[self_axis], other.dimensions[other_axis]
125 )));
126 }
127
128 let _contract_dim = self.dimensions[self_axis];
129
130 let self_outer_dims: Vec<usize> = self
133 .dimensions
134 .iter()
135 .enumerate()
136 .filter(|&(i, _)| i != self_axis)
137 .map(|(_, &d)| d)
138 .collect();
139 let other_outer_dims: Vec<usize> = other
140 .dimensions
141 .iter()
142 .enumerate()
143 .filter(|&(i, _)| i != other_axis)
144 .map(|(_, &d)| d)
145 .collect();
146
147 let mut result_dims = self_outer_dims.clone();
148 result_dims.extend_from_slice(&other_outer_dims);
149
150 let result_is_scalar = result_dims.is_empty();
152
153 let result_shape = if result_is_scalar {
154 IxDyn(&[1usize])
155 } else {
156 IxDyn(result_dims.as_slice())
157 };
158
159 let mut result_data = ArrayD::zeros(result_shape);
160
161 for (self_idx, self_val) in self.data.indexed_iter() {
165 let self_raw = self_idx.slice();
166 let k = self_raw[self_axis];
167
168 let self_outer_idx: Vec<usize> = self_raw
170 .iter()
171 .enumerate()
172 .filter(|&(i, _)| i != self_axis)
173 .map(|(_, &v)| v)
174 .collect();
175
176 for (other_idx, other_val) in other.data.indexed_iter() {
177 let other_raw = other_idx.slice();
178 if other_raw[other_axis] != k {
179 continue;
180 }
181
182 let other_outer_idx: Vec<usize> = other_raw
184 .iter()
185 .enumerate()
186 .filter(|&(i, _)| i != other_axis)
187 .map(|(_, &v)| v)
188 .collect();
189
190 let mut res_idx = self_outer_idx.clone();
192 res_idx.extend_from_slice(&other_outer_idx);
193
194 let target = if result_is_scalar {
195 &mut result_data[IxDyn(&[0usize])]
196 } else {
197 &mut result_data[IxDyn(res_idx.as_slice())]
198 };
199 *target += *self_val * *other_val;
200 }
201 }
202
203 let final_data = if result_is_scalar {
205 let scalar_val = result_data[IxDyn(&[0usize])];
206 ArrayD::from_elem(IxDyn(&[]), scalar_val)
207 } else {
208 result_data
209 };
210
211 let result_rank = result_dims.len();
212 Ok(Self {
213 data: final_data,
214 dimensions: result_dims,
215 rank: result_rank,
216 })
217 }
218
219 pub fn svd(
230 &self,
231 left_axes: &[usize],
232 right_axes: &[usize],
233 max_bond_dim: usize,
234 ) -> QuantRS2Result<(Self, Self)> {
235 use scirs2_core::ndarray::ndarray_linalg::SVD;
236
237 let total_axes = left_axes.len() + right_axes.len();
239 if total_axes != self.rank {
240 return Err(QuantRS2Error::CircuitValidationFailed(format!(
241 "SVD: left_axes ({}) + right_axes ({}) must equal tensor rank ({})",
242 left_axes.len(),
243 right_axes.len(),
244 self.rank
245 )));
246 }
247 {
249 let mut seen = vec![false; self.rank];
250 for &ax in left_axes.iter().chain(right_axes.iter()) {
251 if ax >= self.rank {
252 return Err(QuantRS2Error::CircuitValidationFailed(format!(
253 "SVD: axis {ax} out of range for rank-{} tensor",
254 self.rank
255 )));
256 }
257 if seen[ax] {
258 return Err(QuantRS2Error::CircuitValidationFailed(format!(
259 "SVD: duplicate axis {ax}"
260 )));
261 }
262 seen[ax] = true;
263 }
264 }
265 if max_bond_dim == 0 {
266 return Err(QuantRS2Error::CircuitValidationFailed(
267 "SVD: max_bond_dim must be >= 1".to_string(),
268 ));
269 }
270
271 let left_dims: Vec<usize> = left_axes.iter().map(|&ax| self.dimensions[ax]).collect();
273 let right_dims: Vec<usize> = right_axes.iter().map(|&ax| self.dimensions[ax]).collect();
274
275 let left_size: usize = left_dims.iter().product::<usize>().max(1);
276 let right_size: usize = right_dims.iter().product::<usize>().max(1);
277
278 let permutation: Vec<usize> = left_axes.iter().chain(right_axes.iter()).copied().collect();
281
282 let perm_data: ArrayD<Complex64> = {
284 let view = self.data.view();
286 let permuted = view.permuted_axes(permutation.as_slice());
287 permuted.as_standard_layout().into_owned()
289 };
290
291 let flat: Vec<Complex64> = perm_data.into_raw_vec_and_offset().0;
295 let matrix: Array2<Complex64> = Array2::from_shape_vec((left_size, right_size), flat)
296 .map_err(|e| {
297 QuantRS2Error::CircuitValidationFailed(format!("SVD reshape to matrix failed: {e}"))
298 })?;
299
300 let (u_full, s_full, vt_full) = matrix.svd(true, true).map_err(|e| {
304 QuantRS2Error::CircuitValidationFailed(format!("SVD computation failed: {e}"))
305 })?;
306
307 let rank_cap = left_size.min(right_size);
309 let bond_dim = max_bond_dim.min(rank_cap).min(s_full.len());
310 let bond_dim = bond_dim.max(1);
311
312 let s_trunc: Array1<f64> = s_full
314 .slice(scirs2_core::ndarray::s![..bond_dim])
315 .to_owned();
316 let u_trunc: Array2<Complex64> = u_full
317 .slice(scirs2_core::ndarray::s![.., ..bond_dim])
318 .to_owned();
319 let vt_trunc: Array2<Complex64> = vt_full
320 .slice(scirs2_core::ndarray::s![..bond_dim, ..])
321 .to_owned();
322
323 let mut us: Array2<Complex64> = u_trunc;
326 for j in 0..bond_dim {
327 let sigma = Complex64::new(s_trunc[j], 0.0);
328 for i in 0..left_size {
329 us[[i, j]] *= sigma;
330 }
331 }
332
333 let mut left_shape = left_dims.clone();
334 left_shape.push(bond_dim);
335 let us_flat: Vec<Complex64> = us.as_standard_layout().iter().copied().collect();
340 let left_data: ArrayD<Complex64> =
341 Array::from_shape_vec(IxDyn(left_shape.as_slice()), us_flat).map_err(|e| {
342 QuantRS2Error::CircuitValidationFailed(format!("SVD left reshape failed: {e}"))
343 })?;
344 let left_rank = left_shape.len();
345
346 let mut right_shape = vec![bond_dim];
348 right_shape.extend_from_slice(&right_dims);
349 let vt_flat: Vec<Complex64> = vt_trunc.as_standard_layout().iter().copied().collect();
351 let right_data: ArrayD<Complex64> =
352 Array::from_shape_vec(IxDyn(right_shape.as_slice()), vt_flat).map_err(|e| {
353 QuantRS2Error::CircuitValidationFailed(format!("SVD right reshape failed: {e}"))
354 })?;
355 let right_rank = right_shape.len();
356
357 let left_tensor = Self {
358 data: left_data,
359 dimensions: left_shape,
360 rank: left_rank,
361 };
362 let right_tensor = Self {
363 data: right_data,
364 dimensions: right_shape,
365 rank: right_rank,
366 };
367
368 Ok((left_tensor, right_tensor))
369 }
370}
371
372#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
374pub struct TensorIndex {
375 pub tensor_id: usize,
377
378 pub index: usize,
380}