quantrs2_core/
tensor_network.rs

1//! Tensor Network representations for quantum circuits
2//!
3//! This module provides tensor network representations and operations for quantum circuits,
4//! leveraging SciRS2 for efficient tensor manipulations and contractions.
5
6use crate::{
7    error::{QuantRS2Error, QuantRS2Result},
8    gate::GateOp,
9    linalg_stubs::svd,
10    register::Register,
11};
12use scirs2_core::ndarray::{Array, Array2, ArrayD, IxDyn};
13use scirs2_core::Complex;
14// use scirs2_linalg::svd;
15use std::collections::{HashMap, HashSet};
16
17/// Type alias for complex numbers
18type Complex64 = Complex<f64>;
19
20/// A tensor in the network
21#[derive(Debug, Clone)]
22pub struct Tensor {
23    /// Unique identifier for the tensor
24    pub id: usize,
25    /// The tensor data
26    pub data: ArrayD<Complex64>,
27    /// Labels for each index of the tensor
28    pub indices: Vec<String>,
29    /// Shape of the tensor
30    pub shape: Vec<usize>,
31}
32
33impl Tensor {
34    /// Create a new tensor
35    pub fn new(id: usize, data: ArrayD<Complex64>, indices: Vec<String>) -> Self {
36        let shape = data.shape().to_vec();
37        Self {
38            id,
39            data,
40            indices,
41            shape,
42        }
43    }
44
45    /// Create a tensor from a 2D array (matrix)
46    pub fn from_matrix(
47        id: usize,
48        matrix: Array2<Complex64>,
49        in_idx: String,
50        out_idx: String,
51    ) -> Self {
52        let shape = matrix.shape().to_vec();
53        let data = matrix.into_dyn();
54        Self {
55            id,
56            data,
57            indices: vec![in_idx, out_idx],
58            shape,
59        }
60    }
61
62    /// Create a qubit tensor in |0⟩ state
63    pub fn qubit_zero(id: usize, idx: String) -> Self {
64        let mut data = Array::zeros(IxDyn(&[2]));
65        data[[0]] = Complex64::new(1.0, 0.0);
66        Self {
67            id,
68            data,
69            indices: vec![idx],
70            shape: vec![2],
71        }
72    }
73
74    /// Create a qubit tensor in |1⟩ state
75    pub fn qubit_one(id: usize, idx: String) -> Self {
76        let mut data = Array::zeros(IxDyn(&[2]));
77        data[[1]] = Complex64::new(1.0, 0.0);
78        Self {
79            id,
80            data,
81            indices: vec![idx],
82            shape: vec![2],
83        }
84    }
85
86    /// Create a tensor from an ndarray with specified indices
87    pub fn from_array<D>(
88        array: scirs2_core::ndarray::ArrayBase<scirs2_core::ndarray::OwnedRepr<Complex64>, D>,
89        indices: Vec<usize>,
90    ) -> Self
91    where
92        D: scirs2_core::ndarray::Dimension,
93    {
94        let shape = array.shape().to_vec();
95        let data = array.into_dyn();
96        let index_labels: Vec<String> = indices.iter().map(|i| format!("idx_{i}")).collect();
97        Self {
98            id: 0, // Default ID
99            data,
100            indices: index_labels,
101            shape,
102        }
103    }
104
105    /// Get the rank (number of indices) of the tensor
106    pub fn rank(&self) -> usize {
107        self.indices.len()
108    }
109
110    /// Get a reference to the tensor data
111    pub const fn tensor(&self) -> &ArrayD<Complex64> {
112        &self.data
113    }
114
115    /// Get the number of dimensions
116    pub fn ndim(&self) -> usize {
117        self.data.ndim()
118    }
119
120    /// Contract this tensor with another over specified indices
121    pub fn contract(&self, other: &Self, self_idx: &str, other_idx: &str) -> QuantRS2Result<Self> {
122        // Find the positions of the indices to contract
123        let self_pos = self
124            .indices
125            .iter()
126            .position(|s| s == self_idx)
127            .ok_or_else(|| {
128                QuantRS2Error::InvalidInput(format!("Index {self_idx} not found in tensor"))
129            })?;
130        let other_pos = other
131            .indices
132            .iter()
133            .position(|s| s == other_idx)
134            .ok_or_else(|| {
135                QuantRS2Error::InvalidInput(format!("Index {other_idx} not found in tensor"))
136            })?;
137
138        // Check dimensions match
139        if self.shape[self_pos] != other.shape[other_pos] {
140            return Err(QuantRS2Error::InvalidInput(format!(
141                "Cannot contract indices with different dimensions: {} vs {}",
142                self.shape[self_pos], other.shape[other_pos]
143            )));
144        }
145
146        // Perform tensor contraction using einsum-like operation
147        let contracted = self.contract_indices(&other, self_pos, other_pos)?;
148
149        // Build new index list
150        let mut new_indices = Vec::new();
151        for (i, idx) in self.indices.iter().enumerate() {
152            if i != self_pos {
153                new_indices.push(idx.clone());
154            }
155        }
156        for (i, idx) in other.indices.iter().enumerate() {
157            if i != other_pos {
158                new_indices.push(idx.clone());
159            }
160        }
161
162        Ok(Self::new(
163            self.id.max(other.id) + 1,
164            contracted,
165            new_indices,
166        ))
167    }
168
169    /// Perform the actual index contraction
170    fn contract_indices(
171        &self,
172        other: &Self,
173        self_idx: usize,
174        other_idx: usize,
175    ) -> QuantRS2Result<ArrayD<Complex64>> {
176        // Reshape tensors for matrix multiplication
177        let self_shape = self.data.shape();
178        let other_shape = other.data.shape();
179
180        // Calculate dimensions for reshaping
181        let mut self_left_dims = 1;
182        let mut self_right_dims = 1;
183        for i in 0..self_idx {
184            self_left_dims *= self_shape[i];
185        }
186        for i in (self_idx + 1)..self_shape.len() {
187            self_right_dims *= self_shape[i];
188        }
189
190        let mut other_left_dims = 1;
191        let mut other_right_dims = 1;
192        for i in 0..other_idx {
193            other_left_dims *= other_shape[i];
194        }
195        for i in (other_idx + 1)..other_shape.len() {
196            other_right_dims *= other_shape[i];
197        }
198
199        let contract_dim = self_shape[self_idx];
200
201        // Reshape to matrices
202        let self_mat = self
203            .data
204            .view()
205            .into_shape_with_order((self_left_dims, contract_dim * self_right_dims))
206            .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
207            .to_owned();
208        let other_mat = other
209            .data
210            .view()
211            .into_shape_with_order((other_left_dims * contract_dim, other_right_dims))
212            .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
213            .to_owned();
214
215        // Perform contraction via matrix multiplication
216        let _result_mat: Array2<Complex64> = Array2::zeros((
217            self_left_dims * self_right_dims,
218            other_left_dims * other_right_dims,
219        ));
220
221        // This is a simplified contraction - a full implementation would be more efficient
222        let mut result_vec = Vec::new();
223        for i in 0..self_left_dims {
224            for j in 0..self_right_dims {
225                for k in 0..other_left_dims {
226                    for l in 0..other_right_dims {
227                        let mut sum = Complex64::new(0.0, 0.0);
228                        for c in 0..contract_dim {
229                            // Commented out - index calculations unused
230                            // let _ = i * contract_dim * self_right_dims + c * self_right_dims + j;
231                            // let _ = k * contract_dim * other_right_dims + c * other_right_dims + l;
232                            sum += self_mat[[i, c * self_right_dims + j]]
233                                * other_mat[[k * contract_dim + c, l]];
234                        }
235                        result_vec.push(sum);
236                    }
237                }
238            }
239        }
240
241        // Build result shape
242        let mut result_shape = Vec::new();
243        for i in 0..self_idx {
244            result_shape.push(self_shape[i]);
245        }
246        for i in (self_idx + 1)..self_shape.len() {
247            result_shape.push(self_shape[i]);
248        }
249        for i in 0..other_idx {
250            result_shape.push(other_shape[i]);
251        }
252        for i in (other_idx + 1)..other_shape.len() {
253            result_shape.push(other_shape[i]);
254        }
255
256        ArrayD::from_shape_vec(IxDyn(&result_shape), result_vec)
257            .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))
258    }
259
260    /// Apply SVD decomposition to split tensor along specified index
261    pub fn svd_decompose(
262        &self,
263        idx: usize,
264        max_rank: Option<usize>,
265    ) -> QuantRS2Result<(Self, Self)> {
266        if idx >= self.rank() {
267            return Err(QuantRS2Error::InvalidInput(format!(
268                "Index {} out of bounds for tensor with rank {}",
269                idx,
270                self.rank()
271            )));
272        }
273
274        // Reshape tensor into matrix
275        let shape = self.data.shape();
276        let mut left_dim = 1;
277        let mut right_dim = 1;
278
279        for i in 0..=idx {
280            left_dim *= shape[i];
281        }
282        for i in (idx + 1)..shape.len() {
283            right_dim *= shape[i];
284        }
285
286        // Convert to matrix
287        let matrix = self
288            .data
289            .view()
290            .into_shape_with_order((left_dim, right_dim))
291            .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
292            .to_owned();
293
294        // Perform SVD using SciRS2
295        let real_matrix = matrix.mapv(|c| c.re);
296        let (u, s, vt) = svd(&real_matrix.view(), false, None)
297            .map_err(|e| QuantRS2Error::ComputationError(format!("SVD failed: {e:?}")))?;
298
299        // Determine rank to keep
300        let rank = if let Some(max_r) = max_rank {
301            max_r.min(s.len())
302        } else {
303            s.len()
304        };
305
306        // Truncate based on rank
307        let u_trunc = u.slice(scirs2_core::ndarray::s![.., ..rank]).to_owned();
308        let s_trunc = s.slice(scirs2_core::ndarray::s![..rank]).to_owned();
309        let vt_trunc = vt.slice(scirs2_core::ndarray::s![..rank, ..]).to_owned();
310
311        // Create S matrix
312        let mut s_mat = Array2::zeros((rank, rank));
313        for i in 0..rank {
314            s_mat[[i, i]] = Complex64::new(s_trunc[i].sqrt(), 0.0);
315        }
316
317        // Multiply U * sqrt(S) and sqrt(S) * V^T
318        let left_data = u_trunc.mapv(|x| Complex64::new(x, 0.0)).dot(&s_mat);
319        let right_data = s_mat.dot(&vt_trunc.mapv(|x| Complex64::new(x, 0.0)));
320
321        // Create new tensors with appropriate shapes and indices
322        let mut left_indices = self.indices[..=idx].to_vec();
323        left_indices.push(format!("bond_{}", self.id));
324
325        let mut right_indices = vec![format!("bond_{}", self.id)];
326        right_indices.extend_from_slice(&self.indices[(idx + 1)..]);
327
328        let left_tensor = Self::new(self.id * 2, left_data.into_dyn(), left_indices);
329
330        let right_tensor = Self::new(self.id * 2 + 1, right_data.into_dyn(), right_indices);
331
332        Ok((left_tensor, right_tensor))
333    }
334}
335
336/// Edge in the tensor network
337#[derive(Debug, Clone, PartialEq, Eq, Hash)]
338pub struct TensorEdge {
339    /// First tensor ID
340    pub tensor1: usize,
341    /// Index on first tensor
342    pub index1: String,
343    /// Second tensor ID
344    pub tensor2: usize,
345    /// Index on second tensor
346    pub index2: String,
347}
348
349/// Tensor network representation
350#[derive(Debug)]
351pub struct TensorNetwork {
352    /// Tensors in the network
353    pub tensors: HashMap<usize, Tensor>,
354    /// Edges connecting tensors
355    pub edges: Vec<TensorEdge>,
356    /// Open indices (not connected to other tensors)
357    pub open_indices: HashMap<usize, Vec<String>>,
358    /// Next available tensor ID
359    next_id: usize,
360}
361
362impl TensorNetwork {
363    /// Create a new empty tensor network
364    pub fn new() -> Self {
365        Self {
366            tensors: HashMap::new(),
367            edges: Vec::new(),
368            open_indices: HashMap::new(),
369            next_id: 0,
370        }
371    }
372
373    /// Add a tensor to the network
374    pub fn add_tensor(&mut self, tensor: Tensor) -> usize {
375        let id = tensor.id;
376        self.open_indices.insert(id, tensor.indices.clone());
377        self.tensors.insert(id, tensor);
378        self.next_id = self.next_id.max(id + 1);
379        id
380    }
381
382    /// Connect two tensor indices
383    pub fn connect(
384        &mut self,
385        tensor1: usize,
386        index1: String,
387        tensor2: usize,
388        index2: String,
389    ) -> QuantRS2Result<()> {
390        // Verify tensors exist
391        if !self.tensors.contains_key(&tensor1) {
392            return Err(QuantRS2Error::InvalidInput(format!(
393                "Tensor {tensor1} not found"
394            )));
395        }
396        if !self.tensors.contains_key(&tensor2) {
397            return Err(QuantRS2Error::InvalidInput(format!(
398                "Tensor {tensor2} not found"
399            )));
400        }
401
402        // Verify indices exist and match dimensions
403        let t1 = &self.tensors[&tensor1];
404        let t2 = &self.tensors[&tensor2];
405
406        let idx1_pos = t1
407            .indices
408            .iter()
409            .position(|s| s == &index1)
410            .ok_or_else(|| {
411                QuantRS2Error::InvalidInput(format!("Index {index1} not found in tensor {tensor1}"))
412            })?;
413        let idx2_pos = t2
414            .indices
415            .iter()
416            .position(|s| s == &index2)
417            .ok_or_else(|| {
418                QuantRS2Error::InvalidInput(format!("Index {index2} not found in tensor {tensor2}"))
419            })?;
420
421        if t1.shape[idx1_pos] != t2.shape[idx2_pos] {
422            return Err(QuantRS2Error::InvalidInput(format!(
423                "Connected indices must have same dimension: {} vs {}",
424                t1.shape[idx1_pos], t2.shape[idx2_pos]
425            )));
426        }
427
428        // Add edge
429        self.edges.push(TensorEdge {
430            tensor1,
431            index1: index1.clone(),
432            tensor2,
433            index2: index2.clone(),
434        });
435
436        // Remove from open indices
437        if let Some(indices) = self.open_indices.get_mut(&tensor1) {
438            indices.retain(|s| s != &index1);
439        }
440        if let Some(indices) = self.open_indices.get_mut(&tensor2) {
441            indices.retain(|s| s != &index2);
442        }
443
444        Ok(())
445    }
446
447    /// Find optimal contraction order using greedy algorithm
448    pub fn find_contraction_order(&self) -> Vec<(usize, usize)> {
449        // Simple greedy algorithm: contract pairs that minimize intermediate tensor size
450        let mut remaining_tensors: HashSet<_> = self.tensors.keys().copied().collect();
451        let mut order = Vec::new();
452
453        // Build adjacency list
454        let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
455        for edge in &self.edges {
456            adjacency
457                .entry(edge.tensor1)
458                .or_insert_with(Vec::new)
459                .push(edge.tensor2);
460            adjacency
461                .entry(edge.tensor2)
462                .or_insert_with(Vec::new)
463                .push(edge.tensor1);
464        }
465
466        while remaining_tensors.len() > 1 {
467            let mut best_pair = None;
468            let mut min_cost = usize::MAX;
469
470            // Consider all pairs of connected tensors
471            for &t1 in &remaining_tensors {
472                if let Some(neighbors) = adjacency.get(&t1) {
473                    for &t2 in neighbors {
474                        if t2 > t1 && remaining_tensors.contains(&t2) {
475                            // Estimate cost as product of remaining dimensions
476                            let cost = self.estimate_contraction_cost(t1, t2);
477                            if cost < min_cost {
478                                min_cost = cost;
479                                best_pair = Some((t1, t2));
480                            }
481                        }
482                    }
483                }
484            }
485
486            if let Some((t1, t2)) = best_pair {
487                order.push((t1, t2));
488                remaining_tensors.remove(&t1);
489                remaining_tensors.remove(&t2);
490
491                // Add a virtual tensor representing the contraction result
492                let virtual_id = self.next_id + order.len();
493                remaining_tensors.insert(virtual_id);
494
495                // Update adjacency for virtual tensor
496                let mut virtual_neighbors = HashSet::new();
497                if let Some(n1) = adjacency.get(&t1) {
498                    virtual_neighbors.extend(
499                        n1.iter()
500                            .filter(|&&n| n != t2 && remaining_tensors.contains(&n)),
501                    );
502                }
503                if let Some(n2) = adjacency.get(&t2) {
504                    virtual_neighbors.extend(
505                        n2.iter()
506                            .filter(|&&n| n != t1 && remaining_tensors.contains(&n)),
507                    );
508                }
509                adjacency.insert(virtual_id, virtual_neighbors.into_iter().collect());
510            } else {
511                break;
512            }
513        }
514
515        order
516    }
517
518    /// Estimate the computational cost of contracting two tensors
519    const fn estimate_contraction_cost(&self, _t1: usize, _t2: usize) -> usize {
520        // Cost is roughly the product of all dimensions in the result
521        // This is a simplified estimate
522        1000 // Placeholder
523    }
524
525    /// Contract the entire network to a single tensor
526    pub fn contract_all(&mut self) -> QuantRS2Result<Tensor> {
527        if self.tensors.is_empty() {
528            return Err(QuantRS2Error::InvalidInput(
529                "Cannot contract empty tensor network".into(),
530            ));
531        }
532
533        if self.tensors.len() == 1 {
534            return self
535                .tensors
536                .values()
537                .next()
538                .map(|t| t.clone())
539                .ok_or_else(|| {
540                    QuantRS2Error::InvalidInput("Single tensor expected but not found".into())
541                });
542        }
543
544        // Find contraction order
545        let order = self.find_contraction_order();
546
547        // Execute contractions
548        let mut tensor_map = self.tensors.clone();
549        let mut next_id = self.next_id;
550
551        for (t1_id, t2_id) in order {
552            // Find the edge connecting these tensors
553            let edge = self
554                .edges
555                .iter()
556                .find(|e| {
557                    (e.tensor1 == t1_id && e.tensor2 == t2_id)
558                        || (e.tensor1 == t2_id && e.tensor2 == t1_id)
559                })
560                .ok_or_else(|| QuantRS2Error::InvalidInput("Tensors not connected".into()))?;
561
562            let t1 = tensor_map
563                .remove(&t1_id)
564                .ok_or_else(|| QuantRS2Error::InvalidInput("Tensor not found".into()))?;
565            let t2 = tensor_map
566                .remove(&t2_id)
567                .ok_or_else(|| QuantRS2Error::InvalidInput("Tensor not found".into()))?;
568
569            // Contract tensors
570            let contracted = if edge.tensor1 == t1_id {
571                t1.contract(&t2, &edge.index1, &edge.index2)?
572            } else {
573                t1.contract(&t2, &edge.index2, &edge.index1)?
574            };
575
576            // Add result back
577            let mut new_tensor = contracted;
578            new_tensor.id = next_id;
579            tensor_map.insert(next_id, new_tensor);
580            next_id += 1;
581        }
582
583        // Return the final tensor
584        tensor_map
585            .into_values()
586            .next()
587            .ok_or_else(|| QuantRS2Error::InvalidInput("Contraction failed".into()))
588    }
589
590    /// Apply Matrix Product State (MPS) decomposition
591    pub const fn to_mps(&self, _max_bond_dim: Option<usize>) -> QuantRS2Result<Vec<Tensor>> {
592        // This would decompose the network into a chain of tensors
593        // For now, return a placeholder
594        Ok(vec![])
595    }
596
597    /// Apply Matrix Product Operator (MPO) representation
598    pub const fn apply_mpo(&mut self, _mpo: &[Tensor], _qubits: &[usize]) -> QuantRS2Result<()> {
599        // Apply an MPO to specified qubits
600        Ok(())
601    }
602
603    /// Get a reference to the tensors in the network
604    pub fn tensors(&self) -> Vec<&Tensor> {
605        self.tensors.values().collect()
606    }
607
608    /// Get a reference to a tensor by ID
609    pub fn tensor(&self, id: usize) -> Option<&Tensor> {
610        self.tensors.get(&id)
611    }
612}
613
614/// Builder for quantum circuits as tensor networks
615pub struct TensorNetworkBuilder {
616    network: TensorNetwork,
617    qubit_indices: HashMap<usize, String>,
618    current_indices: HashMap<usize, String>,
619}
620
621impl TensorNetworkBuilder {
622    /// Create a new tensor network builder for n qubits
623    pub fn new(num_qubits: usize) -> Self {
624        let mut network = TensorNetwork::new();
625        let mut qubit_indices = HashMap::new();
626        let mut current_indices = HashMap::new();
627
628        // Initialize qubits in |0⟩ state
629        for i in 0..num_qubits {
630            let idx = format!("q{i}_0");
631            let tensor = Tensor::qubit_zero(i, idx.clone());
632            network.add_tensor(tensor);
633            qubit_indices.insert(i, idx.clone());
634            current_indices.insert(i, idx);
635        }
636
637        Self {
638            network,
639            qubit_indices,
640            current_indices,
641        }
642    }
643
644    /// Apply a single-qubit gate
645    pub fn apply_single_qubit_gate(
646        &mut self,
647        gate: &dyn GateOp,
648        qubit: usize,
649    ) -> QuantRS2Result<()> {
650        let matrix_vec = gate.matrix()?;
651        let matrix = Array2::from_shape_vec((2, 2), matrix_vec)
652            .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?;
653
654        // Create gate tensor
655        let in_idx = self.current_indices[&qubit].clone();
656        let out_idx = format!("q{}_{}", qubit, self.network.next_id);
657        let gate_tensor = Tensor::from_matrix(
658            self.network.next_id,
659            matrix,
660            in_idx.clone(),
661            out_idx.clone(),
662        );
663
664        // Add to network
665        let gate_id = self.network.add_tensor(gate_tensor);
666
667        // Connect to previous tensor on this qubit
668        if let Some(prev_tensor) = self.find_tensor_with_index(&in_idx) {
669            self.network
670                .connect(prev_tensor, in_idx.clone(), gate_id, in_idx)?;
671        }
672
673        // Update current index
674        self.current_indices.insert(qubit, out_idx);
675
676        Ok(())
677    }
678
679    /// Apply a two-qubit gate
680    pub fn apply_two_qubit_gate(
681        &mut self,
682        gate: &dyn GateOp,
683        qubit1: usize,
684        qubit2: usize,
685    ) -> QuantRS2Result<()> {
686        let matrix_vec = gate.matrix()?;
687        let matrix = Array2::from_shape_vec((4, 4), matrix_vec)
688            .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?;
689
690        // Reshape to rank-4 tensor
691        let tensor_data = matrix
692            .into_shape_with_order((2, 2, 2, 2))
693            .map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
694            .into_dyn();
695
696        // Create indices
697        let in1_idx = self.current_indices[&qubit1].clone();
698        let in2_idx = self.current_indices[&qubit2].clone();
699        let out1_idx = format!("q{}_{}", qubit1, self.network.next_id);
700        let out2_idx = format!("q{}_{}", qubit2, self.network.next_id);
701
702        let gate_tensor = Tensor::new(
703            self.network.next_id,
704            tensor_data,
705            vec![
706                in1_idx.clone(),
707                in2_idx.clone(),
708                out1_idx.clone(),
709                out2_idx.clone(),
710            ],
711        );
712
713        // Add to network
714        let gate_id = self.network.add_tensor(gate_tensor);
715
716        // Connect to previous tensors
717        if let Some(prev1) = self.find_tensor_with_index(&in1_idx) {
718            self.network
719                .connect(prev1, in1_idx.clone(), gate_id, in1_idx)?;
720        }
721        if let Some(prev2) = self.find_tensor_with_index(&in2_idx) {
722            self.network
723                .connect(prev2, in2_idx.clone(), gate_id, in2_idx)?;
724        }
725
726        // Update current indices
727        self.current_indices.insert(qubit1, out1_idx);
728        self.current_indices.insert(qubit2, out2_idx);
729
730        Ok(())
731    }
732
733    /// Find tensor that has the given index as output
734    fn find_tensor_with_index(&self, index: &str) -> Option<usize> {
735        for (id, tensor) in &self.network.tensors {
736            if tensor.indices.iter().any(|idx| idx == index) {
737                return Some(*id);
738            }
739        }
740        None
741    }
742
743    /// Build the final tensor network
744    pub fn build(self) -> TensorNetwork {
745        self.network
746    }
747
748    /// Contract the network and return the quantum state
749    #[must_use]
750    pub fn to_statevector(&mut self) -> QuantRS2Result<Vec<Complex64>> {
751        let final_tensor = self.network.contract_all()?;
752        Ok(final_tensor.data.into_raw_vec_and_offset().0)
753    }
754}
755
756/// Quantum circuit simulation using tensor networks
757pub struct TensorNetworkSimulator {
758    /// Maximum bond dimension for MPS
759    max_bond_dim: usize,
760    /// Use SVD compression
761    use_compression: bool,
762    /// Parallelization threshold
763    parallel_threshold: usize,
764}
765
766impl TensorNetworkSimulator {
767    /// Create a new tensor network simulator
768    pub const fn new() -> Self {
769        Self {
770            max_bond_dim: 64,
771            use_compression: true,
772            parallel_threshold: 1000,
773        }
774    }
775
776    /// Set maximum bond dimension
777    #[must_use]
778    pub const fn with_max_bond_dim(mut self, dim: usize) -> Self {
779        self.max_bond_dim = dim;
780        self
781    }
782
783    /// Enable or disable compression
784    #[must_use]
785    pub const fn with_compression(mut self, compress: bool) -> Self {
786        self.use_compression = compress;
787        self
788    }
789
790    /// Simulate a quantum circuit
791    pub fn simulate<const N: usize>(
792        &self,
793        gates: &[Box<dyn GateOp>],
794    ) -> QuantRS2Result<Register<N>> {
795        let mut builder = TensorNetworkBuilder::new(N);
796
797        // Apply gates
798        for gate in gates {
799            let qubits = gate.qubits();
800            match qubits.len() {
801                1 => builder.apply_single_qubit_gate(gate.as_ref(), qubits[0].0 as usize)?,
802                2 => builder.apply_two_qubit_gate(
803                    gate.as_ref(),
804                    qubits[0].0 as usize,
805                    qubits[1].0 as usize,
806                )?,
807                _ => {
808                    return Err(QuantRS2Error::UnsupportedOperation(format!(
809                        "Gates with {} qubits not supported in tensor network",
810                        qubits.len()
811                    )))
812                }
813            }
814        }
815
816        // Contract to get statevector
817        let amplitudes = builder.to_statevector()?;
818        Register::with_amplitudes(amplitudes)
819    }
820}
821
822/// Optimized contraction strategies
823pub mod contraction_optimization {
824    use super::*;
825
826    /// Dynamic programming algorithm for optimal contraction order
827    pub struct DynamicProgrammingOptimizer {
828        memo: HashMap<Vec<usize>, (usize, Vec<(usize, usize)>)>,
829    }
830
831    impl DynamicProgrammingOptimizer {
832        pub fn new() -> Self {
833            Self {
834                memo: HashMap::new(),
835            }
836        }
837
838        /// Find optimal contraction order using dynamic programming
839        pub fn optimize(&mut self, network: &TensorNetwork) -> Vec<(usize, usize)> {
840            let tensor_ids: Vec<_> = network.tensors.keys().copied().collect();
841            self.find_optimal_order(&tensor_ids, network).1
842        }
843
844        fn find_optimal_order(
845            &mut self,
846            tensors: &[usize],
847            network: &TensorNetwork,
848        ) -> (usize, Vec<(usize, usize)>) {
849            if tensors.len() <= 1 {
850                return (0, vec![]);
851            }
852
853            let key = tensors.to_vec();
854            if let Some(result) = self.memo.get(&key) {
855                return result.clone();
856            }
857
858            let mut best_cost = usize::MAX;
859            let mut best_order = vec![];
860
861            // Try all possible pairings
862            for i in 0..tensors.len() {
863                for j in (i + 1)..tensors.len() {
864                    // Check if tensors are connected
865                    if self.are_connected(tensors[i], tensors[j], network) {
866                        let cost = network.estimate_contraction_cost(tensors[i], tensors[j]);
867
868                        // Remaining tensors after contraction
869                        let mut remaining = vec![];
870                        for (k, &t) in tensors.iter().enumerate() {
871                            if k != i && k != j {
872                                remaining.push(t);
873                            }
874                        }
875                        remaining.push(network.next_id + remaining.len()); // Virtual tensor
876
877                        let (sub_cost, sub_order) = self.find_optimal_order(&remaining, network);
878                        let total_cost = cost + sub_cost;
879
880                        if total_cost < best_cost {
881                            best_cost = total_cost;
882                            best_order = vec![(tensors[i], tensors[j])];
883                            best_order.extend(sub_order);
884                        }
885                    }
886                }
887            }
888
889            self.memo.insert(key, (best_cost, best_order.clone()));
890            (best_cost, best_order)
891        }
892
893        fn are_connected(&self, t1: usize, t2: usize, network: &TensorNetwork) -> bool {
894            network.edges.iter().any(|e| {
895                (e.tensor1 == t1 && e.tensor2 == t2) || (e.tensor1 == t2 && e.tensor2 == t1)
896            })
897        }
898    }
899}
900
901#[cfg(test)]
902mod tests {
903    use super::*;
904
905    #[test]
906    fn test_tensor_creation() {
907        let data = ArrayD::zeros(IxDyn(&[2, 2]));
908        let tensor = Tensor::new(0, data, vec!["in".to_string(), "out".to_string()]);
909        assert_eq!(tensor.rank(), 2);
910        assert_eq!(tensor.shape, vec![2, 2]);
911    }
912
913    #[test]
914    fn test_qubit_tensors() {
915        let t0 = Tensor::qubit_zero(0, "q0".to_string());
916        assert_eq!(t0.data[[0]], Complex64::new(1.0, 0.0));
917        assert_eq!(t0.data[[1]], Complex64::new(0.0, 0.0));
918
919        let t1 = Tensor::qubit_one(1, "q1".to_string());
920        assert_eq!(t1.data[[0]], Complex64::new(0.0, 0.0));
921        assert_eq!(t1.data[[1]], Complex64::new(1.0, 0.0));
922    }
923
924    #[test]
925    fn test_tensor_network_builder() {
926        let builder = TensorNetworkBuilder::new(2);
927        assert_eq!(builder.network.tensors.len(), 2);
928    }
929
930    #[test]
931    fn test_network_connection() {
932        let mut network = TensorNetwork::new();
933
934        let t1 = Tensor::qubit_zero(0, "q0".to_string());
935        let t2 = Tensor::qubit_zero(1, "q1".to_string());
936
937        let id1 = network.add_tensor(t1);
938        let id2 = network.add_tensor(t2);
939
940        // Should fail - indices don't exist on these tensors
941        assert!(network
942            .connect(id1, "bond".to_string(), id2, "bond".to_string())
943            .is_err());
944    }
945}