1use crate::{
7 error::{QuantRS2Error, QuantRS2Result},
8 gate::GateOp,
9 register::Register,
10};
11use ndarray::{Array, Array2, ArrayD, IxDyn};
12use num_complex::Complex;
13use scirs2_linalg::svd;
14use std::collections::{HashMap, HashSet};
15
16type Complex64 = Complex<f64>;
18
19#[derive(Debug, Clone)]
21pub struct Tensor {
22 pub id: usize,
24 pub data: ArrayD<Complex64>,
26 pub indices: Vec<String>,
28 pub shape: Vec<usize>,
30}
31
32impl Tensor {
33 pub fn new(id: usize, data: ArrayD<Complex64>, indices: Vec<String>) -> Self {
35 let shape = data.shape().to_vec();
36 Self {
37 id,
38 data,
39 indices,
40 shape,
41 }
42 }
43
44 pub fn from_matrix(
46 id: usize,
47 matrix: Array2<Complex64>,
48 in_idx: String,
49 out_idx: String,
50 ) -> Self {
51 let shape = matrix.shape().to_vec();
52 let data = matrix.into_dyn();
53 Self {
54 id,
55 data,
56 indices: vec![in_idx, out_idx],
57 shape,
58 }
59 }
60
61 pub fn qubit_zero(id: usize, idx: String) -> Self {
63 let mut data = Array::zeros(IxDyn(&[2]));
64 data[[0]] = Complex64::new(1.0, 0.0);
65 Self {
66 id,
67 data,
68 indices: vec![idx],
69 shape: vec![2],
70 }
71 }
72
73 pub fn qubit_one(id: usize, idx: String) -> Self {
75 let mut data = Array::zeros(IxDyn(&[2]));
76 data[[1]] = Complex64::new(1.0, 0.0);
77 Self {
78 id,
79 data,
80 indices: vec![idx],
81 shape: vec![2],
82 }
83 }
84
85 pub fn from_array<D>(
87 array: ndarray::ArrayBase<ndarray::OwnedRepr<Complex64>, D>,
88 indices: Vec<usize>,
89 ) -> Self
90 where
91 D: ndarray::Dimension,
92 {
93 let shape = array.shape().to_vec();
94 let data = array.into_dyn();
95 let index_labels: Vec<String> = indices.iter().map(|i| format!("idx_{}", i)).collect();
96 Self {
97 id: 0, data,
99 indices: index_labels,
100 shape,
101 }
102 }
103
104 pub fn rank(&self) -> usize {
106 self.indices.len()
107 }
108
109 pub fn tensor(&self) -> &ArrayD<Complex64> {
111 &self.data
112 }
113
114 pub fn ndim(&self) -> usize {
116 self.data.ndim()
117 }
118
119 pub fn contract(
121 &self,
122 other: &Tensor,
123 self_idx: &str,
124 other_idx: &str,
125 ) -> QuantRS2Result<Tensor> {
126 let self_pos = self
128 .indices
129 .iter()
130 .position(|s| s == self_idx)
131 .ok_or_else(|| {
132 QuantRS2Error::InvalidInput(format!("Index {} not found in tensor", self_idx))
133 })?;
134 let other_pos = other
135 .indices
136 .iter()
137 .position(|s| s == other_idx)
138 .ok_or_else(|| {
139 QuantRS2Error::InvalidInput(format!("Index {} not found in tensor", other_idx))
140 })?;
141
142 if self.shape[self_pos] != other.shape[other_pos] {
144 return Err(QuantRS2Error::InvalidInput(format!(
145 "Cannot contract indices with different dimensions: {} vs {}",
146 self.shape[self_pos], other.shape[other_pos]
147 )));
148 }
149
150 let contracted = self.contract_indices(&other, self_pos, other_pos)?;
152
153 let mut new_indices = Vec::new();
155 for (i, idx) in self.indices.iter().enumerate() {
156 if i != self_pos {
157 new_indices.push(idx.clone());
158 }
159 }
160 for (i, idx) in other.indices.iter().enumerate() {
161 if i != other_pos {
162 new_indices.push(idx.clone());
163 }
164 }
165
166 Ok(Tensor::new(
167 self.id.max(other.id) + 1,
168 contracted,
169 new_indices,
170 ))
171 }
172
173 fn contract_indices(
175 &self,
176 other: &Tensor,
177 self_idx: usize,
178 other_idx: usize,
179 ) -> QuantRS2Result<ArrayD<Complex64>> {
180 let self_shape = self.data.shape();
182 let other_shape = other.data.shape();
183
184 let mut self_left_dims = 1;
186 let mut self_right_dims = 1;
187 for i in 0..self_idx {
188 self_left_dims *= self_shape[i];
189 }
190 for i in (self_idx + 1)..self_shape.len() {
191 self_right_dims *= self_shape[i];
192 }
193
194 let mut other_left_dims = 1;
195 let mut other_right_dims = 1;
196 for i in 0..other_idx {
197 other_left_dims *= other_shape[i];
198 }
199 for i in (other_idx + 1)..other_shape.len() {
200 other_right_dims *= other_shape[i];
201 }
202
203 let contract_dim = self_shape[self_idx];
204
205 let self_mat = self
207 .data
208 .view()
209 .into_shape_with_order((self_left_dims, contract_dim * self_right_dims))
210 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?
211 .to_owned();
212 let other_mat = other
213 .data
214 .view()
215 .into_shape_with_order((other_left_dims * contract_dim, other_right_dims))
216 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?
217 .to_owned();
218
219 let _result_mat: Array2<Complex64> = Array2::zeros((
221 self_left_dims * self_right_dims,
222 other_left_dims * other_right_dims,
223 ));
224
225 let mut result_vec = Vec::new();
227 for i in 0..self_left_dims {
228 for j in 0..self_right_dims {
229 for k in 0..other_left_dims {
230 for l in 0..other_right_dims {
231 let mut sum = Complex64::new(0.0, 0.0);
232 for c in 0..contract_dim {
233 let _self_idx =
234 i * contract_dim * self_right_dims + c * self_right_dims + j;
235 let _other_idx =
236 k * contract_dim * other_right_dims + c * other_right_dims + l;
237 sum += self_mat[[i, c * self_right_dims + j]]
238 * other_mat[[k * contract_dim + c, l]];
239 }
240 result_vec.push(sum);
241 }
242 }
243 }
244 }
245
246 let mut result_shape = Vec::new();
248 for i in 0..self_idx {
249 result_shape.push(self_shape[i]);
250 }
251 for i in (self_idx + 1)..self_shape.len() {
252 result_shape.push(self_shape[i]);
253 }
254 for i in 0..other_idx {
255 result_shape.push(other_shape[i]);
256 }
257 for i in (other_idx + 1)..other_shape.len() {
258 result_shape.push(other_shape[i]);
259 }
260
261 ArrayD::from_shape_vec(IxDyn(&result_shape), result_vec)
262 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))
263 }
264
265 pub fn svd_decompose(
267 &self,
268 idx: usize,
269 max_rank: Option<usize>,
270 ) -> QuantRS2Result<(Tensor, Tensor)> {
271 if idx >= self.rank() {
272 return Err(QuantRS2Error::InvalidInput(format!(
273 "Index {} out of bounds for tensor with rank {}",
274 idx,
275 self.rank()
276 )));
277 }
278
279 let shape = self.data.shape();
281 let mut left_dim = 1;
282 let mut right_dim = 1;
283
284 for i in 0..=idx {
285 left_dim *= shape[i];
286 }
287 for i in (idx + 1)..shape.len() {
288 right_dim *= shape[i];
289 }
290
291 let matrix = self
293 .data
294 .view()
295 .into_shape_with_order((left_dim, right_dim))
296 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?
297 .to_owned();
298
299 let real_matrix = matrix.mapv(|c| c.re);
301 let (u, s, vt) = svd(&real_matrix.view(), false)
302 .map_err(|e| QuantRS2Error::ComputationError(format!("SVD failed: {:?}", e)))?;
303
304 let rank = if let Some(max_r) = max_rank {
306 max_r.min(s.len())
307 } else {
308 s.len()
309 };
310
311 let u_trunc = u.slice(ndarray::s![.., ..rank]).to_owned();
313 let s_trunc = s.slice(ndarray::s![..rank]).to_owned();
314 let vt_trunc = vt.slice(ndarray::s![..rank, ..]).to_owned();
315
316 let mut s_mat = Array2::zeros((rank, rank));
318 for i in 0..rank {
319 s_mat[[i, i]] = Complex64::new(s_trunc[i].sqrt(), 0.0);
320 }
321
322 let left_data = u_trunc.mapv(|x| Complex64::new(x, 0.0)).dot(&s_mat);
324 let right_data = s_mat.dot(&vt_trunc.mapv(|x| Complex64::new(x, 0.0)));
325
326 let mut left_indices = self.indices[..=idx].to_vec();
328 left_indices.push(format!("bond_{}", self.id));
329
330 let mut right_indices = vec![format!("bond_{}", self.id)];
331 right_indices.extend_from_slice(&self.indices[(idx + 1)..]);
332
333 let left_tensor = Tensor::new(self.id * 2, left_data.into_dyn(), left_indices);
334
335 let right_tensor = Tensor::new(self.id * 2 + 1, right_data.into_dyn(), right_indices);
336
337 Ok((left_tensor, right_tensor))
338 }
339}
340
341#[derive(Debug, Clone, PartialEq, Eq, Hash)]
343pub struct TensorEdge {
344 pub tensor1: usize,
346 pub index1: String,
348 pub tensor2: usize,
350 pub index2: String,
352}
353
354#[derive(Debug)]
356pub struct TensorNetwork {
357 pub tensors: HashMap<usize, Tensor>,
359 pub edges: Vec<TensorEdge>,
361 pub open_indices: HashMap<usize, Vec<String>>,
363 next_id: usize,
365}
366
367impl TensorNetwork {
368 pub fn new() -> Self {
370 Self {
371 tensors: HashMap::new(),
372 edges: Vec::new(),
373 open_indices: HashMap::new(),
374 next_id: 0,
375 }
376 }
377
378 pub fn add_tensor(&mut self, tensor: Tensor) -> usize {
380 let id = tensor.id;
381 self.open_indices.insert(id, tensor.indices.clone());
382 self.tensors.insert(id, tensor);
383 self.next_id = self.next_id.max(id + 1);
384 id
385 }
386
387 pub fn connect(
389 &mut self,
390 tensor1: usize,
391 index1: String,
392 tensor2: usize,
393 index2: String,
394 ) -> QuantRS2Result<()> {
395 if !self.tensors.contains_key(&tensor1) {
397 return Err(QuantRS2Error::InvalidInput(format!(
398 "Tensor {} not found",
399 tensor1
400 )));
401 }
402 if !self.tensors.contains_key(&tensor2) {
403 return Err(QuantRS2Error::InvalidInput(format!(
404 "Tensor {} not found",
405 tensor2
406 )));
407 }
408
409 let t1 = &self.tensors[&tensor1];
411 let t2 = &self.tensors[&tensor2];
412
413 let idx1_pos = t1
414 .indices
415 .iter()
416 .position(|s| s == &index1)
417 .ok_or_else(|| {
418 QuantRS2Error::InvalidInput(format!(
419 "Index {} not found in tensor {}",
420 index1, tensor1
421 ))
422 })?;
423 let idx2_pos = t2
424 .indices
425 .iter()
426 .position(|s| s == &index2)
427 .ok_or_else(|| {
428 QuantRS2Error::InvalidInput(format!(
429 "Index {} not found in tensor {}",
430 index2, tensor2
431 ))
432 })?;
433
434 if t1.shape[idx1_pos] != t2.shape[idx2_pos] {
435 return Err(QuantRS2Error::InvalidInput(format!(
436 "Connected indices must have same dimension: {} vs {}",
437 t1.shape[idx1_pos], t2.shape[idx2_pos]
438 )));
439 }
440
441 self.edges.push(TensorEdge {
443 tensor1,
444 index1: index1.clone(),
445 tensor2,
446 index2: index2.clone(),
447 });
448
449 if let Some(indices) = self.open_indices.get_mut(&tensor1) {
451 indices.retain(|s| s != &index1);
452 }
453 if let Some(indices) = self.open_indices.get_mut(&tensor2) {
454 indices.retain(|s| s != &index2);
455 }
456
457 Ok(())
458 }
459
460 pub fn find_contraction_order(&self) -> Vec<(usize, usize)> {
462 let mut remaining_tensors: HashSet<_> = self.tensors.keys().cloned().collect();
464 let mut order = Vec::new();
465
466 let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
468 for edge in &self.edges {
469 adjacency
470 .entry(edge.tensor1)
471 .or_insert_with(Vec::new)
472 .push(edge.tensor2);
473 adjacency
474 .entry(edge.tensor2)
475 .or_insert_with(Vec::new)
476 .push(edge.tensor1);
477 }
478
479 while remaining_tensors.len() > 1 {
480 let mut best_pair = None;
481 let mut min_cost = usize::MAX;
482
483 for &t1 in &remaining_tensors {
485 if let Some(neighbors) = adjacency.get(&t1) {
486 for &t2 in neighbors {
487 if t2 > t1 && remaining_tensors.contains(&t2) {
488 let cost = self.estimate_contraction_cost(t1, t2);
490 if cost < min_cost {
491 min_cost = cost;
492 best_pair = Some((t1, t2));
493 }
494 }
495 }
496 }
497 }
498
499 if let Some((t1, t2)) = best_pair {
500 order.push((t1, t2));
501 remaining_tensors.remove(&t1);
502 remaining_tensors.remove(&t2);
503
504 let virtual_id = self.next_id + order.len();
506 remaining_tensors.insert(virtual_id);
507
508 let mut virtual_neighbors = HashSet::new();
510 if let Some(n1) = adjacency.get(&t1) {
511 virtual_neighbors.extend(
512 n1.iter()
513 .filter(|&&n| n != t2 && remaining_tensors.contains(&n)),
514 );
515 }
516 if let Some(n2) = adjacency.get(&t2) {
517 virtual_neighbors.extend(
518 n2.iter()
519 .filter(|&&n| n != t1 && remaining_tensors.contains(&n)),
520 );
521 }
522 adjacency.insert(virtual_id, virtual_neighbors.into_iter().collect());
523 } else {
524 break;
525 }
526 }
527
528 order
529 }
530
531 fn estimate_contraction_cost(&self, _t1: usize, _t2: usize) -> usize {
533 1000 }
537
538 pub fn contract_all(&mut self) -> QuantRS2Result<Tensor> {
540 if self.tensors.is_empty() {
541 return Err(QuantRS2Error::InvalidInput(
542 "Cannot contract empty tensor network".into(),
543 ));
544 }
545
546 if self.tensors.len() == 1 {
547 return Ok(self.tensors.values().next().unwrap().clone());
548 }
549
550 let order = self.find_contraction_order();
552
553 let mut tensor_map = self.tensors.clone();
555 let mut next_id = self.next_id;
556
557 for (t1_id, t2_id) in order {
558 let edge = self
560 .edges
561 .iter()
562 .find(|e| {
563 (e.tensor1 == t1_id && e.tensor2 == t2_id)
564 || (e.tensor1 == t2_id && e.tensor2 == t1_id)
565 })
566 .ok_or_else(|| QuantRS2Error::InvalidInput("Tensors not connected".into()))?;
567
568 let t1 = tensor_map
569 .remove(&t1_id)
570 .ok_or_else(|| QuantRS2Error::InvalidInput("Tensor not found".into()))?;
571 let t2 = tensor_map
572 .remove(&t2_id)
573 .ok_or_else(|| QuantRS2Error::InvalidInput("Tensor not found".into()))?;
574
575 let contracted = if edge.tensor1 == t1_id {
577 t1.contract(&t2, &edge.index1, &edge.index2)?
578 } else {
579 t1.contract(&t2, &edge.index2, &edge.index1)?
580 };
581
582 let mut new_tensor = contracted;
584 new_tensor.id = next_id;
585 tensor_map.insert(next_id, new_tensor);
586 next_id += 1;
587 }
588
589 tensor_map
591 .into_values()
592 .next()
593 .ok_or_else(|| QuantRS2Error::InvalidInput("Contraction failed".into()))
594 }
595
596 pub fn to_mps(&self, _max_bond_dim: Option<usize>) -> QuantRS2Result<Vec<Tensor>> {
598 Ok(vec![])
601 }
602
603 pub fn apply_mpo(&mut self, _mpo: &[Tensor], _qubits: &[usize]) -> QuantRS2Result<()> {
605 Ok(())
607 }
608
609 pub fn tensors(&self) -> Vec<&Tensor> {
611 self.tensors.values().collect()
612 }
613
614 pub fn tensor(&self, id: usize) -> Option<&Tensor> {
616 self.tensors.get(&id)
617 }
618}
619
620pub struct TensorNetworkBuilder {
622 network: TensorNetwork,
623 qubit_indices: HashMap<usize, String>,
624 current_indices: HashMap<usize, String>,
625}
626
627impl TensorNetworkBuilder {
628 pub fn new(num_qubits: usize) -> Self {
630 let mut network = TensorNetwork::new();
631 let mut qubit_indices = HashMap::new();
632 let mut current_indices = HashMap::new();
633
634 for i in 0..num_qubits {
636 let idx = format!("q{}_0", i);
637 let tensor = Tensor::qubit_zero(i, idx.clone());
638 network.add_tensor(tensor);
639 qubit_indices.insert(i, idx.clone());
640 current_indices.insert(i, idx);
641 }
642
643 Self {
644 network,
645 qubit_indices,
646 current_indices,
647 }
648 }
649
650 pub fn apply_single_qubit_gate(
652 &mut self,
653 gate: &dyn GateOp,
654 qubit: usize,
655 ) -> QuantRS2Result<()> {
656 let matrix_vec = gate.matrix()?;
657 let matrix = Array2::from_shape_vec((2, 2), matrix_vec)
658 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?;
659
660 let in_idx = self.current_indices[&qubit].clone();
662 let out_idx = format!("q{}_{}", qubit, self.network.next_id);
663 let gate_tensor = Tensor::from_matrix(
664 self.network.next_id,
665 matrix,
666 in_idx.clone(),
667 out_idx.clone(),
668 );
669
670 let gate_id = self.network.add_tensor(gate_tensor);
672
673 if let Some(prev_tensor) = self.find_tensor_with_index(&in_idx) {
675 self.network
676 .connect(prev_tensor, in_idx.clone(), gate_id, in_idx)?;
677 }
678
679 self.current_indices.insert(qubit, out_idx);
681
682 Ok(())
683 }
684
685 pub fn apply_two_qubit_gate(
687 &mut self,
688 gate: &dyn GateOp,
689 qubit1: usize,
690 qubit2: usize,
691 ) -> QuantRS2Result<()> {
692 let matrix_vec = gate.matrix()?;
693 let matrix = Array2::from_shape_vec((4, 4), matrix_vec)
694 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?;
695
696 let tensor_data = matrix
698 .into_shape_with_order((2, 2, 2, 2))
699 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?
700 .into_dyn();
701
702 let in1_idx = self.current_indices[&qubit1].clone();
704 let in2_idx = self.current_indices[&qubit2].clone();
705 let out1_idx = format!("q{}_{}", qubit1, self.network.next_id);
706 let out2_idx = format!("q{}_{}", qubit2, self.network.next_id);
707
708 let gate_tensor = Tensor::new(
709 self.network.next_id,
710 tensor_data,
711 vec![
712 in1_idx.clone(),
713 in2_idx.clone(),
714 out1_idx.clone(),
715 out2_idx.clone(),
716 ],
717 );
718
719 let gate_id = self.network.add_tensor(gate_tensor);
721
722 if let Some(prev1) = self.find_tensor_with_index(&in1_idx) {
724 self.network
725 .connect(prev1, in1_idx.clone(), gate_id, in1_idx)?;
726 }
727 if let Some(prev2) = self.find_tensor_with_index(&in2_idx) {
728 self.network
729 .connect(prev2, in2_idx.clone(), gate_id, in2_idx)?;
730 }
731
732 self.current_indices.insert(qubit1, out1_idx);
734 self.current_indices.insert(qubit2, out2_idx);
735
736 Ok(())
737 }
738
739 fn find_tensor_with_index(&self, index: &str) -> Option<usize> {
741 for (id, tensor) in &self.network.tensors {
742 if tensor.indices.iter().any(|idx| idx == index) {
743 return Some(*id);
744 }
745 }
746 None
747 }
748
749 pub fn build(self) -> TensorNetwork {
751 self.network
752 }
753
754 pub fn to_statevector(&mut self) -> QuantRS2Result<Vec<Complex64>> {
756 let final_tensor = self.network.contract_all()?;
757 Ok(final_tensor.data.into_raw_vec_and_offset().0)
758 }
759}
760
761pub struct TensorNetworkSimulator {
763 max_bond_dim: usize,
765 use_compression: bool,
767 parallel_threshold: usize,
769}
770
771impl TensorNetworkSimulator {
772 pub fn new() -> Self {
774 Self {
775 max_bond_dim: 64,
776 use_compression: true,
777 parallel_threshold: 1000,
778 }
779 }
780
781 pub fn with_max_bond_dim(mut self, dim: usize) -> Self {
783 self.max_bond_dim = dim;
784 self
785 }
786
787 pub fn with_compression(mut self, compress: bool) -> Self {
789 self.use_compression = compress;
790 self
791 }
792
793 pub fn simulate<const N: usize>(
795 &self,
796 gates: &[Box<dyn GateOp>],
797 ) -> QuantRS2Result<Register<N>> {
798 let mut builder = TensorNetworkBuilder::new(N);
799
800 for gate in gates {
802 let qubits = gate.qubits();
803 match qubits.len() {
804 1 => builder.apply_single_qubit_gate(gate.as_ref(), qubits[0].0 as usize)?,
805 2 => builder.apply_two_qubit_gate(
806 gate.as_ref(),
807 qubits[0].0 as usize,
808 qubits[1].0 as usize,
809 )?,
810 _ => {
811 return Err(QuantRS2Error::UnsupportedOperation(format!(
812 "Gates with {} qubits not supported in tensor network",
813 qubits.len()
814 )))
815 }
816 }
817 }
818
819 let amplitudes = builder.to_statevector()?;
821 Register::with_amplitudes(amplitudes)
822 }
823}
824
825pub mod contraction_optimization {
827 use super::*;
828
829 pub struct DynamicProgrammingOptimizer {
831 memo: HashMap<Vec<usize>, (usize, Vec<(usize, usize)>)>,
832 }
833
834 impl DynamicProgrammingOptimizer {
835 pub fn new() -> Self {
836 Self {
837 memo: HashMap::new(),
838 }
839 }
840
841 pub fn optimize(&mut self, network: &TensorNetwork) -> Vec<(usize, usize)> {
843 let tensor_ids: Vec<_> = network.tensors.keys().cloned().collect();
844 self.find_optimal_order(&tensor_ids, network).1
845 }
846
847 fn find_optimal_order(
848 &mut self,
849 tensors: &[usize],
850 network: &TensorNetwork,
851 ) -> (usize, Vec<(usize, usize)>) {
852 if tensors.len() <= 1 {
853 return (0, vec![]);
854 }
855
856 let key = tensors.to_vec();
857 if let Some(result) = self.memo.get(&key) {
858 return result.clone();
859 }
860
861 let mut best_cost = usize::MAX;
862 let mut best_order = vec![];
863
864 for i in 0..tensors.len() {
866 for j in (i + 1)..tensors.len() {
867 if self.are_connected(tensors[i], tensors[j], network) {
869 let cost = network.estimate_contraction_cost(tensors[i], tensors[j]);
870
871 let mut remaining = vec![];
873 for (k, &t) in tensors.iter().enumerate() {
874 if k != i && k != j {
875 remaining.push(t);
876 }
877 }
878 remaining.push(network.next_id + remaining.len()); let (sub_cost, sub_order) = self.find_optimal_order(&remaining, network);
881 let total_cost = cost + sub_cost;
882
883 if total_cost < best_cost {
884 best_cost = total_cost;
885 best_order = vec![(tensors[i], tensors[j])];
886 best_order.extend(sub_order);
887 }
888 }
889 }
890 }
891
892 self.memo.insert(key, (best_cost, best_order.clone()));
893 (best_cost, best_order)
894 }
895
896 fn are_connected(&self, t1: usize, t2: usize, network: &TensorNetwork) -> bool {
897 network.edges.iter().any(|e| {
898 (e.tensor1 == t1 && e.tensor2 == t2) || (e.tensor1 == t2 && e.tensor2 == t1)
899 })
900 }
901 }
902}
903
904#[cfg(test)]
905mod tests {
906 use super::*;
907
908 #[test]
909 fn test_tensor_creation() {
910 let data = ArrayD::zeros(IxDyn(&[2, 2]));
911 let tensor = Tensor::new(0, data, vec!["in".to_string(), "out".to_string()]);
912 assert_eq!(tensor.rank(), 2);
913 assert_eq!(tensor.shape, vec![2, 2]);
914 }
915
916 #[test]
917 fn test_qubit_tensors() {
918 let t0 = Tensor::qubit_zero(0, "q0".to_string());
919 assert_eq!(t0.data[[0]], Complex64::new(1.0, 0.0));
920 assert_eq!(t0.data[[1]], Complex64::new(0.0, 0.0));
921
922 let t1 = Tensor::qubit_one(1, "q1".to_string());
923 assert_eq!(t1.data[[0]], Complex64::new(0.0, 0.0));
924 assert_eq!(t1.data[[1]], Complex64::new(1.0, 0.0));
925 }
926
927 #[test]
928 fn test_tensor_network_builder() {
929 let builder = TensorNetworkBuilder::new(2);
930 assert_eq!(builder.network.tensors.len(), 2);
931 }
932
933 #[test]
934 fn test_network_connection() {
935 let mut network = TensorNetwork::new();
936
937 let t1 = Tensor::qubit_zero(0, "q0".to_string());
938 let t2 = Tensor::qubit_zero(1, "q1".to_string());
939
940 let id1 = network.add_tensor(t1);
941 let id2 = network.add_tensor(t2);
942
943 assert!(network
945 .connect(id1, "bond".to_string(), id2, "bond".to_string())
946 .is_err());
947 }
948}