1use crate::{
7 error::{QuantRS2Error, QuantRS2Result},
8 gate::GateOp,
9 linalg_stubs::svd,
10 register::Register,
11};
12use scirs2_core::ndarray::{Array, Array2, ArrayD, IxDyn};
13use scirs2_core::Complex;
14use std::collections::{HashMap, HashSet};
16
17type Complex64 = Complex<f64>;
19
20#[derive(Debug, Clone)]
22pub struct Tensor {
23 pub id: usize,
25 pub data: ArrayD<Complex64>,
27 pub indices: Vec<String>,
29 pub shape: Vec<usize>,
31}
32
33impl Tensor {
34 pub fn new(id: usize, data: ArrayD<Complex64>, indices: Vec<String>) -> Self {
36 let shape = data.shape().to_vec();
37 Self {
38 id,
39 data,
40 indices,
41 shape,
42 }
43 }
44
45 pub fn from_matrix(
47 id: usize,
48 matrix: Array2<Complex64>,
49 in_idx: String,
50 out_idx: String,
51 ) -> Self {
52 let shape = matrix.shape().to_vec();
53 let data = matrix.into_dyn();
54 Self {
55 id,
56 data,
57 indices: vec![in_idx, out_idx],
58 shape,
59 }
60 }
61
62 pub fn qubit_zero(id: usize, idx: String) -> Self {
64 let mut data = Array::zeros(IxDyn(&[2]));
65 data[[0]] = Complex64::new(1.0, 0.0);
66 Self {
67 id,
68 data,
69 indices: vec![idx],
70 shape: vec![2],
71 }
72 }
73
74 pub fn qubit_one(id: usize, idx: String) -> Self {
76 let mut data = Array::zeros(IxDyn(&[2]));
77 data[[1]] = Complex64::new(1.0, 0.0);
78 Self {
79 id,
80 data,
81 indices: vec![idx],
82 shape: vec![2],
83 }
84 }
85
86 pub fn from_array<D>(
88 array: scirs2_core::ndarray::ArrayBase<scirs2_core::ndarray::OwnedRepr<Complex64>, D>,
89 indices: Vec<usize>,
90 ) -> Self
91 where
92 D: scirs2_core::ndarray::Dimension,
93 {
94 let shape = array.shape().to_vec();
95 let data = array.into_dyn();
96 let index_labels: Vec<String> = indices.iter().map(|i| format!("idx_{}", i)).collect();
97 Self {
98 id: 0, data,
100 indices: index_labels,
101 shape,
102 }
103 }
104
105 pub fn rank(&self) -> usize {
107 self.indices.len()
108 }
109
110 pub fn tensor(&self) -> &ArrayD<Complex64> {
112 &self.data
113 }
114
115 pub fn ndim(&self) -> usize {
117 self.data.ndim()
118 }
119
120 pub fn contract(
122 &self,
123 other: &Tensor,
124 self_idx: &str,
125 other_idx: &str,
126 ) -> QuantRS2Result<Tensor> {
127 let self_pos = self
129 .indices
130 .iter()
131 .position(|s| s == self_idx)
132 .ok_or_else(|| {
133 QuantRS2Error::InvalidInput(format!("Index {} not found in tensor", self_idx))
134 })?;
135 let other_pos = other
136 .indices
137 .iter()
138 .position(|s| s == other_idx)
139 .ok_or_else(|| {
140 QuantRS2Error::InvalidInput(format!("Index {} not found in tensor", other_idx))
141 })?;
142
143 if self.shape[self_pos] != other.shape[other_pos] {
145 return Err(QuantRS2Error::InvalidInput(format!(
146 "Cannot contract indices with different dimensions: {} vs {}",
147 self.shape[self_pos], other.shape[other_pos]
148 )));
149 }
150
151 let contracted = self.contract_indices(&other, self_pos, other_pos)?;
153
154 let mut new_indices = Vec::new();
156 for (i, idx) in self.indices.iter().enumerate() {
157 if i != self_pos {
158 new_indices.push(idx.clone());
159 }
160 }
161 for (i, idx) in other.indices.iter().enumerate() {
162 if i != other_pos {
163 new_indices.push(idx.clone());
164 }
165 }
166
167 Ok(Tensor::new(
168 self.id.max(other.id) + 1,
169 contracted,
170 new_indices,
171 ))
172 }
173
174 fn contract_indices(
176 &self,
177 other: &Tensor,
178 self_idx: usize,
179 other_idx: usize,
180 ) -> QuantRS2Result<ArrayD<Complex64>> {
181 let self_shape = self.data.shape();
183 let other_shape = other.data.shape();
184
185 let mut self_left_dims = 1;
187 let mut self_right_dims = 1;
188 for i in 0..self_idx {
189 self_left_dims *= self_shape[i];
190 }
191 for i in (self_idx + 1)..self_shape.len() {
192 self_right_dims *= self_shape[i];
193 }
194
195 let mut other_left_dims = 1;
196 let mut other_right_dims = 1;
197 for i in 0..other_idx {
198 other_left_dims *= other_shape[i];
199 }
200 for i in (other_idx + 1)..other_shape.len() {
201 other_right_dims *= other_shape[i];
202 }
203
204 let contract_dim = self_shape[self_idx];
205
206 let self_mat = self
208 .data
209 .view()
210 .into_shape_with_order((self_left_dims, contract_dim * self_right_dims))
211 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?
212 .to_owned();
213 let other_mat = other
214 .data
215 .view()
216 .into_shape_with_order((other_left_dims * contract_dim, other_right_dims))
217 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?
218 .to_owned();
219
220 let _result_mat: Array2<Complex64> = Array2::zeros((
222 self_left_dims * self_right_dims,
223 other_left_dims * other_right_dims,
224 ));
225
226 let mut result_vec = Vec::new();
228 for i in 0..self_left_dims {
229 for j in 0..self_right_dims {
230 for k in 0..other_left_dims {
231 for l in 0..other_right_dims {
232 let mut sum = Complex64::new(0.0, 0.0);
233 for c in 0..contract_dim {
234 let _self_idx =
235 i * contract_dim * self_right_dims + c * self_right_dims + j;
236 let _other_idx =
237 k * contract_dim * other_right_dims + c * other_right_dims + l;
238 sum += self_mat[[i, c * self_right_dims + j]]
239 * other_mat[[k * contract_dim + c, l]];
240 }
241 result_vec.push(sum);
242 }
243 }
244 }
245 }
246
247 let mut result_shape = Vec::new();
249 for i in 0..self_idx {
250 result_shape.push(self_shape[i]);
251 }
252 for i in (self_idx + 1)..self_shape.len() {
253 result_shape.push(self_shape[i]);
254 }
255 for i in 0..other_idx {
256 result_shape.push(other_shape[i]);
257 }
258 for i in (other_idx + 1)..other_shape.len() {
259 result_shape.push(other_shape[i]);
260 }
261
262 ArrayD::from_shape_vec(IxDyn(&result_shape), result_vec)
263 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))
264 }
265
266 pub fn svd_decompose(
268 &self,
269 idx: usize,
270 max_rank: Option<usize>,
271 ) -> QuantRS2Result<(Tensor, Tensor)> {
272 if idx >= self.rank() {
273 return Err(QuantRS2Error::InvalidInput(format!(
274 "Index {} out of bounds for tensor with rank {}",
275 idx,
276 self.rank()
277 )));
278 }
279
280 let shape = self.data.shape();
282 let mut left_dim = 1;
283 let mut right_dim = 1;
284
285 for i in 0..=idx {
286 left_dim *= shape[i];
287 }
288 for i in (idx + 1)..shape.len() {
289 right_dim *= shape[i];
290 }
291
292 let matrix = self
294 .data
295 .view()
296 .into_shape_with_order((left_dim, right_dim))
297 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?
298 .to_owned();
299
300 let real_matrix = matrix.mapv(|c| c.re);
302 let (u, s, vt) = svd(&real_matrix.view(), false, None)
303 .map_err(|e| QuantRS2Error::ComputationError(format!("SVD failed: {:?}", e)))?;
304
305 let rank = if let Some(max_r) = max_rank {
307 max_r.min(s.len())
308 } else {
309 s.len()
310 };
311
312 let u_trunc = u.slice(scirs2_core::ndarray::s![.., ..rank]).to_owned();
314 let s_trunc = s.slice(scirs2_core::ndarray::s![..rank]).to_owned();
315 let vt_trunc = vt.slice(scirs2_core::ndarray::s![..rank, ..]).to_owned();
316
317 let mut s_mat = Array2::zeros((rank, rank));
319 for i in 0..rank {
320 s_mat[[i, i]] = Complex64::new(s_trunc[i].sqrt(), 0.0);
321 }
322
323 let left_data = u_trunc.mapv(|x| Complex64::new(x, 0.0)).dot(&s_mat);
325 let right_data = s_mat.dot(&vt_trunc.mapv(|x| Complex64::new(x, 0.0)));
326
327 let mut left_indices = self.indices[..=idx].to_vec();
329 left_indices.push(format!("bond_{}", self.id));
330
331 let mut right_indices = vec![format!("bond_{}", self.id)];
332 right_indices.extend_from_slice(&self.indices[(idx + 1)..]);
333
334 let left_tensor = Tensor::new(self.id * 2, left_data.into_dyn(), left_indices);
335
336 let right_tensor = Tensor::new(self.id * 2 + 1, right_data.into_dyn(), right_indices);
337
338 Ok((left_tensor, right_tensor))
339 }
340}
341
342#[derive(Debug, Clone, PartialEq, Eq, Hash)]
344pub struct TensorEdge {
345 pub tensor1: usize,
347 pub index1: String,
349 pub tensor2: usize,
351 pub index2: String,
353}
354
355#[derive(Debug)]
357pub struct TensorNetwork {
358 pub tensors: HashMap<usize, Tensor>,
360 pub edges: Vec<TensorEdge>,
362 pub open_indices: HashMap<usize, Vec<String>>,
364 next_id: usize,
366}
367
368impl TensorNetwork {
369 pub fn new() -> Self {
371 Self {
372 tensors: HashMap::new(),
373 edges: Vec::new(),
374 open_indices: HashMap::new(),
375 next_id: 0,
376 }
377 }
378
379 pub fn add_tensor(&mut self, tensor: Tensor) -> usize {
381 let id = tensor.id;
382 self.open_indices.insert(id, tensor.indices.clone());
383 self.tensors.insert(id, tensor);
384 self.next_id = self.next_id.max(id + 1);
385 id
386 }
387
388 pub fn connect(
390 &mut self,
391 tensor1: usize,
392 index1: String,
393 tensor2: usize,
394 index2: String,
395 ) -> QuantRS2Result<()> {
396 if !self.tensors.contains_key(&tensor1) {
398 return Err(QuantRS2Error::InvalidInput(format!(
399 "Tensor {} not found",
400 tensor1
401 )));
402 }
403 if !self.tensors.contains_key(&tensor2) {
404 return Err(QuantRS2Error::InvalidInput(format!(
405 "Tensor {} not found",
406 tensor2
407 )));
408 }
409
410 let t1 = &self.tensors[&tensor1];
412 let t2 = &self.tensors[&tensor2];
413
414 let idx1_pos = t1
415 .indices
416 .iter()
417 .position(|s| s == &index1)
418 .ok_or_else(|| {
419 QuantRS2Error::InvalidInput(format!(
420 "Index {} not found in tensor {}",
421 index1, tensor1
422 ))
423 })?;
424 let idx2_pos = t2
425 .indices
426 .iter()
427 .position(|s| s == &index2)
428 .ok_or_else(|| {
429 QuantRS2Error::InvalidInput(format!(
430 "Index {} not found in tensor {}",
431 index2, tensor2
432 ))
433 })?;
434
435 if t1.shape[idx1_pos] != t2.shape[idx2_pos] {
436 return Err(QuantRS2Error::InvalidInput(format!(
437 "Connected indices must have same dimension: {} vs {}",
438 t1.shape[idx1_pos], t2.shape[idx2_pos]
439 )));
440 }
441
442 self.edges.push(TensorEdge {
444 tensor1,
445 index1: index1.clone(),
446 tensor2,
447 index2: index2.clone(),
448 });
449
450 if let Some(indices) = self.open_indices.get_mut(&tensor1) {
452 indices.retain(|s| s != &index1);
453 }
454 if let Some(indices) = self.open_indices.get_mut(&tensor2) {
455 indices.retain(|s| s != &index2);
456 }
457
458 Ok(())
459 }
460
461 pub fn find_contraction_order(&self) -> Vec<(usize, usize)> {
463 let mut remaining_tensors: HashSet<_> = self.tensors.keys().cloned().collect();
465 let mut order = Vec::new();
466
467 let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
469 for edge in &self.edges {
470 adjacency
471 .entry(edge.tensor1)
472 .or_insert_with(Vec::new)
473 .push(edge.tensor2);
474 adjacency
475 .entry(edge.tensor2)
476 .or_insert_with(Vec::new)
477 .push(edge.tensor1);
478 }
479
480 while remaining_tensors.len() > 1 {
481 let mut best_pair = None;
482 let mut min_cost = usize::MAX;
483
484 for &t1 in &remaining_tensors {
486 if let Some(neighbors) = adjacency.get(&t1) {
487 for &t2 in neighbors {
488 if t2 > t1 && remaining_tensors.contains(&t2) {
489 let cost = self.estimate_contraction_cost(t1, t2);
491 if cost < min_cost {
492 min_cost = cost;
493 best_pair = Some((t1, t2));
494 }
495 }
496 }
497 }
498 }
499
500 if let Some((t1, t2)) = best_pair {
501 order.push((t1, t2));
502 remaining_tensors.remove(&t1);
503 remaining_tensors.remove(&t2);
504
505 let virtual_id = self.next_id + order.len();
507 remaining_tensors.insert(virtual_id);
508
509 let mut virtual_neighbors = HashSet::new();
511 if let Some(n1) = adjacency.get(&t1) {
512 virtual_neighbors.extend(
513 n1.iter()
514 .filter(|&&n| n != t2 && remaining_tensors.contains(&n)),
515 );
516 }
517 if let Some(n2) = adjacency.get(&t2) {
518 virtual_neighbors.extend(
519 n2.iter()
520 .filter(|&&n| n != t1 && remaining_tensors.contains(&n)),
521 );
522 }
523 adjacency.insert(virtual_id, virtual_neighbors.into_iter().collect());
524 } else {
525 break;
526 }
527 }
528
529 order
530 }
531
532 fn estimate_contraction_cost(&self, _t1: usize, _t2: usize) -> usize {
534 1000 }
538
539 pub fn contract_all(&mut self) -> QuantRS2Result<Tensor> {
541 if self.tensors.is_empty() {
542 return Err(QuantRS2Error::InvalidInput(
543 "Cannot contract empty tensor network".into(),
544 ));
545 }
546
547 if self.tensors.len() == 1 {
548 return Ok(self.tensors.values().next().unwrap().clone());
549 }
550
551 let order = self.find_contraction_order();
553
554 let mut tensor_map = self.tensors.clone();
556 let mut next_id = self.next_id;
557
558 for (t1_id, t2_id) in order {
559 let edge = self
561 .edges
562 .iter()
563 .find(|e| {
564 (e.tensor1 == t1_id && e.tensor2 == t2_id)
565 || (e.tensor1 == t2_id && e.tensor2 == t1_id)
566 })
567 .ok_or_else(|| QuantRS2Error::InvalidInput("Tensors not connected".into()))?;
568
569 let t1 = tensor_map
570 .remove(&t1_id)
571 .ok_or_else(|| QuantRS2Error::InvalidInput("Tensor not found".into()))?;
572 let t2 = tensor_map
573 .remove(&t2_id)
574 .ok_or_else(|| QuantRS2Error::InvalidInput("Tensor not found".into()))?;
575
576 let contracted = if edge.tensor1 == t1_id {
578 t1.contract(&t2, &edge.index1, &edge.index2)?
579 } else {
580 t1.contract(&t2, &edge.index2, &edge.index1)?
581 };
582
583 let mut new_tensor = contracted;
585 new_tensor.id = next_id;
586 tensor_map.insert(next_id, new_tensor);
587 next_id += 1;
588 }
589
590 tensor_map
592 .into_values()
593 .next()
594 .ok_or_else(|| QuantRS2Error::InvalidInput("Contraction failed".into()))
595 }
596
597 pub fn to_mps(&self, _max_bond_dim: Option<usize>) -> QuantRS2Result<Vec<Tensor>> {
599 Ok(vec![])
602 }
603
604 pub fn apply_mpo(&mut self, _mpo: &[Tensor], _qubits: &[usize]) -> QuantRS2Result<()> {
606 Ok(())
608 }
609
610 pub fn tensors(&self) -> Vec<&Tensor> {
612 self.tensors.values().collect()
613 }
614
615 pub fn tensor(&self, id: usize) -> Option<&Tensor> {
617 self.tensors.get(&id)
618 }
619}
620
621pub struct TensorNetworkBuilder {
623 network: TensorNetwork,
624 qubit_indices: HashMap<usize, String>,
625 current_indices: HashMap<usize, String>,
626}
627
628impl TensorNetworkBuilder {
629 pub fn new(num_qubits: usize) -> Self {
631 let mut network = TensorNetwork::new();
632 let mut qubit_indices = HashMap::new();
633 let mut current_indices = HashMap::new();
634
635 for i in 0..num_qubits {
637 let idx = format!("q{}_0", i);
638 let tensor = Tensor::qubit_zero(i, idx.clone());
639 network.add_tensor(tensor);
640 qubit_indices.insert(i, idx.clone());
641 current_indices.insert(i, idx);
642 }
643
644 Self {
645 network,
646 qubit_indices,
647 current_indices,
648 }
649 }
650
651 pub fn apply_single_qubit_gate(
653 &mut self,
654 gate: &dyn GateOp,
655 qubit: usize,
656 ) -> QuantRS2Result<()> {
657 let matrix_vec = gate.matrix()?;
658 let matrix = Array2::from_shape_vec((2, 2), matrix_vec)
659 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?;
660
661 let in_idx = self.current_indices[&qubit].clone();
663 let out_idx = format!("q{}_{}", qubit, self.network.next_id);
664 let gate_tensor = Tensor::from_matrix(
665 self.network.next_id,
666 matrix,
667 in_idx.clone(),
668 out_idx.clone(),
669 );
670
671 let gate_id = self.network.add_tensor(gate_tensor);
673
674 if let Some(prev_tensor) = self.find_tensor_with_index(&in_idx) {
676 self.network
677 .connect(prev_tensor, in_idx.clone(), gate_id, in_idx)?;
678 }
679
680 self.current_indices.insert(qubit, out_idx);
682
683 Ok(())
684 }
685
686 pub fn apply_two_qubit_gate(
688 &mut self,
689 gate: &dyn GateOp,
690 qubit1: usize,
691 qubit2: usize,
692 ) -> QuantRS2Result<()> {
693 let matrix_vec = gate.matrix()?;
694 let matrix = Array2::from_shape_vec((4, 4), matrix_vec)
695 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?;
696
697 let tensor_data = matrix
699 .into_shape_with_order((2, 2, 2, 2))
700 .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {}", e)))?
701 .into_dyn();
702
703 let in1_idx = self.current_indices[&qubit1].clone();
705 let in2_idx = self.current_indices[&qubit2].clone();
706 let out1_idx = format!("q{}_{}", qubit1, self.network.next_id);
707 let out2_idx = format!("q{}_{}", qubit2, self.network.next_id);
708
709 let gate_tensor = Tensor::new(
710 self.network.next_id,
711 tensor_data,
712 vec![
713 in1_idx.clone(),
714 in2_idx.clone(),
715 out1_idx.clone(),
716 out2_idx.clone(),
717 ],
718 );
719
720 let gate_id = self.network.add_tensor(gate_tensor);
722
723 if let Some(prev1) = self.find_tensor_with_index(&in1_idx) {
725 self.network
726 .connect(prev1, in1_idx.clone(), gate_id, in1_idx)?;
727 }
728 if let Some(prev2) = self.find_tensor_with_index(&in2_idx) {
729 self.network
730 .connect(prev2, in2_idx.clone(), gate_id, in2_idx)?;
731 }
732
733 self.current_indices.insert(qubit1, out1_idx);
735 self.current_indices.insert(qubit2, out2_idx);
736
737 Ok(())
738 }
739
740 fn find_tensor_with_index(&self, index: &str) -> Option<usize> {
742 for (id, tensor) in &self.network.tensors {
743 if tensor.indices.iter().any(|idx| idx == index) {
744 return Some(*id);
745 }
746 }
747 None
748 }
749
750 pub fn build(self) -> TensorNetwork {
752 self.network
753 }
754
755 pub fn to_statevector(&mut self) -> QuantRS2Result<Vec<Complex64>> {
757 let final_tensor = self.network.contract_all()?;
758 Ok(final_tensor.data.into_raw_vec_and_offset().0)
759 }
760}
761
762pub struct TensorNetworkSimulator {
764 max_bond_dim: usize,
766 use_compression: bool,
768 parallel_threshold: usize,
770}
771
772impl TensorNetworkSimulator {
773 pub fn new() -> Self {
775 Self {
776 max_bond_dim: 64,
777 use_compression: true,
778 parallel_threshold: 1000,
779 }
780 }
781
782 pub fn with_max_bond_dim(mut self, dim: usize) -> Self {
784 self.max_bond_dim = dim;
785 self
786 }
787
788 pub fn with_compression(mut self, compress: bool) -> Self {
790 self.use_compression = compress;
791 self
792 }
793
794 pub fn simulate<const N: usize>(
796 &self,
797 gates: &[Box<dyn GateOp>],
798 ) -> QuantRS2Result<Register<N>> {
799 let mut builder = TensorNetworkBuilder::new(N);
800
801 for gate in gates {
803 let qubits = gate.qubits();
804 match qubits.len() {
805 1 => builder.apply_single_qubit_gate(gate.as_ref(), qubits[0].0 as usize)?,
806 2 => builder.apply_two_qubit_gate(
807 gate.as_ref(),
808 qubits[0].0 as usize,
809 qubits[1].0 as usize,
810 )?,
811 _ => {
812 return Err(QuantRS2Error::UnsupportedOperation(format!(
813 "Gates with {} qubits not supported in tensor network",
814 qubits.len()
815 )))
816 }
817 }
818 }
819
820 let amplitudes = builder.to_statevector()?;
822 Register::with_amplitudes(amplitudes)
823 }
824}
825
826pub mod contraction_optimization {
828 use super::*;
829
830 pub struct DynamicProgrammingOptimizer {
832 memo: HashMap<Vec<usize>, (usize, Vec<(usize, usize)>)>,
833 }
834
835 impl DynamicProgrammingOptimizer {
836 pub fn new() -> Self {
837 Self {
838 memo: HashMap::new(),
839 }
840 }
841
842 pub fn optimize(&mut self, network: &TensorNetwork) -> Vec<(usize, usize)> {
844 let tensor_ids: Vec<_> = network.tensors.keys().cloned().collect();
845 self.find_optimal_order(&tensor_ids, network).1
846 }
847
848 fn find_optimal_order(
849 &mut self,
850 tensors: &[usize],
851 network: &TensorNetwork,
852 ) -> (usize, Vec<(usize, usize)>) {
853 if tensors.len() <= 1 {
854 return (0, vec![]);
855 }
856
857 let key = tensors.to_vec();
858 if let Some(result) = self.memo.get(&key) {
859 return result.clone();
860 }
861
862 let mut best_cost = usize::MAX;
863 let mut best_order = vec![];
864
865 for i in 0..tensors.len() {
867 for j in (i + 1)..tensors.len() {
868 if self.are_connected(tensors[i], tensors[j], network) {
870 let cost = network.estimate_contraction_cost(tensors[i], tensors[j]);
871
872 let mut remaining = vec![];
874 for (k, &t) in tensors.iter().enumerate() {
875 if k != i && k != j {
876 remaining.push(t);
877 }
878 }
879 remaining.push(network.next_id + remaining.len()); let (sub_cost, sub_order) = self.find_optimal_order(&remaining, network);
882 let total_cost = cost + sub_cost;
883
884 if total_cost < best_cost {
885 best_cost = total_cost;
886 best_order = vec![(tensors[i], tensors[j])];
887 best_order.extend(sub_order);
888 }
889 }
890 }
891 }
892
893 self.memo.insert(key, (best_cost, best_order.clone()));
894 (best_cost, best_order)
895 }
896
897 fn are_connected(&self, t1: usize, t2: usize, network: &TensorNetwork) -> bool {
898 network.edges.iter().any(|e| {
899 (e.tensor1 == t1 && e.tensor2 == t2) || (e.tensor1 == t2 && e.tensor2 == t1)
900 })
901 }
902 }
903}
904
905#[cfg(test)]
906mod tests {
907 use super::*;
908
909 #[test]
910 fn test_tensor_creation() {
911 let data = ArrayD::zeros(IxDyn(&[2, 2]));
912 let tensor = Tensor::new(0, data, vec!["in".to_string(), "out".to_string()]);
913 assert_eq!(tensor.rank(), 2);
914 assert_eq!(tensor.shape, vec![2, 2]);
915 }
916
917 #[test]
918 fn test_qubit_tensors() {
919 let t0 = Tensor::qubit_zero(0, "q0".to_string());
920 assert_eq!(t0.data[[0]], Complex64::new(1.0, 0.0));
921 assert_eq!(t0.data[[1]], Complex64::new(0.0, 0.0));
922
923 let t1 = Tensor::qubit_one(1, "q1".to_string());
924 assert_eq!(t1.data[[0]], Complex64::new(0.0, 0.0));
925 assert_eq!(t1.data[[1]], Complex64::new(1.0, 0.0));
926 }
927
928 #[test]
929 fn test_tensor_network_builder() {
930 let builder = TensorNetworkBuilder::new(2);
931 assert_eq!(builder.network.tensors.len(), 2);
932 }
933
934 #[test]
935 fn test_network_connection() {
936 let mut network = TensorNetwork::new();
937
938 let t1 = Tensor::qubit_zero(0, "q0".to_string());
939 let t2 = Tensor::qubit_zero(1, "q1".to_string());
940
941 let id1 = network.add_tensor(t1);
942 let id2 = network.add_tensor(t2);
943
944 assert!(network
946 .connect(id1, "bond".to_string(), id2, "bond".to_string())
947 .is_err());
948 }
949}