quantrs2_sim/
mps_simulator.rs

1//! Matrix Product State (MPS) quantum simulator
2//!
3//! This module implements an efficient quantum simulator using the Matrix Product State
4//! representation, which is particularly effective for simulating quantum systems with
5//! limited entanglement.
6
7use quantrs2_circuit::builder::{Circuit, Simulator};
8use quantrs2_core::{
9    error::{QuantRS2Error, QuantRS2Result},
10    gate::GateOp,
11    prelude::QubitId,
12    register::Register,
13};
14use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
15use scirs2_core::Complex64;
16
17/// MPS tensor for a single qubit
18#[derive(Debug, Clone)]
19struct MPSTensor {
20    /// The tensor data: `left_bond` x physical x `right_bond`
21    data: Array3<Complex64>,
22    /// Left bond dimension
23    left_dim: usize,
24    /// Right bond dimension
25    right_dim: usize,
26}
27
28impl MPSTensor {
29    /// Create a new MPS tensor
30    fn new(data: Array3<Complex64>) -> Self {
31        let shape = data.shape();
32        Self {
33            left_dim: shape[0],
34            right_dim: shape[2],
35            data,
36        }
37    }
38
39    /// Create initial tensor for |0> state
40    fn zero_state(is_first: bool, is_last: bool) -> Self {
41        let data = if is_first && is_last {
42            // Single qubit: 1x2x1 tensor
43            let mut tensor = Array3::zeros((1, 2, 1));
44            tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
45            tensor
46        } else if is_first {
47            // First qubit: 1x2xD tensor
48            let mut tensor = Array3::zeros((1, 2, 2));
49            tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
50            tensor
51        } else if is_last {
52            // Last qubit: Dx2x1 tensor
53            let mut tensor = Array3::zeros((2, 2, 1));
54            tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
55            tensor
56        } else {
57            // Middle qubit: Dx2xD tensor
58            let mut tensor = Array3::zeros((2, 2, 2));
59            tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
60            tensor
61        };
62        Self::new(data)
63    }
64}
65
66/// Matrix Product State representation of a quantum state
67pub struct MPS {
68    /// MPS tensors for each qubit
69    tensors: Vec<MPSTensor>,
70    /// Number of qubits
71    num_qubits: usize,
72    /// Maximum allowed bond dimension
73    max_bond_dim: usize,
74    /// SVD truncation threshold
75    truncation_threshold: f64,
76    /// Current orthogonality center (-1 if not in canonical form)
77    orthogonality_center: i32,
78}
79
80impl MPS {
81    /// Create a new MPS in the |0...0> state
82    #[must_use]
83    pub fn new(num_qubits: usize, max_bond_dim: usize) -> Self {
84        let tensors = (0..num_qubits)
85            .map(|i| MPSTensor::zero_state(i == 0, i == num_qubits - 1))
86            .collect();
87
88        Self {
89            tensors,
90            num_qubits,
91            max_bond_dim,
92            truncation_threshold: 1e-10,
93            orthogonality_center: -1,
94        }
95    }
96
97    /// Set the truncation threshold for SVD
98    pub const fn set_truncation_threshold(&mut self, threshold: f64) {
99        self.truncation_threshold = threshold;
100    }
101
102    /// Move orthogonality center to specified position
103    pub fn move_orthogonality_center(&mut self, target: usize) -> QuantRS2Result<()> {
104        if target >= self.num_qubits {
105            return Err(QuantRS2Error::InvalidQubitId(target as u32));
106        }
107
108        // If no current center, canonicalize from left
109        if self.orthogonality_center < 0 {
110            self.left_canonicalize_up_to(target)?;
111            self.orthogonality_center = target as i32;
112            return Ok(());
113        }
114
115        let current = self.orthogonality_center as usize;
116
117        if current < target {
118            // Move right
119            for i in current..target {
120                self.move_center_right(i)?;
121            }
122        } else if current > target {
123            // Move left
124            for i in (target + 1..=current).rev() {
125                self.move_center_left(i)?;
126            }
127        }
128
129        self.orthogonality_center = target as i32;
130        Ok(())
131    }
132
133    /// Left-canonicalize tensors up to position
134    fn left_canonicalize_up_to(&mut self, position: usize) -> QuantRS2Result<()> {
135        for i in 0..position {
136            let tensor = &self.tensors[i];
137            let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
138
139            // Reshape to matrix for QR decomposition
140            let matrix = tensor
141                .data
142                .view()
143                .into_shape((left_dim * phys_dim, right_dim))?;
144
145            // QR decomposition
146            let (q, r) = qr_decomposition(&matrix)?;
147
148            // Update current tensor with Q
149            let new_shape = (left_dim, phys_dim, q.shape()[1]);
150            self.tensors[i].data = q.into_shape(new_shape)?;
151            self.tensors[i].right_dim = new_shape.2;
152
153            // Absorb R into next tensor
154            if i + 1 < self.num_qubits {
155                let next = &mut self.tensors[i + 1];
156                let next_matrix = next
157                    .data
158                    .view()
159                    .into_shape((next.left_dim, 2 * next.right_dim))?;
160                let new_matrix = r.dot(&next_matrix);
161                next.data = new_matrix.into_shape((r.shape()[0], 2, next.right_dim))?;
162                next.left_dim = r.shape()[0];
163            }
164        }
165        Ok(())
166    }
167
168    /// Move orthogonality center one position to the right
169    fn move_center_right(&mut self, position: usize) -> QuantRS2Result<()> {
170        let tensor = &self.tensors[position];
171        let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
172
173        // Reshape and QR decompose
174        let matrix = tensor
175            .data
176            .view()
177            .into_shape((left_dim * phys_dim, right_dim))?;
178        let (q, r) = qr_decomposition(&matrix)?;
179
180        // Update current tensor
181        let q_cols = q.shape()[1];
182        self.tensors[position].data = q.into_shape((left_dim, phys_dim, q_cols))?;
183        self.tensors[position].right_dim = q_cols;
184
185        // Update next tensor
186        if position + 1 < self.num_qubits {
187            let next = &mut self.tensors[position + 1];
188            let next_matrix = next
189                .data
190                .view()
191                .into_shape((next.left_dim, 2 * next.right_dim))?;
192            let new_matrix = r.dot(&next_matrix);
193            next.data = new_matrix.into_shape((r.shape()[0], 2, next.right_dim))?;
194            next.left_dim = r.shape()[0];
195        }
196
197        Ok(())
198    }
199
200    /// Move orthogonality center one position to the left
201    fn move_center_left(&mut self, position: usize) -> QuantRS2Result<()> {
202        let tensor = &self.tensors[position];
203        let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
204
205        // Reshape and QR decompose from right
206        let matrix = tensor
207            .data
208            .view()
209            .permuted_axes([2, 1, 0])
210            .into_shape((right_dim * phys_dim, left_dim))?;
211        let (q, r) = qr_decomposition(&matrix)?;
212
213        // Update current tensor
214        let q_cols = q.shape()[1];
215        let q_reshaped = q.into_shape((right_dim, phys_dim, q_cols))?;
216        self.tensors[position].data = q_reshaped.permuted_axes([2, 1, 0]);
217        self.tensors[position].left_dim = q_cols;
218
219        // Update previous tensor
220        if position > 0 {
221            let prev = &mut self.tensors[position - 1];
222            let prev_matrix = prev
223                .data
224                .view()
225                .into_shape((prev.left_dim * 2, prev.right_dim))?;
226            let new_matrix = prev_matrix.dot(&r.t());
227            prev.data = new_matrix.into_shape((prev.left_dim, 2, r.shape()[0]))?;
228            prev.right_dim = r.shape()[0];
229        }
230
231        Ok(())
232    }
233
234    /// Apply single-qubit gate
235    pub fn apply_single_qubit_gate(
236        &mut self,
237        gate: &dyn GateOp,
238        qubit: usize,
239    ) -> QuantRS2Result<()> {
240        if qubit >= self.num_qubits {
241            return Err(QuantRS2Error::InvalidQubitId(qubit as u32));
242        }
243
244        // Get gate matrix
245        let gate_matrix = gate.matrix()?;
246        let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)?;
247
248        // Apply gate to tensor
249        let tensor = &mut self.tensors[qubit];
250        let mut new_data = Array3::zeros(tensor.data.dim());
251
252        for left in 0..tensor.left_dim {
253            for right in 0..tensor.right_dim {
254                for i in 0..2 {
255                    for j in 0..2 {
256                        new_data[[left, i, right]] +=
257                            gate_array[[i, j]] * tensor.data[[left, j, right]];
258                    }
259                }
260            }
261        }
262
263        tensor.data = new_data;
264        Ok(())
265    }
266
267    /// Apply two-qubit gate using SVD compression
268    pub fn apply_two_qubit_gate(
269        &mut self,
270        gate: &dyn GateOp,
271        qubit1: usize,
272        qubit2: usize,
273    ) -> QuantRS2Result<()> {
274        // Ensure qubits are adjacent
275        if (qubit1 as i32 - qubit2 as i32).abs() != 1 {
276            return Err(QuantRS2Error::ComputationError(
277                "MPS simulator requires adjacent qubits for two-qubit gates".to_string(),
278            ));
279        }
280
281        let (left_qubit, right_qubit) = if qubit1 < qubit2 {
282            (qubit1, qubit2)
283        } else {
284            (qubit2, qubit1)
285        };
286
287        // Move orthogonality center to left qubit
288        self.move_orthogonality_center(left_qubit)?;
289
290        // Get gate matrix
291        let gate_matrix = gate.matrix()?;
292        let gate_array = Array2::from_shape_vec((4, 4), gate_matrix)?;
293
294        // Contract the two tensors
295        let left_tensor = &self.tensors[left_qubit];
296        let right_tensor = &self.tensors[right_qubit];
297
298        let left_dim = left_tensor.left_dim;
299        let right_dim = right_tensor.right_dim;
300
301        // Combine tensors
302        let mut combined = Array3::<Complex64>::zeros((left_dim, 4, right_dim));
303        for l in 0..left_dim {
304            for r in 0..right_dim {
305                for i in 0..2 {
306                    for j in 0..2 {
307                        for k in 0..left_tensor.right_dim {
308                            combined[[l, i * 2 + j, r]] +=
309                                left_tensor.data[[l, i, k]] * right_tensor.data[[k, j, r]];
310                        }
311                    }
312                }
313            }
314        }
315
316        // Apply gate
317        let mut gated = Array3::<Complex64>::zeros((left_dim, 4, right_dim));
318        for l in 0..left_dim {
319            for r in 0..right_dim {
320                for out_idx in 0..4 {
321                    for in_idx in 0..4 {
322                        gated[[l, out_idx, r]] +=
323                            gate_array[[out_idx, in_idx]] * combined[[l, in_idx, r]];
324                    }
325                }
326            }
327        }
328
329        // Decompose back using SVD
330        let matrix = gated.into_shape((left_dim * 2, 2 * right_dim))?;
331        let (u, s, vt) = svd_decomposition(&matrix, self.max_bond_dim, self.truncation_threshold)?;
332
333        // Update tensors
334        let new_bond = s.len();
335        self.tensors[left_qubit].data = u.into_shape((left_dim, 2, new_bond))?;
336        self.tensors[left_qubit].right_dim = new_bond;
337
338        // Convert s to complex diagonal matrix and multiply with vt
339        let mut sv = Array2::<Complex64>::zeros((new_bond, vt.shape()[1]));
340        for i in 0..new_bond {
341            for j in 0..vt.shape()[1] {
342                sv[[i, j]] = Complex64::new(s[i], 0.0) * vt[[i, j]];
343            }
344        }
345        self.tensors[right_qubit].data = sv.t().to_owned().into_shape((new_bond, 2, right_dim))?;
346        self.tensors[right_qubit].left_dim = new_bond;
347
348        self.orthogonality_center = right_qubit as i32;
349
350        Ok(())
351    }
352
353    /// Compute amplitude of a basis state
354    pub fn get_amplitude(&self, bitstring: &[bool]) -> QuantRS2Result<Complex64> {
355        if bitstring.len() != self.num_qubits {
356            return Err(QuantRS2Error::ComputationError(format!(
357                "Bitstring length {} doesn't match qubit count {}",
358                bitstring.len(),
359                self.num_qubits
360            )));
361        }
362
363        // Contract from left to right
364        let mut result = Array2::eye(1);
365
366        for (i, &bit) in bitstring.iter().enumerate() {
367            let tensor = &self.tensors[i];
368            let idx = i32::from(bit);
369
370            // Extract the matrix for this bit value
371            let matrix = tensor.data.slice(s![.., idx, ..]);
372            result = result.dot(&matrix);
373        }
374
375        Ok(result[[0, 0]])
376    }
377
378    /// Sample from the MPS
379    #[must_use]
380    pub fn sample(&self) -> Vec<bool> {
381        use scirs2_core::random::prelude::*;
382        let mut rng = thread_rng();
383        let mut result = vec![false; self.num_qubits];
384        let mut accumulated_matrix = Array2::eye(1);
385
386        for (i, tensor) in self.tensors.iter().enumerate() {
387            // Compute probabilities for this qubit
388            let mut prob0 = Complex64::new(0.0, 0.0);
389            let mut prob1 = Complex64::new(0.0, 0.0);
390
391            // Probability of |0>
392            let matrix0 = tensor.data.slice(s![.., 0, ..]);
393            let temp0: Array2<Complex64> = accumulated_matrix.dot(&matrix0);
394
395            // Contract with remaining tensors
396            let mut right_contract = Array2::eye(temp0.shape()[1]);
397            for j in (i + 1)..self.num_qubits {
398                let sum_matrix = self.tensors[j].data.slice(s![.., 0, ..]).to_owned()
399                    + self.tensors[j].data.slice(s![.., 1, ..]).to_owned();
400                right_contract = right_contract.dot(&sum_matrix);
401            }
402
403            prob0 = temp0.dot(&right_contract)[[0, 0]];
404
405            // Similar for |1>
406            let matrix1 = tensor.data.slice(s![.., 1, ..]);
407            let temp1: Array2<Complex64> = accumulated_matrix.dot(&matrix1);
408            prob1 = temp1.dot(&right_contract)[[0, 0]];
409
410            // Normalize and sample
411            let total = prob0.norm_sqr() + prob1.norm_sqr();
412            let threshold = prob0.norm_sqr() / total;
413
414            if rng.gen::<f64>() < threshold {
415                result[i] = false;
416                accumulated_matrix = temp0;
417            } else {
418                result[i] = true;
419                accumulated_matrix = temp1;
420            }
421        }
422
423        result
424    }
425}
426
427/// QR decomposition helper
428fn qr_decomposition(
429    matrix: &ArrayView2<Complex64>,
430) -> QuantRS2Result<(Array2<Complex64>, Array2<Complex64>)> {
431    // Simple Gram-Schmidt QR decomposition
432    let (m, n) = matrix.dim();
433    let mut q = Array2::zeros((m, n.min(m)));
434    let mut r = Array2::zeros((n.min(m), n));
435
436    for j in 0..n.min(m) {
437        let mut v = matrix.column(j).to_owned();
438
439        // Orthogonalize against previous columns
440        for i in 0..j {
441            let proj = q.column(i).dot(&v);
442            r[[i, j]] = proj;
443            v -= &(proj * &q.column(i).to_owned());
444        }
445
446        let norm = (v.dot(&v)).sqrt();
447        if norm.norm() > 1e-10 {
448            r[[j, j]] = norm;
449            q.column_mut(j).assign(&(v / norm));
450        }
451    }
452
453    // Copy remaining columns of R
454    if n > m {
455        for j in m..n {
456            for i in 0..m {
457                r[[i, j]] = q.column(i).dot(&matrix.column(j));
458            }
459        }
460    }
461
462    Ok((q, r))
463}
464
465/// SVD decomposition with truncation
466fn svd_decomposition(
467    matrix: &Array2<Complex64>,
468    max_bond: usize,
469    threshold: f64,
470) -> QuantRS2Result<(Array2<Complex64>, Array1<f64>, Array2<Complex64>)> {
471    // Placeholder - in real implementation would use proper SVD
472    // For now, return identity-like decomposition
473    let (m, n) = matrix.dim();
474    let k = m.min(n).min(max_bond);
475
476    let u = Array2::eye(m).slice(s![.., ..k]).to_owned();
477    let s = Array1::ones(k);
478    let vt = Array2::eye(n).slice(s![..k, ..]).to_owned();
479
480    Ok((u, s, vt))
481}
482
483/// MPS quantum simulator
484pub struct MPSSimulator {
485    /// Maximum bond dimension
486    max_bond_dimension: usize,
487    /// SVD truncation threshold
488    truncation_threshold: f64,
489}
490
491impl MPSSimulator {
492    /// Create a new MPS simulator
493    #[must_use]
494    pub const fn new(max_bond_dimension: usize) -> Self {
495        Self {
496            max_bond_dimension,
497            truncation_threshold: 1e-10,
498        }
499    }
500
501    /// Set the truncation threshold
502    pub const fn set_truncation_threshold(&mut self, threshold: f64) {
503        self.truncation_threshold = threshold;
504    }
505}
506
507impl<const N: usize> Simulator<N> for MPSSimulator {
508    fn run(&self, circuit: &Circuit<N>) -> QuantRS2Result<Register<N>> {
509        // Create initial MPS state
510        let mut mps = MPS::new(N, self.max_bond_dimension);
511        mps.set_truncation_threshold(self.truncation_threshold);
512
513        // Get gate sequence from circuit
514        // Note: This is a placeholder - would need actual circuit introspection
515        // For now, return a register in |0> state
516        Ok(Register::new())
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use quantrs2_core::gate::single::Hadamard;
524
525    #[test]
526    fn test_mps_creation() {
527        let mps = MPS::new(4, 10);
528        assert_eq!(mps.num_qubits, 4);
529        assert_eq!(mps.tensors.len(), 4);
530    }
531
532    #[test]
533    fn test_single_qubit_gate() {
534        let mut mps = MPS::new(1, 10);
535        let h = Hadamard {
536            target: QubitId::new(0),
537        };
538
539        mps.apply_single_qubit_gate(&h, 0)
540            .expect("Failed to apply single qubit gate");
541
542        // Check amplitudes
543        let amp0 = mps
544            .get_amplitude(&[false])
545            .expect("Failed to get amplitude for |0>");
546        let amp1 = mps
547            .get_amplitude(&[true])
548            .expect("Failed to get amplitude for |1>");
549
550        let expected = 1.0 / 2.0_f64.sqrt();
551        assert!((amp0.re - expected).abs() < 1e-10);
552        assert!((amp1.re - expected).abs() < 1e-10);
553    }
554
555    #[test]
556    fn test_orthogonality_center() {
557        let mut mps = MPS::new(5, 10);
558
559        mps.move_orthogonality_center(2)
560            .expect("Failed to move orthogonality center to 2");
561        assert_eq!(mps.orthogonality_center, 2);
562
563        mps.move_orthogonality_center(4)
564            .expect("Failed to move orthogonality center to 4");
565        assert_eq!(mps.orthogonality_center, 4);
566
567        mps.move_orthogonality_center(0)
568            .expect("Failed to move orthogonality center to 0");
569        assert_eq!(mps.orthogonality_center, 0);
570    }
571}