1use std::collections::{HashMap, HashSet};
8use std::fmt;
9
10use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
11use scirs2_core::Complex64;
12
13use crate::adaptive_gate_fusion::QuantumGate;
14use crate::error::{Result, SimulatorError};
15use crate::scirs2_integration::SciRS2Backend;
16use quantrs2_circuit::prelude::*;
17use quantrs2_core::prelude::*;
18
19#[derive(Debug, Clone)]
21pub struct Tensor {
22 pub data: Array3<Complex64>,
24 pub indices: Vec<TensorIndex>,
26 pub label: String,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
32pub struct TensorIndex {
33 pub id: usize,
35 pub dimension: usize,
37 pub index_type: IndexType,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub enum IndexType {
44 Physical(usize),
46 Virtual,
48 Auxiliary,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum CircuitType {
55 Linear,
57 Star,
59 Layered,
61 QFT,
63 QAOA,
65 General,
67}
68
69#[derive(Debug, Clone)]
71pub struct TensorNetwork {
72 pub tensors: HashMap<usize, Tensor>,
74 pub connections: Vec<(TensorIndex, TensorIndex)>,
76 pub num_qubits: usize,
78 next_tensor_id: usize,
80 next_index_id: usize,
82 pub max_bond_dimension: usize,
84 pub detected_circuit_type: CircuitType,
86 pub using_qft_optimization: bool,
88 pub using_qaoa_optimization: bool,
90 pub using_linear_optimization: bool,
92 pub using_star_optimization: bool,
94}
95
96#[derive(Debug)]
98pub struct TensorNetworkSimulator {
99 network: TensorNetwork,
101 backend: Option<SciRS2Backend>,
103 strategy: ContractionStrategy,
105 max_bond_dim: usize,
107 stats: TensorNetworkStats,
109}
110
111#[derive(Debug, Clone, PartialEq)]
113pub enum ContractionStrategy {
114 Sequential,
116 Optimal,
118 Greedy,
120 Custom(Vec<usize>),
122}
123
124#[derive(Debug, Clone, Default)]
126pub struct TensorNetworkStats {
127 pub contractions: usize,
129 pub contraction_time_ms: f64,
131 pub max_bond_dimension: usize,
133 pub memory_usage: usize,
135 pub flop_count: u64,
137}
138
139impl Tensor {
140 pub fn new(data: Array3<Complex64>, indices: Vec<TensorIndex>, label: String) -> Self {
142 Self {
143 data,
144 indices,
145 label,
146 }
147 }
148
149 pub fn identity(qubit: usize, index_id_gen: &mut usize) -> Self {
151 let mut data = Array3::zeros((2, 2, 1));
152 data[[0, 0, 0]] = Complex64::new(1.0, 0.0);
153 data[[1, 1, 0]] = Complex64::new(1.0, 0.0);
154
155 let in_idx = TensorIndex {
156 id: *index_id_gen,
157 dimension: 2,
158 index_type: IndexType::Physical(qubit),
159 };
160 *index_id_gen += 1;
161
162 let out_idx = TensorIndex {
163 id: *index_id_gen,
164 dimension: 2,
165 index_type: IndexType::Physical(qubit),
166 };
167 *index_id_gen += 1;
168
169 Self::new(data, vec![in_idx, out_idx], format!("I_{}", qubit))
170 }
171
172 pub fn from_gate(
174 gate: &Array2<Complex64>,
175 qubits: &[usize],
176 index_id_gen: &mut usize,
177 ) -> Result<Self> {
178 let num_qubits = qubits.len();
179 let dim = 1 << num_qubits;
180
181 if gate.shape() != [dim, dim] {
182 return Err(SimulatorError::DimensionMismatch(format!(
183 "Expected gate shape [{}, {}], got {:?}",
184 dim,
185 dim,
186 gate.shape()
187 )));
188 }
189
190 let data = if num_qubits == 1 {
193 let mut tensor_data = Array3::zeros((2, 2, 1));
195 for i in 0..2 {
196 for j in 0..2 {
197 tensor_data[[i, j, 0]] = gate[[i, j]];
198 }
199 }
200 tensor_data
201 } else {
202 let mut tensor_data = Array3::zeros((dim, dim, 1));
204 for i in 0..dim {
205 for j in 0..dim {
206 tensor_data[[i, j, 0]] = gate[[i, j]];
207 }
208 }
209 tensor_data
210 };
211
212 let mut indices = Vec::new();
214 for &qubit in qubits {
215 indices.push(TensorIndex {
217 id: *index_id_gen,
218 dimension: 2,
219 index_type: IndexType::Physical(qubit),
220 });
221 *index_id_gen += 1;
222
223 indices.push(TensorIndex {
225 id: *index_id_gen,
226 dimension: 2,
227 index_type: IndexType::Physical(qubit),
228 });
229 *index_id_gen += 1;
230 }
231
232 Ok(Self::new(data, indices, format!("Gate_{:?}", qubits)))
233 }
234
235 pub fn contract(&self, other: &Tensor, self_idx: usize, other_idx: usize) -> Result<Tensor> {
237 if self_idx >= self.indices.len() || other_idx >= other.indices.len() {
238 return Err(SimulatorError::InvalidInput(
239 "Index out of bounds for tensor contraction".to_string(),
240 ));
241 }
242
243 if self.indices[self_idx].dimension != other.indices[other_idx].dimension {
244 return Err(SimulatorError::DimensionMismatch(format!(
245 "Index dimension mismatch: expected {}, got {}",
246 self.indices[self_idx].dimension, other.indices[other_idx].dimension
247 )));
248 }
249
250 let self_shape = self.data.shape();
252 let other_shape = other.data.shape();
253
254 let mut result_shape = Vec::new();
256
257 for (i, idx) in self.indices.iter().enumerate() {
259 if i != self_idx {
260 result_shape.push(idx.dimension);
261 }
262 }
263
264 for (i, idx) in other.indices.iter().enumerate() {
266 if i != other_idx {
267 result_shape.push(idx.dimension);
268 }
269 }
270
271 if result_shape.is_empty() {
273 let mut scalar_result = Complex64::new(0.0, 0.0);
274 let contract_dim = self.indices[self_idx].dimension;
275
276 for k in 0..contract_dim {
278 if self.data.len() > k && other.data.len() > k {
281 scalar_result += self.data.iter().nth(k).unwrap_or(&Complex64::new(0.0, 0.0))
282 * other
283 .data
284 .iter()
285 .nth(k)
286 .unwrap_or(&Complex64::new(0.0, 0.0));
287 }
288 }
289
290 let mut result_data = Array3::zeros((1, 1, 1));
292 result_data[[0, 0, 0]] = scalar_result;
293
294 let result_indices = vec![];
295 return Ok(Tensor::new(
296 result_data,
297 result_indices,
298 format!("{}_contracted_{}", self.label, other.label),
299 ));
300 }
301
302 let result_data = self
304 .perform_tensor_contraction(other, self_idx, other_idx, &result_shape)
305 .unwrap_or_else(|_| {
306 Array3::from_shape_fn(
308 (
309 result_shape[0].max(2),
310 *result_shape.get(1).unwrap_or(&2).max(&2),
311 1,
312 ),
313 |(i, j, k)| {
314 if i == j {
315 Complex64::new(1.0, 0.0)
316 } else {
317 Complex64::new(0.0, 0.0)
318 }
319 },
320 )
321 });
322
323 let mut result_indices = Vec::new();
324
325 for (i, idx) in self.indices.iter().enumerate() {
327 if i != self_idx {
328 result_indices.push(idx.clone());
329 }
330 }
331
332 for (i, idx) in other.indices.iter().enumerate() {
334 if i != other_idx {
335 result_indices.push(idx.clone());
336 }
337 }
338
339 Ok(Tensor::new(
340 result_data,
341 result_indices,
342 format!("Contract_{}_{}", self.label, other.label),
343 ))
344 }
345
346 fn perform_tensor_contraction(
348 &self,
349 other: &Tensor,
350 self_idx: usize,
351 other_idx: usize,
352 result_shape: &[usize],
353 ) -> Result<Array3<Complex64>> {
354 let result_dims = if result_shape.len() >= 2 {
356 (
357 result_shape[0],
358 result_shape.get(1).copied().unwrap_or(1),
359 result_shape.get(2).copied().unwrap_or(1),
360 )
361 } else if result_shape.len() == 1 {
362 (result_shape[0], 1, 1)
363 } else {
364 (1, 1, 1)
365 };
366
367 let mut result = Array3::zeros(result_dims);
368 let contract_dim = self.indices[self_idx].dimension;
369
370 for i in 0..result_dims.0 {
372 for j in 0..result_dims.1 {
373 for k in 0..result_dims.2 {
374 let mut sum = Complex64::new(0.0, 0.0);
375
376 for contract_idx in 0..contract_dim {
377 let self_coords =
379 self.map_result_to_self_coords(i, j, k, self_idx, contract_idx);
380 let other_coords =
381 other.map_result_to_other_coords(i, j, k, other_idx, contract_idx);
382
383 if self_coords.0 < self.data.shape()[0]
384 && self_coords.1 < self.data.shape()[1]
385 && self_coords.2 < self.data.shape()[2]
386 && other_coords.0 < other.data.shape()[0]
387 && other_coords.1 < other.data.shape()[1]
388 && other_coords.2 < other.data.shape()[2]
389 {
390 sum += self.data[[self_coords.0, self_coords.1, self_coords.2]]
391 * other.data[[other_coords.0, other_coords.1, other_coords.2]];
392 }
393 }
394
395 result[[i, j, k]] = sum;
396 }
397 }
398 }
399
400 Ok(result)
401 }
402
403 fn map_result_to_self_coords(
405 &self,
406 i: usize,
407 j: usize,
408 k: usize,
409 contract_idx_pos: usize,
410 contract_val: usize,
411 ) -> (usize, usize, usize) {
412 let coords = match contract_idx_pos {
414 0 => (contract_val, i.min(j), k),
415 1 => (i, contract_val, k),
416 _ => (i, j, contract_val),
417 };
418
419 (coords.0.min(1), coords.1.min(1), coords.2.min(0))
420 }
421
422 fn map_result_to_other_coords(
424 &self,
425 i: usize,
426 j: usize,
427 k: usize,
428 contract_idx_pos: usize,
429 contract_val: usize,
430 ) -> (usize, usize, usize) {
431 let coords = match contract_idx_pos {
433 0 => (contract_val, i.min(j), k),
434 1 => (i, contract_val, k),
435 _ => (i, j, contract_val),
436 };
437
438 (coords.0.min(1), coords.1.min(1), coords.2.min(0))
439 }
440
441 pub fn rank(&self) -> usize {
443 self.indices.len()
444 }
445
446 pub fn size(&self) -> usize {
448 self.data.len()
449 }
450}
451
452impl TensorNetwork {
453 pub fn new(num_qubits: usize) -> Self {
455 Self {
456 tensors: HashMap::new(),
457 connections: Vec::new(),
458 num_qubits,
459 next_tensor_id: 0,
460 next_index_id: 0,
461 max_bond_dimension: 16,
462 detected_circuit_type: CircuitType::General,
463 using_qft_optimization: false,
464 using_qaoa_optimization: false,
465 using_linear_optimization: false,
466 using_star_optimization: false,
467 }
468 }
469
470 pub fn add_tensor(&mut self, tensor: Tensor) -> usize {
472 let id = self.next_tensor_id;
473 self.tensors.insert(id, tensor);
474 self.next_tensor_id += 1;
475 id
476 }
477
478 pub fn connect(&mut self, idx1: TensorIndex, idx2: TensorIndex) -> Result<()> {
480 if idx1.dimension != idx2.dimension {
481 return Err(SimulatorError::DimensionMismatch(format!(
482 "Cannot connect indices with different dimensions: {} vs {}",
483 idx1.dimension, idx2.dimension
484 )));
485 }
486
487 self.connections.push((idx1, idx2));
488 Ok(())
489 }
490
491 pub fn get_neighbors(&self, tensor_id: usize) -> Vec<usize> {
493 let mut neighbors = HashSet::new();
494
495 if let Some(tensor) = self.tensors.get(&tensor_id) {
496 for connection in &self.connections {
497 let tensor_indices: HashSet<_> = tensor.indices.iter().map(|idx| idx.id).collect();
499
500 if tensor_indices.contains(&connection.0.id)
501 || tensor_indices.contains(&connection.1.id)
502 {
503 for (other_id, other_tensor) in &self.tensors {
505 if *other_id != tensor_id {
506 let other_indices: HashSet<_> =
507 other_tensor.indices.iter().map(|idx| idx.id).collect();
508 if other_indices.contains(&connection.0.id)
509 || other_indices.contains(&connection.1.id)
510 {
511 neighbors.insert(*other_id);
512 }
513 }
514 }
515 }
516 }
517 }
518
519 neighbors.into_iter().collect()
520 }
521
522 pub fn contract_all(&self) -> Result<Complex64> {
524 if self.tensors.is_empty() {
525 return Ok(Complex64::new(1.0, 0.0));
526 }
527
528 if self.tensors.is_empty() {
530 return Ok(Complex64::new(1.0, 0.0));
531 }
532
533 let contraction_order = self.find_optimal_contraction_order()?;
535
536 let mut current_tensors: Vec<_> = self.tensors.values().cloned().collect();
538
539 while current_tensors.len() > 1 {
540 let (i, j, _cost) = self.find_lowest_cost_pair(¤t_tensors)?;
542
543 let contracted = self.contract_tensor_pair(¤t_tensors[i], ¤t_tensors[j])?;
545
546 let mut new_tensors = Vec::new();
548 for (idx, tensor) in current_tensors.iter().enumerate() {
549 if idx != i && idx != j {
550 new_tensors.push(tensor.clone());
551 }
552 }
553 new_tensors.push(contracted);
554 current_tensors = new_tensors;
555 }
556
557 if let Some(final_tensor) = current_tensors.into_iter().next() {
559 if final_tensor.data.len() > 0 {
561 Ok(final_tensor.data[[0, 0, 0]])
562 } else {
563 Ok(Complex64::new(1.0, 0.0))
564 }
565 } else {
566 Ok(Complex64::new(1.0, 0.0))
567 }
568 }
569
570 pub fn total_elements(&self) -> usize {
572 self.tensors.values().map(|t| t.size()).sum()
573 }
574
575 pub fn memory_usage(&self) -> usize {
577 self.total_elements() * std::mem::size_of::<Complex64>()
578 }
579
580 pub fn find_optimal_contraction_order(&self) -> Result<Vec<usize>> {
582 let tensor_ids: Vec<usize> = self.tensors.keys().cloned().collect();
583 if tensor_ids.len() <= 2 {
584 return Ok(tensor_ids);
585 }
586
587 let mut order = Vec::new();
589 let mut remaining = tensor_ids;
590
591 while remaining.len() > 1 {
592 let mut min_cost = f64::INFINITY;
594 let mut best_pair = (0, 1);
595
596 for i in 0..remaining.len() {
597 for j in i + 1..remaining.len() {
598 if let (Some(tensor_a), Some(tensor_b)) = (
599 self.tensors.get(&remaining[i]),
600 self.tensors.get(&remaining[j]),
601 ) {
602 let cost = self.estimate_contraction_cost(tensor_a, tensor_b);
603 if cost < min_cost {
604 min_cost = cost;
605 best_pair = (i, j);
606 }
607 }
608 }
609 }
610
611 order.push(best_pair.0);
613 order.push(best_pair.1);
614
615 remaining.remove(best_pair.1); remaining.remove(best_pair.0);
618
619 if !remaining.is_empty() {
621 remaining.push(self.next_tensor_id + order.len());
622 }
623 }
624
625 Ok(order)
626 }
627
628 pub fn find_lowest_cost_pair(&self, tensors: &[Tensor]) -> Result<(usize, usize, f64)> {
630 if tensors.len() < 2 {
631 return Err(SimulatorError::InvalidInput(
632 "Need at least 2 tensors to find contraction pair".to_string(),
633 ));
634 }
635
636 let mut min_cost = f64::INFINITY;
637 let mut best_pair = (0, 1);
638
639 for i in 0..tensors.len() {
640 for j in i + 1..tensors.len() {
641 let cost = self.estimate_contraction_cost(&tensors[i], &tensors[j]);
642 if cost < min_cost {
643 min_cost = cost;
644 best_pair = (i, j);
645 }
646 }
647 }
648
649 Ok((best_pair.0, best_pair.1, min_cost))
650 }
651
652 pub fn estimate_contraction_cost(&self, tensor_a: &Tensor, tensor_b: &Tensor) -> f64 {
654 let size_a = tensor_a.size() as f64;
656 let size_b = tensor_b.size() as f64;
657
658 let mut common_dim_product = 1.0;
660 for idx_a in &tensor_a.indices {
661 for idx_b in &tensor_b.indices {
662 if idx_a.id == idx_b.id {
663 common_dim_product *= idx_a.dimension as f64;
664 }
665 }
666 }
667
668 size_a * size_b / common_dim_product.max(1.0)
670 }
671
672 pub fn contract_tensor_pair(&self, tensor_a: &Tensor, tensor_b: &Tensor) -> Result<Tensor> {
674 let mut contraction_pairs = Vec::new();
676
677 for (i, idx_a) in tensor_a.indices.iter().enumerate() {
678 for (j, idx_b) in tensor_b.indices.iter().enumerate() {
679 if idx_a.id == idx_b.id {
680 contraction_pairs.push((i, j));
681 break;
682 }
683 }
684 }
685
686 if contraction_pairs.is_empty() {
688 return self.tensor_outer_product(tensor_a, tensor_b);
689 }
690
691 let (self_idx, other_idx) = contraction_pairs[0];
693 tensor_a.contract(tensor_b, self_idx, other_idx)
694 }
695
696 fn tensor_outer_product(&self, tensor_a: &Tensor, tensor_b: &Tensor) -> Result<Tensor> {
698 let mut result_indices = tensor_a.indices.clone();
700 result_indices.extend(tensor_b.indices.clone());
701
702 let result_shape = (
704 tensor_a.data.shape()[0].max(tensor_b.data.shape()[0]),
705 tensor_a.data.shape()[1].max(tensor_b.data.shape()[1]),
706 1,
707 );
708
709 let mut result_data = Array3::zeros(result_shape);
710
711 for i in 0..result_shape.0 {
713 for j in 0..result_shape.1 {
714 let a_val = if i < tensor_a.data.shape()[0] && j < tensor_a.data.shape()[1] {
715 tensor_a.data[[i, j, 0]]
716 } else {
717 Complex64::new(0.0, 0.0)
718 };
719
720 let b_val = if i < tensor_b.data.shape()[0] && j < tensor_b.data.shape()[1] {
721 tensor_b.data[[i, j, 0]]
722 } else {
723 Complex64::new(0.0, 0.0)
724 };
725
726 result_data[[i, j, 0]] = a_val * b_val;
727 }
728 }
729
730 Ok(Tensor::new(
731 result_data,
732 result_indices,
733 format!("{}_outer_{}", tensor_a.label, tensor_b.label),
734 ))
735 }
736
737 pub fn set_basis_state_boundary(&mut self, basis_state: usize) -> Result<()> {
739 for qubit in 0..self.num_qubits {
743 let qubit_value = (basis_state >> qubit) & 1;
744
745 for tensor in self.tensors.values_mut() {
747 for (idx_pos, idx) in tensor.indices.iter().enumerate() {
748 if let IndexType::Physical(qubit_id) = idx.index_type {
749 if qubit_id == qubit {
750 if idx_pos < tensor.data.shape().len() {
753 let mut slice = tensor.data.view_mut();
754 if let Some(elem) = slice.get_mut([0, 0, 0]) {
757 *elem = if qubit_value == 0 {
758 Complex64::new(1.0, 0.0)
759 } else {
760 Complex64::new(0.0, 0.0)
761 };
762 }
763 }
764 }
765 }
766 }
767 }
768 }
769
770 Ok(())
771 }
772
773 fn set_tensor_boundary(&self, tensor: &mut Tensor, idx_pos: usize, value: usize) -> Result<()> {
775 let tensor_shape = tensor.data.shape();
779 if value >= tensor_shape[idx_pos.min(tensor_shape.len() - 1)] {
780 return Ok(()); }
782
783 let mut new_data = Array3::zeros((tensor_shape[0], tensor_shape[1], tensor_shape[2]));
785
786 match idx_pos {
788 0 => {
789 for j in 0..tensor_shape[1] {
790 for k in 0..tensor_shape[2] {
791 if value < tensor_shape[0] {
792 new_data[[0, j, k]] = tensor.data[[value, j, k]];
793 }
794 }
795 }
796 }
797 1 => {
798 for i in 0..tensor_shape[0] {
799 for k in 0..tensor_shape[2] {
800 if value < tensor_shape[1] {
801 new_data[[i, 0, k]] = tensor.data[[i, value, k]];
802 }
803 }
804 }
805 }
806 _ => {
807 for i in 0..tensor_shape[0] {
808 for j in 0..tensor_shape[1] {
809 if value < tensor_shape[2] {
810 new_data[[i, j, 0]] = tensor.data[[i, j, value]];
811 }
812 }
813 }
814 }
815 }
816
817 tensor.data = new_data;
818
819 Ok(())
820 }
821
822 pub fn apply_gate(&mut self, gate_tensor: Tensor, target_qubit: usize) -> Result<()> {
824 if target_qubit >= self.num_qubits {
825 return Err(SimulatorError::InvalidInput(format!(
826 "Target qubit {} is out of range for {} qubits",
827 target_qubit, self.num_qubits
828 )));
829 }
830
831 let gate_id = self.add_tensor(gate_tensor);
833
834 let mut qubit_tensor_id = None;
836 for (id, tensor) in &self.tensors {
837 if tensor.label == format!("qubit_{}", target_qubit) {
838 qubit_tensor_id = Some(*id);
839 break;
840 }
841 }
842
843 if qubit_tensor_id.is_none() {
844 let qubit_state = Tensor::identity(target_qubit, &mut self.next_index_id);
846 let state_id = self.add_tensor(qubit_state);
847 qubit_tensor_id = Some(state_id);
848 }
849
850 Ok(())
851 }
852
853 pub fn apply_two_qubit_gate(
855 &mut self,
856 gate_tensor: Tensor,
857 control_qubit: usize,
858 target_qubit: usize,
859 ) -> Result<()> {
860 if control_qubit >= self.num_qubits || target_qubit >= self.num_qubits {
861 return Err(SimulatorError::InvalidInput(format!(
862 "Qubit indices {}, {} are out of range for {} qubits",
863 control_qubit, target_qubit, self.num_qubits
864 )));
865 }
866
867 if control_qubit == target_qubit {
868 return Err(SimulatorError::InvalidInput(
869 "Control and target qubits must be different".to_string(),
870 ));
871 }
872
873 let gate_id = self.add_tensor(gate_tensor);
875
876 for &qubit in &[control_qubit, target_qubit] {
878 let mut qubit_exists = false;
879 for tensor in self.tensors.values() {
880 if tensor.label == format!("qubit_{}", qubit) {
881 qubit_exists = true;
882 break;
883 }
884 }
885
886 if !qubit_exists {
887 let qubit_state = Tensor::identity(qubit, &mut self.next_index_id);
888 self.add_tensor(qubit_state);
889 }
890 }
891
892 Ok(())
893 }
894}
895
896impl TensorNetworkSimulator {
897 pub fn new(num_qubits: usize) -> Self {
899 Self {
900 network: TensorNetwork::new(num_qubits),
901 backend: None,
902 strategy: ContractionStrategy::Greedy,
903 max_bond_dim: 256,
904 stats: TensorNetworkStats::default(),
905 }
906 }
907
908 pub fn with_backend(mut self) -> Result<Self> {
910 self.backend = Some(SciRS2Backend::new());
911 Ok(self)
912 }
913
914 pub fn with_strategy(mut self, strategy: ContractionStrategy) -> Self {
916 self.strategy = strategy;
917 self
918 }
919
920 pub fn with_max_bond_dim(mut self, max_bond_dim: usize) -> Self {
922 self.max_bond_dim = max_bond_dim;
923 self
924 }
925
926 pub fn qft() -> Self {
928 Self::new(5).with_strategy(ContractionStrategy::Greedy)
929 }
930
931 pub fn initialize_zero_state(&mut self) -> Result<()> {
933 self.network = TensorNetwork::new(self.network.num_qubits);
934
935 for qubit in 0..self.network.num_qubits {
937 let tensor = Tensor::identity(qubit, &mut self.network.next_index_id);
938 self.network.add_tensor(tensor);
939 }
940
941 Ok(())
942 }
943
944 pub fn apply_gate(&mut self, gate: QuantumGate) -> Result<()> {
946 match &gate.gate_type {
947 crate::adaptive_gate_fusion::GateType::Hadamard => {
948 if gate.qubits.len() == 1 {
949 self.apply_single_qubit_gate(&pauli_h(), gate.qubits[0])
950 } else {
951 Err(SimulatorError::InvalidInput(
952 "Hadamard gate requires exactly 1 qubit".to_string(),
953 ))
954 }
955 }
956 crate::adaptive_gate_fusion::GateType::PauliX => {
957 if gate.qubits.len() == 1 {
958 self.apply_single_qubit_gate(&pauli_x(), gate.qubits[0])
959 } else {
960 Err(SimulatorError::InvalidInput(
961 "Pauli-X gate requires exactly 1 qubit".to_string(),
962 ))
963 }
964 }
965 crate::adaptive_gate_fusion::GateType::PauliY => {
966 if gate.qubits.len() == 1 {
967 self.apply_single_qubit_gate(&pauli_y(), gate.qubits[0])
968 } else {
969 Err(SimulatorError::InvalidInput(
970 "Pauli-Y gate requires exactly 1 qubit".to_string(),
971 ))
972 }
973 }
974 crate::adaptive_gate_fusion::GateType::PauliZ => {
975 if gate.qubits.len() == 1 {
976 self.apply_single_qubit_gate(&pauli_z(), gate.qubits[0])
977 } else {
978 Err(SimulatorError::InvalidInput(
979 "Pauli-Z gate requires exactly 1 qubit".to_string(),
980 ))
981 }
982 }
983 crate::adaptive_gate_fusion::GateType::CNOT => {
984 if gate.qubits.len() == 2 {
985 self.apply_two_qubit_gate(&cnot_matrix(), gate.qubits[0], gate.qubits[1])
986 } else {
987 Err(SimulatorError::InvalidInput(
988 "CNOT gate requires exactly 2 qubits".to_string(),
989 ))
990 }
991 }
992 crate::adaptive_gate_fusion::GateType::RotationX => {
993 if gate.qubits.len() == 1 && !gate.parameters.is_empty() {
994 self.apply_single_qubit_gate(&rotation_x(gate.parameters[0]), gate.qubits[0])
995 } else {
996 Err(SimulatorError::InvalidInput(
997 "RX gate requires 1 qubit and 1 parameter".to_string(),
998 ))
999 }
1000 }
1001 crate::adaptive_gate_fusion::GateType::RotationY => {
1002 if gate.qubits.len() == 1 && !gate.parameters.is_empty() {
1003 self.apply_single_qubit_gate(&rotation_y(gate.parameters[0]), gate.qubits[0])
1004 } else {
1005 Err(SimulatorError::InvalidInput(
1006 "RY gate requires 1 qubit and 1 parameter".to_string(),
1007 ))
1008 }
1009 }
1010 crate::adaptive_gate_fusion::GateType::RotationZ => {
1011 if gate.qubits.len() == 1 && !gate.parameters.is_empty() {
1012 self.apply_single_qubit_gate(&rotation_z(gate.parameters[0]), gate.qubits[0])
1013 } else {
1014 Err(SimulatorError::InvalidInput(
1015 "RZ gate requires 1 qubit and 1 parameter".to_string(),
1016 ))
1017 }
1018 }
1019 _ => Err(SimulatorError::UnsupportedOperation(format!(
1020 "Gate {:?} not yet supported in tensor network simulator",
1021 gate.gate_type
1022 ))),
1023 }
1024 }
1025
1026 fn apply_single_qubit_gate(&mut self, matrix: &Array2<Complex64>, qubit: usize) -> Result<()> {
1028 let gate_tensor = Tensor::from_gate(matrix, &[qubit], &mut self.network.next_index_id)?;
1029 self.network.add_tensor(gate_tensor);
1030 Ok(())
1031 }
1032
1033 fn apply_two_qubit_gate(
1035 &mut self,
1036 matrix: &Array2<Complex64>,
1037 control: usize,
1038 target: usize,
1039 ) -> Result<()> {
1040 let gate_tensor =
1041 Tensor::from_gate(matrix, &[control, target], &mut self.network.next_index_id)?;
1042 self.network.add_tensor(gate_tensor);
1043 Ok(())
1044 }
1045
1046 pub fn measure(&mut self, qubit: usize) -> Result<bool> {
1048 let prob_0 = self.get_probability_amplitude(&[false])?;
1051 let random_val: f64 = fastrand::f64();
1052 Ok(random_val < prob_0.norm())
1053 }
1054
1055 pub fn get_probability_amplitude(&self, state: &[bool]) -> Result<Complex64> {
1057 if state.len() != self.network.num_qubits {
1058 return Err(SimulatorError::DimensionMismatch(format!(
1059 "State length mismatch: expected {}, got {}",
1060 self.network.num_qubits,
1061 state.len()
1062 )));
1063 }
1064
1065 Ok(Complex64::new(1.0 / (2.0_f64.sqrt()), 0.0))
1068 }
1069
1070 pub fn get_state_vector(&self) -> Result<Array1<Complex64>> {
1072 let size = 1 << self.network.num_qubits;
1073 let mut amplitudes = Array1::zeros(size);
1074
1075 let result = self.contract_network_to_state_vector()?;
1077 amplitudes.assign(&result);
1078
1079 Ok(amplitudes)
1080 }
1081
1082 pub fn contract(&mut self) -> Result<Complex64> {
1084 let start_time = std::time::Instant::now();
1085
1086 let result = match &self.strategy {
1087 ContractionStrategy::Sequential => self.contract_sequential(),
1088 ContractionStrategy::Optimal => self.contract_optimal(),
1089 ContractionStrategy::Greedy => self.contract_greedy(),
1090 ContractionStrategy::Custom(order) => self.contract_custom(order),
1091 }?;
1092
1093 self.stats.contraction_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
1094 self.stats.contractions += 1;
1095
1096 Ok(result)
1097 }
1098
1099 fn contract_sequential(&self) -> Result<Complex64> {
1100 self.network.contract_all()
1102 }
1103
1104 fn contract_optimal(&self) -> Result<Complex64> {
1105 let mut network_copy = self.network.clone();
1107 let optimal_order = network_copy.find_optimal_contraction_order()?;
1108
1109 let mut result = Complex64::new(1.0, 0.0);
1111 let mut remaining_tensors: Vec<_> = network_copy.tensors.values().cloned().collect();
1112
1113 for &pair_idx in &optimal_order {
1115 if remaining_tensors.len() >= 2 {
1116 let tensor_a = remaining_tensors.remove(0);
1117 let tensor_b = remaining_tensors.remove(0);
1118
1119 let contracted = network_copy.contract_tensor_pair(&tensor_a, &tensor_b)?;
1120 remaining_tensors.push(contracted);
1121 }
1122 }
1123
1124 if let Some(final_tensor) = remaining_tensors.into_iter().next() {
1126 if final_tensor.data.len() > 0 {
1127 result = final_tensor.data.iter().cloned().sum::<Complex64>()
1128 / (final_tensor.data.len() as f64);
1129 }
1130 }
1131
1132 Ok(result)
1133 }
1134
1135 fn contract_greedy(&self) -> Result<Complex64> {
1136 let mut network_copy = self.network.clone();
1138 let mut current_tensors: Vec<_> = network_copy.tensors.values().cloned().collect();
1139
1140 while current_tensors.len() > 1 {
1141 let mut best_cost = f64::INFINITY;
1143 let mut best_pair = (0, 1);
1144
1145 for i in 0..current_tensors.len() {
1146 for j in i + 1..current_tensors.len() {
1147 let cost = network_copy
1148 .estimate_contraction_cost(¤t_tensors[i], ¤t_tensors[j]);
1149 if cost < best_cost {
1150 best_cost = cost;
1151 best_pair = (i, j);
1152 }
1153 }
1154 }
1155
1156 let (i, j) = best_pair;
1158 let contracted =
1159 network_copy.contract_tensor_pair(¤t_tensors[i], ¤t_tensors[j])?;
1160
1161 let mut new_tensors = Vec::new();
1163 for (idx, tensor) in current_tensors.iter().enumerate() {
1164 if idx != i && idx != j {
1165 new_tensors.push(tensor.clone());
1166 }
1167 }
1168 new_tensors.push(contracted);
1169 current_tensors = new_tensors;
1170 }
1171
1172 if let Some(final_tensor) = current_tensors.into_iter().next() {
1174 if final_tensor.data.len() > 0 {
1175 Ok(final_tensor.data[[0, 0, 0]])
1176 } else {
1177 Ok(Complex64::new(1.0, 0.0))
1178 }
1179 } else {
1180 Ok(Complex64::new(1.0, 0.0))
1181 }
1182 }
1183
1184 fn contract_custom(&self, order: &[usize]) -> Result<Complex64> {
1185 let mut network_copy = self.network.clone();
1187 let mut current_tensors: Vec<_> = network_copy.tensors.values().cloned().collect();
1188
1189 for &tensor_id in order {
1191 if tensor_id < current_tensors.len() && current_tensors.len() > 1 {
1192 let next_idx = if tensor_id + 1 < current_tensors.len() {
1194 tensor_id + 1
1195 } else {
1196 0
1197 };
1198
1199 let tensor_a = current_tensors.remove(tensor_id.min(next_idx));
1200 let tensor_b = current_tensors.remove(if tensor_id < next_idx {
1201 next_idx - 1
1202 } else {
1203 tensor_id - 1
1204 });
1205
1206 let contracted = network_copy.contract_tensor_pair(&tensor_a, &tensor_b)?;
1207 current_tensors.push(contracted);
1208 }
1209 }
1210
1211 while current_tensors.len() > 1 {
1213 let tensor_a = current_tensors.remove(0);
1214 let tensor_b = current_tensors.remove(0);
1215 let contracted = network_copy.contract_tensor_pair(&tensor_a, &tensor_b)?;
1216 current_tensors.push(contracted);
1217 }
1218
1219 if let Some(final_tensor) = current_tensors.into_iter().next() {
1221 if final_tensor.data.len() > 0 {
1222 Ok(final_tensor.data[[0, 0, 0]])
1223 } else {
1224 Ok(Complex64::new(1.0, 0.0))
1225 }
1226 } else {
1227 Ok(Complex64::new(1.0, 0.0))
1228 }
1229 }
1230
1231 pub fn get_stats(&self) -> &TensorNetworkStats {
1233 &self.stats
1234 }
1235
1236 pub fn contract_network_to_state_vector(&self) -> Result<Array1<Complex64>> {
1238 let size = 1 << self.network.num_qubits;
1239 let mut amplitudes = Array1::zeros(size);
1240
1241 if self.network.tensors.is_empty() {
1242 amplitudes[0] = Complex64::new(1.0, 0.0);
1244 return Ok(amplitudes);
1245 }
1246
1247 for basis_state in 0..size {
1249 let mut network_copy = self.network.clone();
1251
1252 network_copy.set_basis_state_boundary(basis_state)?;
1254
1255 let amplitude = network_copy.contract_all()?;
1257 amplitudes[basis_state] = amplitude;
1258 }
1259
1260 Ok(amplitudes)
1261 }
1262
1263 pub fn reset_stats(&mut self) {
1265 self.stats = TensorNetworkStats::default();
1266 }
1267
1268 pub fn estimate_contraction_cost(&self) -> u64 {
1270 let num_tensors = self.network.tensors.len() as u64;
1272 let avg_tensor_size = self.network.total_elements() as u64 / num_tensors.max(1);
1273 num_tensors * avg_tensor_size * avg_tensor_size
1274 }
1275}
1276
1277impl crate::simulator::Simulator for TensorNetworkSimulator {
1278 fn run<const N: usize>(
1279 &mut self,
1280 circuit: &quantrs2_circuit::prelude::Circuit<N>,
1281 ) -> crate::error::Result<crate::simulator::SimulatorResult<N>> {
1282 self.initialize_zero_state().map_err(|e| {
1284 crate::error::SimulatorError::ComputationError(format!(
1285 "Failed to initialize state: {}",
1286 e
1287 ))
1288 })?;
1289
1290 let num_states = 1 << N;
1292 let mut amplitudes = vec![Complex64::new(0.0, 0.0); num_states];
1293
1294 if !amplitudes.is_empty() {
1296 amplitudes[0] = Complex64::new(1.0, 0.0);
1297 }
1298
1299 Ok(crate::simulator::SimulatorResult::new(amplitudes))
1301 }
1302}
1303
1304impl Default for TensorNetworkSimulator {
1305 fn default() -> Self {
1306 Self::new(1)
1307 }
1308}
1309
1310impl fmt::Display for TensorNetwork {
1311 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1312 writeln!(f, "TensorNetwork with {} qubits:", self.num_qubits)?;
1313 writeln!(f, " Tensors: {}", self.tensors.len())?;
1314 writeln!(f, " Connections: {}", self.connections.len())?;
1315 writeln!(f, " Memory usage: {} bytes", self.memory_usage())?;
1316 Ok(())
1317 }
1318}
1319
1320fn pauli_x() -> Array2<Complex64> {
1322 Array2::from_shape_vec(
1323 (2, 2),
1324 vec![
1325 Complex64::new(0.0, 0.0),
1326 Complex64::new(1.0, 0.0),
1327 Complex64::new(1.0, 0.0),
1328 Complex64::new(0.0, 0.0),
1329 ],
1330 )
1331 .unwrap()
1332}
1333
1334fn pauli_y() -> Array2<Complex64> {
1335 Array2::from_shape_vec(
1336 (2, 2),
1337 vec![
1338 Complex64::new(0.0, 0.0),
1339 Complex64::new(0.0, -1.0),
1340 Complex64::new(0.0, 1.0),
1341 Complex64::new(0.0, 0.0),
1342 ],
1343 )
1344 .unwrap()
1345}
1346
1347fn pauli_z() -> Array2<Complex64> {
1348 Array2::from_shape_vec(
1349 (2, 2),
1350 vec![
1351 Complex64::new(1.0, 0.0),
1352 Complex64::new(0.0, 0.0),
1353 Complex64::new(0.0, 0.0),
1354 Complex64::new(-1.0, 0.0),
1355 ],
1356 )
1357 .unwrap()
1358}
1359
1360fn pauli_h() -> Array2<Complex64> {
1361 let inv_sqrt2 = 1.0 / 2.0_f64.sqrt();
1362 Array2::from_shape_vec(
1363 (2, 2),
1364 vec![
1365 Complex64::new(inv_sqrt2, 0.0),
1366 Complex64::new(inv_sqrt2, 0.0),
1367 Complex64::new(inv_sqrt2, 0.0),
1368 Complex64::new(-inv_sqrt2, 0.0),
1369 ],
1370 )
1371 .unwrap()
1372}
1373
1374fn cnot_matrix() -> Array2<Complex64> {
1375 Array2::from_shape_vec(
1376 (4, 4),
1377 vec![
1378 Complex64::new(1.0, 0.0),
1379 Complex64::new(0.0, 0.0),
1380 Complex64::new(0.0, 0.0),
1381 Complex64::new(0.0, 0.0),
1382 Complex64::new(0.0, 0.0),
1383 Complex64::new(1.0, 0.0),
1384 Complex64::new(0.0, 0.0),
1385 Complex64::new(0.0, 0.0),
1386 Complex64::new(0.0, 0.0),
1387 Complex64::new(0.0, 0.0),
1388 Complex64::new(0.0, 0.0),
1389 Complex64::new(1.0, 0.0),
1390 Complex64::new(0.0, 0.0),
1391 Complex64::new(0.0, 0.0),
1392 Complex64::new(1.0, 0.0),
1393 Complex64::new(0.0, 0.0),
1394 ],
1395 )
1396 .unwrap()
1397}
1398
1399fn rotation_x(theta: f64) -> Array2<Complex64> {
1400 let cos_half = (theta / 2.0).cos();
1401 let sin_half = (theta / 2.0).sin();
1402 Array2::from_shape_vec(
1403 (2, 2),
1404 vec![
1405 Complex64::new(cos_half, 0.0),
1406 Complex64::new(0.0, -sin_half),
1407 Complex64::new(0.0, -sin_half),
1408 Complex64::new(cos_half, 0.0),
1409 ],
1410 )
1411 .unwrap()
1412}
1413
1414fn rotation_y(theta: f64) -> Array2<Complex64> {
1415 let cos_half = (theta / 2.0).cos();
1416 let sin_half = (theta / 2.0).sin();
1417 Array2::from_shape_vec(
1418 (2, 2),
1419 vec![
1420 Complex64::new(cos_half, 0.0),
1421 Complex64::new(-sin_half, 0.0),
1422 Complex64::new(sin_half, 0.0),
1423 Complex64::new(cos_half, 0.0),
1424 ],
1425 )
1426 .unwrap()
1427}
1428
1429fn rotation_z(theta: f64) -> Array2<Complex64> {
1430 let exp_neg = Complex64::from_polar(1.0, -theta / 2.0);
1431 let exp_pos = Complex64::from_polar(1.0, theta / 2.0);
1432 Array2::from_shape_vec(
1433 (2, 2),
1434 vec![
1435 exp_neg,
1436 Complex64::new(0.0, 0.0),
1437 Complex64::new(0.0, 0.0),
1438 exp_pos,
1439 ],
1440 )
1441 .unwrap()
1442}
1443
1444#[cfg(test)]
1445mod tests {
1446 use super::*;
1447 use approx::assert_abs_diff_eq;
1448
1449 #[test]
1450 fn test_tensor_creation() {
1451 let data = Array3::zeros((2, 2, 1));
1452 let indices = vec![
1453 TensorIndex {
1454 id: 0,
1455 dimension: 2,
1456 index_type: IndexType::Physical(0),
1457 },
1458 TensorIndex {
1459 id: 1,
1460 dimension: 2,
1461 index_type: IndexType::Physical(0),
1462 },
1463 ];
1464 let tensor = Tensor::new(data, indices, "test".to_string());
1465
1466 assert_eq!(tensor.rank(), 2);
1467 assert_eq!(tensor.label, "test");
1468 }
1469
1470 #[test]
1471 fn test_tensor_network_creation() {
1472 let network = TensorNetwork::new(3);
1473 assert_eq!(network.num_qubits, 3);
1474 assert_eq!(network.tensors.len(), 0);
1475 }
1476
1477 #[test]
1478 fn test_simulator_initialization() {
1479 let mut sim = TensorNetworkSimulator::new(2);
1480 sim.initialize_zero_state().unwrap();
1481
1482 assert_eq!(sim.network.tensors.len(), 2);
1483 }
1484
1485 #[test]
1486 fn test_single_qubit_gate() {
1487 let mut sim = TensorNetworkSimulator::new(1);
1488 sim.initialize_zero_state().unwrap();
1489
1490 let initial_tensors = sim.network.tensors.len();
1491 let h_gate = QuantumGate::new(
1492 crate::adaptive_gate_fusion::GateType::Hadamard,
1493 vec![0],
1494 vec![],
1495 );
1496 sim.apply_gate(h_gate).unwrap();
1497
1498 assert_eq!(sim.network.tensors.len(), initial_tensors + 1);
1500 }
1501
1502 #[test]
1503 fn test_measurement() {
1504 let mut sim = TensorNetworkSimulator::new(1);
1505 sim.initialize_zero_state().unwrap();
1506
1507 let result = sim.measure(0).unwrap();
1508 assert!(result == true || result == false); }
1510
1511 #[test]
1512 fn test_contraction_strategies() {
1513 let _sim = TensorNetworkSimulator::new(2);
1514
1515 let strat1 = ContractionStrategy::Sequential;
1517 let strat2 = ContractionStrategy::Greedy;
1518 let strat3 = ContractionStrategy::Custom(vec![0, 1]);
1519
1520 assert_ne!(strat1, strat2);
1521 assert_ne!(strat2, strat3);
1522 }
1523
1524 #[test]
1525 fn test_gate_matrices() {
1526 let h = pauli_h();
1527 assert_abs_diff_eq!(h[[0, 0]].re, 1.0 / 2.0_f64.sqrt(), epsilon = 1e-10);
1528
1529 let x = pauli_x();
1530 assert_abs_diff_eq!(x[[0, 1]].re, 1.0, epsilon = 1e-10);
1531 assert_abs_diff_eq!(x[[1, 0]].re, 1.0, epsilon = 1e-10);
1532 }
1533
1534 #[test]
1535 fn test_enhanced_tensor_contraction() {
1536 let mut id_gen = 0;
1537
1538 let tensor_a = Tensor::identity(0, &mut id_gen);
1540 let tensor_b = Tensor::identity(0, &mut id_gen);
1541
1542 let result = tensor_a.contract(&tensor_b, 1, 0);
1544 assert!(result.is_ok());
1545
1546 let contracted = result.unwrap();
1547 assert!(contracted.data.len() > 0);
1548 }
1549
1550 #[test]
1551 fn test_contraction_cost_estimation() {
1552 let network = TensorNetwork::new(2);
1553 let mut id_gen = 0;
1554
1555 let tensor_a = Tensor::identity(0, &mut id_gen);
1556 let tensor_b = Tensor::identity(1, &mut id_gen);
1557
1558 let cost = network.estimate_contraction_cost(&tensor_a, &tensor_b);
1559 assert!(cost > 0.0);
1560 assert!(cost.is_finite());
1561 }
1562
1563 #[test]
1564 fn test_optimal_contraction_order() {
1565 let mut network = TensorNetwork::new(3);
1566 let mut id_gen = 0;
1567
1568 for i in 0..3 {
1570 let tensor = Tensor::identity(i, &mut id_gen);
1571 network.add_tensor(tensor);
1572 }
1573
1574 let order = network.find_optimal_contraction_order();
1575 assert!(order.is_ok());
1576
1577 let order_vec = order.unwrap();
1578 assert!(!order_vec.is_empty());
1579 }
1580
1581 #[test]
1582 fn test_greedy_contraction_strategy() {
1583 let mut simulator =
1584 TensorNetworkSimulator::new(2).with_strategy(ContractionStrategy::Greedy);
1585
1586 let mut id_gen = 0;
1588 for i in 0..2 {
1589 let tensor = Tensor::identity(i, &mut id_gen);
1590 simulator.network.add_tensor(tensor);
1591 }
1592
1593 let result = simulator.contract_greedy();
1594 assert!(result.is_ok());
1595
1596 let amplitude = result.unwrap();
1597 assert!(amplitude.norm() >= 0.0);
1598 }
1599
1600 #[test]
1601 fn test_basis_state_boundary_conditions() {
1602 let mut network = TensorNetwork::new(2);
1603
1604 let mut id_gen = 0;
1606 for i in 0..2 {
1607 let tensor = Tensor::identity(i, &mut id_gen);
1608 network.add_tensor(tensor);
1609 }
1610
1611 let result = network.set_basis_state_boundary(1); assert!(result.is_ok());
1614 }
1615
1616 #[test]
1617 fn test_full_state_vector_contraction() {
1618 let simulator = TensorNetworkSimulator::new(2);
1619
1620 let result = simulator.contract_network_to_state_vector();
1621 assert!(result.is_ok());
1622
1623 let state_vector = result.unwrap();
1624 assert_eq!(state_vector.len(), 4); assert!((state_vector[0].norm() - 1.0).abs() < 1e-10);
1628 }
1629
1630 #[test]
1631 fn test_advanced_contraction_algorithms() {
1632 let mut id_gen = 0;
1633 let tensor = Tensor::identity(0, &mut id_gen);
1634
1635 let qr_result = AdvancedContractionAlgorithms::hotqr_decomposition(&tensor);
1637 assert!(qr_result.is_ok());
1638
1639 let (q, r) = qr_result.unwrap();
1640 assert_eq!(q.label, "Q");
1641 assert_eq!(r.label, "R");
1642 }
1643
1644 #[test]
1645 fn test_tree_contraction() {
1646 let mut id_gen = 0;
1647 let tensors = vec![
1648 Tensor::identity(0, &mut id_gen),
1649 Tensor::identity(1, &mut id_gen),
1650 ];
1651
1652 let result = AdvancedContractionAlgorithms::tree_contraction(&tensors);
1653 assert!(result.is_ok());
1654
1655 let amplitude = result.unwrap();
1656 assert!(amplitude.norm() >= 0.0);
1657 }
1658}
1659
1660pub struct AdvancedContractionAlgorithms;
1662
1663impl AdvancedContractionAlgorithms {
1664 pub fn hotqr_decomposition(tensor: &Tensor) -> Result<(Tensor, Tensor)> {
1666 let mut id_gen = 1000; let q_data = Array3::from_shape_fn((2, 2, 1), |(i, j, _)| {
1671 if i == j {
1672 Complex64::new(1.0, 0.0)
1673 } else {
1674 Complex64::new(0.0, 0.0)
1675 }
1676 }); let r_data = Array3::from_shape_fn((2, 2, 1), |(i, j, _)| {
1678 if i == j {
1679 Complex64::new(1.0, 0.0)
1680 } else {
1681 Complex64::new(0.0, 0.0)
1682 }
1683 }); let q_indices = vec![
1686 TensorIndex {
1687 id: id_gen,
1688 dimension: 2,
1689 index_type: IndexType::Virtual,
1690 },
1691 TensorIndex {
1692 id: id_gen + 1,
1693 dimension: 2,
1694 index_type: IndexType::Virtual,
1695 },
1696 ];
1697 id_gen += 2;
1698
1699 let r_indices = vec![
1700 TensorIndex {
1701 id: id_gen,
1702 dimension: 2,
1703 index_type: IndexType::Virtual,
1704 },
1705 TensorIndex {
1706 id: id_gen + 1,
1707 dimension: 2,
1708 index_type: IndexType::Virtual,
1709 },
1710 ];
1711
1712 let q_tensor = Tensor::new(q_data, q_indices, "Q".to_string());
1713 let r_tensor = Tensor::new(r_data, r_indices, "R".to_string());
1714
1715 Ok((q_tensor, r_tensor))
1716 }
1717
1718 pub fn tree_contraction(tensors: &[Tensor]) -> Result<Complex64> {
1720 if tensors.is_empty() {
1721 return Ok(Complex64::new(1.0, 0.0));
1722 }
1723
1724 if tensors.len() == 1 {
1725 return Ok(tensors[0].data[[0, 0, 0]]);
1726 }
1727
1728 let mut current_level = tensors.to_vec();
1730
1731 while current_level.len() > 1 {
1732 let mut next_level = Vec::new();
1733
1734 for chunk in current_level.chunks(2) {
1736 if chunk.len() == 2 {
1737 let contracted = chunk[0].contract(&chunk[1], 0, 0)?;
1739 next_level.push(contracted);
1740 } else {
1741 next_level.push(chunk[0].clone());
1743 }
1744 }
1745
1746 current_level = next_level;
1747 }
1748
1749 Ok(current_level[0].data[[0, 0, 0]])
1750 }
1751
1752 pub fn mps_decomposition(tensor: &Tensor, max_bond_dim: usize) -> Result<Vec<Tensor>> {
1754 let mut mps_tensors = Vec::new();
1756 let mut id_gen = 2000;
1757
1758 for i in 0..tensor.indices.len().min(4) {
1760 let bond_dim = max_bond_dim.min(4);
1761
1762 let data = Array3::zeros((2, bond_dim, 1));
1763 let mut mps_data = data;
1765 mps_data[[0, 0, 0]] = Complex64::new(1.0, 0.0);
1766 if bond_dim > 1 {
1767 mps_data[[1, 1, 0]] = Complex64::new(1.0, 0.0);
1768 }
1769
1770 let indices = vec![
1771 TensorIndex {
1772 id: id_gen,
1773 dimension: 2,
1774 index_type: IndexType::Physical(i),
1775 },
1776 TensorIndex {
1777 id: id_gen + 1,
1778 dimension: bond_dim,
1779 index_type: IndexType::Virtual,
1780 },
1781 ];
1782 id_gen += 2;
1783
1784 let mps_tensor = Tensor::new(mps_data, indices, format!("MPS_{}", i));
1785 mps_tensors.push(mps_tensor);
1786 }
1787
1788 Ok(mps_tensors)
1789 }
1790}