1use crate::{
7 error::{QuantRS2Error, QuantRS2Result},
8 gate::GateOp,
9 linalg_stubs::svd,
10 register::Register,
11};
12use scirs2_core::ndarray::{Array, Array2, ArrayD, IxDyn};
13use scirs2_core::Complex;
14use std::collections::{HashMap, HashSet};
16
17type Complex64 = Complex<f64>;
19
20#[derive(Debug, Clone)]
22pub struct Tensor {
23 pub id: usize,
25 pub data: ArrayD<Complex64>,
27 pub indices: Vec<String>,
29 pub shape: Vec<usize>,
31}
32
33impl Tensor {
34 pub fn new(id: usize, data: ArrayD<Complex64>, indices: Vec<String>) -> Self {
36 let shape = data.shape().to_vec();
37 Self {
38 id,
39 data,
40 indices,
41 shape,
42 }
43 }
44
45 pub fn from_matrix(
47 id: usize,
48 matrix: Array2<Complex64>,
49 in_idx: String,
50 out_idx: String,
51 ) -> Self {
52 let shape = matrix.shape().to_vec();
53 let data = matrix.into_dyn();
54 Self {
55 id,
56 data,
57 indices: vec![in_idx, out_idx],
58 shape,
59 }
60 }
61
62 pub fn qubit_zero(id: usize, idx: String) -> Self {
64 let mut data = Array::zeros(IxDyn(&[2]));
65 data[[0]] = Complex64::new(1.0, 0.0);
66 Self {
67 id,
68 data,
69 indices: vec![idx],
70 shape: vec![2],
71 }
72 }
73
74 pub fn qubit_one(id: usize, idx: String) -> Self {
76 let mut data = Array::zeros(IxDyn(&[2]));
77 data[[1]] = Complex64::new(1.0, 0.0);
78 Self {
79 id,
80 data,
81 indices: vec![idx],
82 shape: vec![2],
83 }
84 }
85
86 pub fn from_array<D>(
88 array: scirs2_core::ndarray::ArrayBase<scirs2_core::ndarray::OwnedRepr<Complex64>, D>,
89 indices: Vec<usize>,
90 ) -> Self
91 where
92 D: scirs2_core::ndarray::Dimension,
93 {
94 let shape = array.shape().to_vec();
95 let data = array.into_dyn();
96 let index_labels: Vec<String> = indices.iter().map(|i| format!("idx_{i}")).collect();
97 Self {
98 id: 0, data,
100 indices: index_labels,
101 shape,
102 }
103 }
104
105 pub fn rank(&self) -> usize {
107 self.indices.len()
108 }
109
110 pub const fn tensor(&self) -> &ArrayD<Complex64> {
112 &self.data
113 }
114
115 pub fn ndim(&self) -> usize {
117 self.data.ndim()
118 }
119
120 pub fn contract(&self, other: &Self, self_idx: &str, other_idx: &str) -> QuantRS2Result<Self> {
122 let self_pos = self
124 .indices
125 .iter()
126 .position(|s| s == self_idx)
127 .ok_or_else(|| {
128 QuantRS2Error::InvalidInput(format!("Index {self_idx} not found in tensor"))
129 })?;
130 let other_pos = other
131 .indices
132 .iter()
133 .position(|s| s == other_idx)
134 .ok_or_else(|| {
135 QuantRS2Error::InvalidInput(format!("Index {other_idx} not found in tensor"))
136 })?;
137
138 if self.shape[self_pos] != other.shape[other_pos] {
140 return Err(QuantRS2Error::InvalidInput(format!(
141 "Cannot contract indices with different dimensions: {} vs {}",
142 self.shape[self_pos], other.shape[other_pos]
143 )));
144 }
145
146 let contracted = self.contract_indices(&other, self_pos, other_pos)?;
148
149 let mut new_indices = Vec::new();
151 for (i, idx) in self.indices.iter().enumerate() {
152 if i != self_pos {
153 new_indices.push(idx.clone());
154 }
155 }
156 for (i, idx) in other.indices.iter().enumerate() {
157 if i != other_pos {
158 new_indices.push(idx.clone());
159 }
160 }
161
162 Ok(Self::new(
163 self.id.max(other.id) + 1,
164 contracted,
165 new_indices,
166 ))
167 }
168
169 fn contract_indices(
171 &self,
172 other: &Self,
173 self_idx: usize,
174 other_idx: usize,
175 ) -> QuantRS2Result<ArrayD<Complex64>> {
176 let self_shape = self.data.shape();
178 let other_shape = other.data.shape();
179
180 let mut self_left_dims = 1;
182 let mut self_right_dims = 1;
183 for i in 0..self_idx {
184 self_left_dims *= self_shape[i];
185 }
186 for i in (self_idx + 1)..self_shape.len() {
187 self_right_dims *= self_shape[i];
188 }
189
190 let mut other_left_dims = 1;
191 let mut other_right_dims = 1;
192 for i in 0..other_idx {
193 other_left_dims *= other_shape[i];
194 }
195 for i in (other_idx + 1)..other_shape.len() {
196 other_right_dims *= other_shape[i];
197 }
198
199 let contract_dim = self_shape[self_idx];
200
201 let self_mat = self
203 .data
204 .view()
205 .into_shape_with_order((self_left_dims, contract_dim * self_right_dims))
206 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
207 .to_owned();
208 let other_mat = other
209 .data
210 .view()
211 .into_shape_with_order((other_left_dims * contract_dim, other_right_dims))
212 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
213 .to_owned();
214
215 let _result_mat: Array2<Complex64> = Array2::zeros((
217 self_left_dims * self_right_dims,
218 other_left_dims * other_right_dims,
219 ));
220
221 let mut result_vec = Vec::new();
223 for i in 0..self_left_dims {
224 for j in 0..self_right_dims {
225 for k in 0..other_left_dims {
226 for l in 0..other_right_dims {
227 let mut sum = Complex64::new(0.0, 0.0);
228 for c in 0..contract_dim {
229 sum += self_mat[[i, c * self_right_dims + j]]
233 * other_mat[[k * contract_dim + c, l]];
234 }
235 result_vec.push(sum);
236 }
237 }
238 }
239 }
240
241 let mut result_shape = Vec::new();
243 for i in 0..self_idx {
244 result_shape.push(self_shape[i]);
245 }
246 for i in (self_idx + 1)..self_shape.len() {
247 result_shape.push(self_shape[i]);
248 }
249 for i in 0..other_idx {
250 result_shape.push(other_shape[i]);
251 }
252 for i in (other_idx + 1)..other_shape.len() {
253 result_shape.push(other_shape[i]);
254 }
255
256 ArrayD::from_shape_vec(IxDyn(&result_shape), result_vec)
257 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))
258 }
259
260 pub fn svd_decompose(
262 &self,
263 idx: usize,
264 max_rank: Option<usize>,
265 ) -> QuantRS2Result<(Self, Self)> {
266 if idx >= self.rank() {
267 return Err(QuantRS2Error::InvalidInput(format!(
268 "Index {} out of bounds for tensor with rank {}",
269 idx,
270 self.rank()
271 )));
272 }
273
274 let shape = self.data.shape();
276 let mut left_dim = 1;
277 let mut right_dim = 1;
278
279 for i in 0..=idx {
280 left_dim *= shape[i];
281 }
282 for i in (idx + 1)..shape.len() {
283 right_dim *= shape[i];
284 }
285
286 let matrix = self
288 .data
289 .view()
290 .into_shape_with_order((left_dim, right_dim))
291 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
292 .to_owned();
293
294 let real_matrix = matrix.mapv(|c| c.re);
296 let (u, s, vt) = svd(&real_matrix.view(), false, None)
297 .map_err(|e| QuantRS2Error::ComputationError(format!("SVD failed: {e:?}")))?;
298
299 let rank = if let Some(max_r) = max_rank {
301 max_r.min(s.len())
302 } else {
303 s.len()
304 };
305
306 let u_trunc = u.slice(scirs2_core::ndarray::s![.., ..rank]).to_owned();
308 let s_trunc = s.slice(scirs2_core::ndarray::s![..rank]).to_owned();
309 let vt_trunc = vt.slice(scirs2_core::ndarray::s![..rank, ..]).to_owned();
310
311 let mut s_mat = Array2::zeros((rank, rank));
313 for i in 0..rank {
314 s_mat[[i, i]] = Complex64::new(s_trunc[i].sqrt(), 0.0);
315 }
316
317 let left_data = u_trunc.mapv(|x| Complex64::new(x, 0.0)).dot(&s_mat);
319 let right_data = s_mat.dot(&vt_trunc.mapv(|x| Complex64::new(x, 0.0)));
320
321 let mut left_indices = self.indices[..=idx].to_vec();
323 left_indices.push(format!("bond_{}", self.id));
324
325 let mut right_indices = vec![format!("bond_{}", self.id)];
326 right_indices.extend_from_slice(&self.indices[(idx + 1)..]);
327
328 let left_tensor = Self::new(self.id * 2, left_data.into_dyn(), left_indices);
329
330 let right_tensor = Self::new(self.id * 2 + 1, right_data.into_dyn(), right_indices);
331
332 Ok((left_tensor, right_tensor))
333 }
334}
335
336#[derive(Debug, Clone, PartialEq, Eq, Hash)]
338pub struct TensorEdge {
339 pub tensor1: usize,
341 pub index1: String,
343 pub tensor2: usize,
345 pub index2: String,
347}
348
349#[derive(Debug)]
351pub struct TensorNetwork {
352 pub tensors: HashMap<usize, Tensor>,
354 pub edges: Vec<TensorEdge>,
356 pub open_indices: HashMap<usize, Vec<String>>,
358 next_id: usize,
360}
361
362impl TensorNetwork {
363 pub fn new() -> Self {
365 Self {
366 tensors: HashMap::new(),
367 edges: Vec::new(),
368 open_indices: HashMap::new(),
369 next_id: 0,
370 }
371 }
372
373 pub fn add_tensor(&mut self, tensor: Tensor) -> usize {
375 let id = tensor.id;
376 self.open_indices.insert(id, tensor.indices.clone());
377 self.tensors.insert(id, tensor);
378 self.next_id = self.next_id.max(id + 1);
379 id
380 }
381
382 pub fn connect(
384 &mut self,
385 tensor1: usize,
386 index1: String,
387 tensor2: usize,
388 index2: String,
389 ) -> QuantRS2Result<()> {
390 if !self.tensors.contains_key(&tensor1) {
392 return Err(QuantRS2Error::InvalidInput(format!(
393 "Tensor {tensor1} not found"
394 )));
395 }
396 if !self.tensors.contains_key(&tensor2) {
397 return Err(QuantRS2Error::InvalidInput(format!(
398 "Tensor {tensor2} not found"
399 )));
400 }
401
402 let t1 = &self.tensors[&tensor1];
404 let t2 = &self.tensors[&tensor2];
405
406 let idx1_pos = t1
407 .indices
408 .iter()
409 .position(|s| s == &index1)
410 .ok_or_else(|| {
411 QuantRS2Error::InvalidInput(format!("Index {index1} not found in tensor {tensor1}"))
412 })?;
413 let idx2_pos = t2
414 .indices
415 .iter()
416 .position(|s| s == &index2)
417 .ok_or_else(|| {
418 QuantRS2Error::InvalidInput(format!("Index {index2} not found in tensor {tensor2}"))
419 })?;
420
421 if t1.shape[idx1_pos] != t2.shape[idx2_pos] {
422 return Err(QuantRS2Error::InvalidInput(format!(
423 "Connected indices must have same dimension: {} vs {}",
424 t1.shape[idx1_pos], t2.shape[idx2_pos]
425 )));
426 }
427
428 self.edges.push(TensorEdge {
430 tensor1,
431 index1: index1.clone(),
432 tensor2,
433 index2: index2.clone(),
434 });
435
436 if let Some(indices) = self.open_indices.get_mut(&tensor1) {
438 indices.retain(|s| s != &index1);
439 }
440 if let Some(indices) = self.open_indices.get_mut(&tensor2) {
441 indices.retain(|s| s != &index2);
442 }
443
444 Ok(())
445 }
446
447 pub fn find_contraction_order(&self) -> Vec<(usize, usize)> {
449 let mut remaining_tensors: HashSet<_> = self.tensors.keys().copied().collect();
451 let mut order = Vec::new();
452
453 let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
455 for edge in &self.edges {
456 adjacency
457 .entry(edge.tensor1)
458 .or_insert_with(Vec::new)
459 .push(edge.tensor2);
460 adjacency
461 .entry(edge.tensor2)
462 .or_insert_with(Vec::new)
463 .push(edge.tensor1);
464 }
465
466 while remaining_tensors.len() > 1 {
467 let mut best_pair = None;
468 let mut min_cost = usize::MAX;
469
470 for &t1 in &remaining_tensors {
472 if let Some(neighbors) = adjacency.get(&t1) {
473 for &t2 in neighbors {
474 if t2 > t1 && remaining_tensors.contains(&t2) {
475 let cost = self.estimate_contraction_cost(t1, t2);
477 if cost < min_cost {
478 min_cost = cost;
479 best_pair = Some((t1, t2));
480 }
481 }
482 }
483 }
484 }
485
486 if let Some((t1, t2)) = best_pair {
487 order.push((t1, t2));
488 remaining_tensors.remove(&t1);
489 remaining_tensors.remove(&t2);
490
491 let virtual_id = self.next_id + order.len();
493 remaining_tensors.insert(virtual_id);
494
495 let mut virtual_neighbors = HashSet::new();
497 if let Some(n1) = adjacency.get(&t1) {
498 virtual_neighbors.extend(
499 n1.iter()
500 .filter(|&&n| n != t2 && remaining_tensors.contains(&n)),
501 );
502 }
503 if let Some(n2) = adjacency.get(&t2) {
504 virtual_neighbors.extend(
505 n2.iter()
506 .filter(|&&n| n != t1 && remaining_tensors.contains(&n)),
507 );
508 }
509 adjacency.insert(virtual_id, virtual_neighbors.into_iter().collect());
510 } else {
511 break;
512 }
513 }
514
515 order
516 }
517
518 const fn estimate_contraction_cost(&self, _t1: usize, _t2: usize) -> usize {
520 1000 }
524
525 pub fn contract_all(&mut self) -> QuantRS2Result<Tensor> {
527 if self.tensors.is_empty() {
528 return Err(QuantRS2Error::InvalidInput(
529 "Cannot contract empty tensor network".into(),
530 ));
531 }
532
533 if self.tensors.len() == 1 {
534 return self
535 .tensors
536 .values()
537 .next()
538 .map(|t| t.clone())
539 .ok_or_else(|| {
540 QuantRS2Error::InvalidInput("Single tensor expected but not found".into())
541 });
542 }
543
544 let order = self.find_contraction_order();
546
547 let mut tensor_map = self.tensors.clone();
549 let mut next_id = self.next_id;
550
551 for (t1_id, t2_id) in order {
552 let edge = self
554 .edges
555 .iter()
556 .find(|e| {
557 (e.tensor1 == t1_id && e.tensor2 == t2_id)
558 || (e.tensor1 == t2_id && e.tensor2 == t1_id)
559 })
560 .ok_or_else(|| QuantRS2Error::InvalidInput("Tensors not connected".into()))?;
561
562 let t1 = tensor_map
563 .remove(&t1_id)
564 .ok_or_else(|| QuantRS2Error::InvalidInput("Tensor not found".into()))?;
565 let t2 = tensor_map
566 .remove(&t2_id)
567 .ok_or_else(|| QuantRS2Error::InvalidInput("Tensor not found".into()))?;
568
569 let contracted = if edge.tensor1 == t1_id {
571 t1.contract(&t2, &edge.index1, &edge.index2)?
572 } else {
573 t1.contract(&t2, &edge.index2, &edge.index1)?
574 };
575
576 let mut new_tensor = contracted;
578 new_tensor.id = next_id;
579 tensor_map.insert(next_id, new_tensor);
580 next_id += 1;
581 }
582
583 tensor_map
585 .into_values()
586 .next()
587 .ok_or_else(|| QuantRS2Error::InvalidInput("Contraction failed".into()))
588 }
589
590 pub const fn to_mps(&self, _max_bond_dim: Option<usize>) -> QuantRS2Result<Vec<Tensor>> {
592 Ok(vec![])
595 }
596
597 pub const fn apply_mpo(&mut self, _mpo: &[Tensor], _qubits: &[usize]) -> QuantRS2Result<()> {
599 Ok(())
601 }
602
603 pub fn tensors(&self) -> Vec<&Tensor> {
605 self.tensors.values().collect()
606 }
607
608 pub fn tensor(&self, id: usize) -> Option<&Tensor> {
610 self.tensors.get(&id)
611 }
612}
613
614pub struct TensorNetworkBuilder {
616 network: TensorNetwork,
617 qubit_indices: HashMap<usize, String>,
618 current_indices: HashMap<usize, String>,
619}
620
621impl TensorNetworkBuilder {
622 pub fn new(num_qubits: usize) -> Self {
624 let mut network = TensorNetwork::new();
625 let mut qubit_indices = HashMap::new();
626 let mut current_indices = HashMap::new();
627
628 for i in 0..num_qubits {
630 let idx = format!("q{i}_0");
631 let tensor = Tensor::qubit_zero(i, idx.clone());
632 network.add_tensor(tensor);
633 qubit_indices.insert(i, idx.clone());
634 current_indices.insert(i, idx);
635 }
636
637 Self {
638 network,
639 qubit_indices,
640 current_indices,
641 }
642 }
643
644 pub fn apply_single_qubit_gate(
646 &mut self,
647 gate: &dyn GateOp,
648 qubit: usize,
649 ) -> QuantRS2Result<()> {
650 let matrix_vec = gate.matrix()?;
651 let matrix = Array2::from_shape_vec((2, 2), matrix_vec)
652 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?;
653
654 let in_idx = self.current_indices[&qubit].clone();
656 let out_idx = format!("q{}_{}", qubit, self.network.next_id);
657 let gate_tensor = Tensor::from_matrix(
658 self.network.next_id,
659 matrix,
660 in_idx.clone(),
661 out_idx.clone(),
662 );
663
664 let gate_id = self.network.add_tensor(gate_tensor);
666
667 if let Some(prev_tensor) = self.find_tensor_with_index(&in_idx) {
669 self.network
670 .connect(prev_tensor, in_idx.clone(), gate_id, in_idx)?;
671 }
672
673 self.current_indices.insert(qubit, out_idx);
675
676 Ok(())
677 }
678
679 pub fn apply_two_qubit_gate(
681 &mut self,
682 gate: &dyn GateOp,
683 qubit1: usize,
684 qubit2: usize,
685 ) -> QuantRS2Result<()> {
686 let matrix_vec = gate.matrix()?;
687 let matrix = Array2::from_shape_vec((4, 4), matrix_vec)
688 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?;
689
690 let tensor_data = matrix
692 .into_shape_with_order((2, 2, 2, 2))
693 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
694 .into_dyn();
695
696 let in1_idx = self.current_indices[&qubit1].clone();
698 let in2_idx = self.current_indices[&qubit2].clone();
699 let out1_idx = format!("q{}_{}", qubit1, self.network.next_id);
700 let out2_idx = format!("q{}_{}", qubit2, self.network.next_id);
701
702 let gate_tensor = Tensor::new(
703 self.network.next_id,
704 tensor_data,
705 vec![
706 in1_idx.clone(),
707 in2_idx.clone(),
708 out1_idx.clone(),
709 out2_idx.clone(),
710 ],
711 );
712
713 let gate_id = self.network.add_tensor(gate_tensor);
715
716 if let Some(prev1) = self.find_tensor_with_index(&in1_idx) {
718 self.network
719 .connect(prev1, in1_idx.clone(), gate_id, in1_idx)?;
720 }
721 if let Some(prev2) = self.find_tensor_with_index(&in2_idx) {
722 self.network
723 .connect(prev2, in2_idx.clone(), gate_id, in2_idx)?;
724 }
725
726 self.current_indices.insert(qubit1, out1_idx);
728 self.current_indices.insert(qubit2, out2_idx);
729
730 Ok(())
731 }
732
733 fn find_tensor_with_index(&self, index: &str) -> Option<usize> {
735 for (id, tensor) in &self.network.tensors {
736 if tensor.indices.iter().any(|idx| idx == index) {
737 return Some(*id);
738 }
739 }
740 None
741 }
742
743 pub fn build(self) -> TensorNetwork {
745 self.network
746 }
747
748 #[must_use]
750 pub fn to_statevector(&mut self) -> QuantRS2Result<Vec<Complex64>> {
751 let final_tensor = self.network.contract_all()?;
752 Ok(final_tensor.data.into_raw_vec_and_offset().0)
753 }
754}
755
756pub struct TensorNetworkSimulator {
758 max_bond_dim: usize,
760 use_compression: bool,
762 parallel_threshold: usize,
764}
765
766impl TensorNetworkSimulator {
767 pub const fn new() -> Self {
769 Self {
770 max_bond_dim: 64,
771 use_compression: true,
772 parallel_threshold: 1000,
773 }
774 }
775
776 #[must_use]
778 pub const fn with_max_bond_dim(mut self, dim: usize) -> Self {
779 self.max_bond_dim = dim;
780 self
781 }
782
783 #[must_use]
785 pub const fn with_compression(mut self, compress: bool) -> Self {
786 self.use_compression = compress;
787 self
788 }
789
790 pub fn simulate<const N: usize>(
792 &self,
793 gates: &[Box<dyn GateOp>],
794 ) -> QuantRS2Result<Register<N>> {
795 let mut builder = TensorNetworkBuilder::new(N);
796
797 for gate in gates {
799 let qubits = gate.qubits();
800 match qubits.len() {
801 1 => builder.apply_single_qubit_gate(gate.as_ref(), qubits[0].0 as usize)?,
802 2 => builder.apply_two_qubit_gate(
803 gate.as_ref(),
804 qubits[0].0 as usize,
805 qubits[1].0 as usize,
806 )?,
807 _ => {
808 return Err(QuantRS2Error::UnsupportedOperation(format!(
809 "Gates with {} qubits not supported in tensor network",
810 qubits.len()
811 )))
812 }
813 }
814 }
815
816 let amplitudes = builder.to_statevector()?;
818 Register::with_amplitudes(amplitudes)
819 }
820}
821
822pub mod contraction_optimization {
824 use super::*;
825
826 pub struct DynamicProgrammingOptimizer {
828 memo: HashMap<Vec<usize>, (usize, Vec<(usize, usize)>)>,
829 }
830
831 impl DynamicProgrammingOptimizer {
832 pub fn new() -> Self {
833 Self {
834 memo: HashMap::new(),
835 }
836 }
837
838 pub fn optimize(&mut self, network: &TensorNetwork) -> Vec<(usize, usize)> {
840 let tensor_ids: Vec<_> = network.tensors.keys().copied().collect();
841 self.find_optimal_order(&tensor_ids, network).1
842 }
843
844 fn find_optimal_order(
845 &mut self,
846 tensors: &[usize],
847 network: &TensorNetwork,
848 ) -> (usize, Vec<(usize, usize)>) {
849 if tensors.len() <= 1 {
850 return (0, vec![]);
851 }
852
853 let key = tensors.to_vec();
854 if let Some(result) = self.memo.get(&key) {
855 return result.clone();
856 }
857
858 let mut best_cost = usize::MAX;
859 let mut best_order = vec![];
860
861 for i in 0..tensors.len() {
863 for j in (i + 1)..tensors.len() {
864 if self.are_connected(tensors[i], tensors[j], network) {
866 let cost = network.estimate_contraction_cost(tensors[i], tensors[j]);
867
868 let mut remaining = vec![];
870 for (k, &t) in tensors.iter().enumerate() {
871 if k != i && k != j {
872 remaining.push(t);
873 }
874 }
875 remaining.push(network.next_id + remaining.len()); let (sub_cost, sub_order) = self.find_optimal_order(&remaining, network);
878 let total_cost = cost + sub_cost;
879
880 if total_cost < best_cost {
881 best_cost = total_cost;
882 best_order = vec![(tensors[i], tensors[j])];
883 best_order.extend(sub_order);
884 }
885 }
886 }
887 }
888
889 self.memo.insert(key, (best_cost, best_order.clone()));
890 (best_cost, best_order)
891 }
892
893 fn are_connected(&self, t1: usize, t2: usize, network: &TensorNetwork) -> bool {
894 network.edges.iter().any(|e| {
895 (e.tensor1 == t1 && e.tensor2 == t2) || (e.tensor1 == t2 && e.tensor2 == t1)
896 })
897 }
898 }
899}
900
901#[cfg(test)]
902mod tests {
903 use super::*;
904
905 #[test]
906 fn test_tensor_creation() {
907 let data = ArrayD::zeros(IxDyn(&[2, 2]));
908 let tensor = Tensor::new(0, data, vec!["in".to_string(), "out".to_string()]);
909 assert_eq!(tensor.rank(), 2);
910 assert_eq!(tensor.shape, vec![2, 2]);
911 }
912
913 #[test]
914 fn test_qubit_tensors() {
915 let t0 = Tensor::qubit_zero(0, "q0".to_string());
916 assert_eq!(t0.data[[0]], Complex64::new(1.0, 0.0));
917 assert_eq!(t0.data[[1]], Complex64::new(0.0, 0.0));
918
919 let t1 = Tensor::qubit_one(1, "q1".to_string());
920 assert_eq!(t1.data[[0]], Complex64::new(0.0, 0.0));
921 assert_eq!(t1.data[[1]], Complex64::new(1.0, 0.0));
922 }
923
924 #[test]
925 fn test_tensor_network_builder() {
926 let builder = TensorNetworkBuilder::new(2);
927 assert_eq!(builder.network.tensors.len(), 2);
928 }
929
930 #[test]
931 fn test_network_connection() {
932 let mut network = TensorNetwork::new();
933
934 let t1 = Tensor::qubit_zero(0, "q0".to_string());
935 let t2 = Tensor::qubit_zero(1, "q1".to_string());
936
937 let id1 = network.add_tensor(t1);
938 let id2 = network.add_tensor(t2);
939
940 assert!(network
942 .connect(id1, "bond".to_string(), id2, "bond".to_string())
943 .is_err());
944 }
945}