quantrs2_sim/
mps_simulator.rs

1//! Matrix Product State (MPS) quantum simulator
2//!
3//! This module implements an efficient quantum simulator using the Matrix Product State
4//! representation, which is particularly effective for simulating quantum systems with
5//! limited entanglement.
6
7use ndarray::{s, Array1, Array2, Array3, ArrayView2};
8use num_complex::Complex64;
9use quantrs2_circuit::builder::{Circuit, Simulator};
10use quantrs2_core::{
11    error::{QuantRS2Error, QuantRS2Result},
12    gate::GateOp,
13    prelude::QubitId,
14    register::Register,
15};
16
17/// MPS tensor for a single qubit
18#[derive(Debug, Clone)]
19struct MPSTensor {
20    /// The tensor data: left_bond x physical x right_bond
21    data: Array3<Complex64>,
22    /// Left bond dimension
23    left_dim: usize,
24    /// Right bond dimension
25    right_dim: usize,
26}
27
28impl MPSTensor {
29    /// Create a new MPS tensor
30    fn new(data: Array3<Complex64>) -> Self {
31        let shape = data.shape();
32        Self {
33            left_dim: shape[0],
34            right_dim: shape[2],
35            data,
36        }
37    }
38
39    /// Create initial tensor for |0> state
40    fn zero_state(is_first: bool, is_last: bool) -> Self {
41        let data = if is_first && is_last {
42            // Single qubit: 1x2x1 tensor
43            let mut tensor = Array3::zeros((1, 2, 1));
44            tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
45            tensor
46        } else if is_first {
47            // First qubit: 1x2xD tensor
48            let mut tensor = Array3::zeros((1, 2, 2));
49            tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
50            tensor
51        } else if is_last {
52            // Last qubit: Dx2x1 tensor
53            let mut tensor = Array3::zeros((2, 2, 1));
54            tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
55            tensor
56        } else {
57            // Middle qubit: Dx2xD tensor
58            let mut tensor = Array3::zeros((2, 2, 2));
59            tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
60            tensor
61        };
62        Self::new(data)
63    }
64}
65
66/// Matrix Product State representation of a quantum state
67pub struct MPS {
68    /// MPS tensors for each qubit
69    tensors: Vec<MPSTensor>,
70    /// Number of qubits
71    num_qubits: usize,
72    /// Maximum allowed bond dimension
73    max_bond_dim: usize,
74    /// SVD truncation threshold
75    truncation_threshold: f64,
76    /// Current orthogonality center (-1 if not in canonical form)
77    orthogonality_center: i32,
78}
79
80impl MPS {
81    /// Create a new MPS in the |0...0> state
82    pub fn new(num_qubits: usize, max_bond_dim: usize) -> Self {
83        let tensors = (0..num_qubits)
84            .map(|i| MPSTensor::zero_state(i == 0, i == num_qubits - 1))
85            .collect();
86
87        Self {
88            tensors,
89            num_qubits,
90            max_bond_dim,
91            truncation_threshold: 1e-10,
92            orthogonality_center: -1,
93        }
94    }
95
96    /// Set the truncation threshold for SVD
97    pub fn set_truncation_threshold(&mut self, threshold: f64) {
98        self.truncation_threshold = threshold;
99    }
100
101    /// Move orthogonality center to specified position
102    pub fn move_orthogonality_center(&mut self, target: usize) -> QuantRS2Result<()> {
103        if target >= self.num_qubits {
104            return Err(QuantRS2Error::InvalidQubitId(target as u32));
105        }
106
107        // If no current center, canonicalize from left
108        if self.orthogonality_center < 0 {
109            self.left_canonicalize_up_to(target)?;
110            self.orthogonality_center = target as i32;
111            return Ok(());
112        }
113
114        let current = self.orthogonality_center as usize;
115
116        if current < target {
117            // Move right
118            for i in current..target {
119                self.move_center_right(i)?;
120            }
121        } else if current > target {
122            // Move left
123            for i in (target + 1..=current).rev() {
124                self.move_center_left(i)?;
125            }
126        }
127
128        self.orthogonality_center = target as i32;
129        Ok(())
130    }
131
132    /// Left-canonicalize tensors up to position
133    fn left_canonicalize_up_to(&mut self, position: usize) -> QuantRS2Result<()> {
134        for i in 0..position {
135            let tensor = &self.tensors[i];
136            let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
137
138            // Reshape to matrix for QR decomposition
139            let matrix = tensor
140                .data
141                .view()
142                .into_shape((left_dim * phys_dim, right_dim))?;
143
144            // QR decomposition
145            let (q, r) = qr_decomposition(&matrix)?;
146
147            // Update current tensor with Q
148            let new_shape = (left_dim, phys_dim, q.shape()[1]);
149            self.tensors[i].data = q.into_shape(new_shape)?;
150            self.tensors[i].right_dim = new_shape.2;
151
152            // Absorb R into next tensor
153            if i + 1 < self.num_qubits {
154                let next = &mut self.tensors[i + 1];
155                let next_matrix = next
156                    .data
157                    .view()
158                    .into_shape((next.left_dim, 2 * next.right_dim))?;
159                let new_matrix = r.dot(&next_matrix);
160                next.data = new_matrix.into_shape((r.shape()[0], 2, next.right_dim))?;
161                next.left_dim = r.shape()[0];
162            }
163        }
164        Ok(())
165    }
166
167    /// Move orthogonality center one position to the right
168    fn move_center_right(&mut self, position: usize) -> QuantRS2Result<()> {
169        let tensor = &self.tensors[position];
170        let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
171
172        // Reshape and QR decompose
173        let matrix = tensor
174            .data
175            .view()
176            .into_shape((left_dim * phys_dim, right_dim))?;
177        let (q, r) = qr_decomposition(&matrix)?;
178
179        // Update current tensor
180        let q_cols = q.shape()[1];
181        self.tensors[position].data = q.into_shape((left_dim, phys_dim, q_cols))?;
182        self.tensors[position].right_dim = q_cols;
183
184        // Update next tensor
185        if position + 1 < self.num_qubits {
186            let next = &mut self.tensors[position + 1];
187            let next_matrix = next
188                .data
189                .view()
190                .into_shape((next.left_dim, 2 * next.right_dim))?;
191            let new_matrix = r.dot(&next_matrix);
192            next.data = new_matrix.into_shape((r.shape()[0], 2, next.right_dim))?;
193            next.left_dim = r.shape()[0];
194        }
195
196        Ok(())
197    }
198
199    /// Move orthogonality center one position to the left
200    fn move_center_left(&mut self, position: usize) -> QuantRS2Result<()> {
201        let tensor = &self.tensors[position];
202        let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
203
204        // Reshape and QR decompose from right
205        let matrix = tensor
206            .data
207            .view()
208            .permuted_axes([2, 1, 0])
209            .into_shape((right_dim * phys_dim, left_dim))?;
210        let (q, r) = qr_decomposition(&matrix)?;
211
212        // Update current tensor
213        let q_cols = q.shape()[1];
214        let q_reshaped = q.into_shape((right_dim, phys_dim, q_cols))?;
215        self.tensors[position].data = q_reshaped.permuted_axes([2, 1, 0]);
216        self.tensors[position].left_dim = q_cols;
217
218        // Update previous tensor
219        if position > 0 {
220            let prev = &mut self.tensors[position - 1];
221            let prev_matrix = prev
222                .data
223                .view()
224                .into_shape((prev.left_dim * 2, prev.right_dim))?;
225            let new_matrix = prev_matrix.dot(&r.t());
226            prev.data = new_matrix.into_shape((prev.left_dim, 2, r.shape()[0]))?;
227            prev.right_dim = r.shape()[0];
228        }
229
230        Ok(())
231    }
232
233    /// Apply single-qubit gate
234    pub fn apply_single_qubit_gate(
235        &mut self,
236        gate: &dyn GateOp,
237        qubit: usize,
238    ) -> QuantRS2Result<()> {
239        if qubit >= self.num_qubits {
240            return Err(QuantRS2Error::InvalidQubitId(qubit as u32));
241        }
242
243        // Get gate matrix
244        let gate_matrix = gate.matrix()?;
245        let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)?;
246
247        // Apply gate to tensor
248        let tensor = &mut self.tensors[qubit];
249        let mut new_data = Array3::zeros(tensor.data.dim());
250
251        for left in 0..tensor.left_dim {
252            for right in 0..tensor.right_dim {
253                for i in 0..2 {
254                    for j in 0..2 {
255                        new_data[[left, i, right]] +=
256                            gate_array[[i, j]] * tensor.data[[left, j, right]];
257                    }
258                }
259            }
260        }
261
262        tensor.data = new_data;
263        Ok(())
264    }
265
266    /// Apply two-qubit gate using SVD compression
267    pub fn apply_two_qubit_gate(
268        &mut self,
269        gate: &dyn GateOp,
270        qubit1: usize,
271        qubit2: usize,
272    ) -> QuantRS2Result<()> {
273        // Ensure qubits are adjacent
274        if (qubit1 as i32 - qubit2 as i32).abs() != 1 {
275            return Err(QuantRS2Error::ComputationError(
276                "MPS simulator requires adjacent qubits for two-qubit gates".to_string(),
277            ));
278        }
279
280        let (left_qubit, right_qubit) = if qubit1 < qubit2 {
281            (qubit1, qubit2)
282        } else {
283            (qubit2, qubit1)
284        };
285
286        // Move orthogonality center to left qubit
287        self.move_orthogonality_center(left_qubit)?;
288
289        // Get gate matrix
290        let gate_matrix = gate.matrix()?;
291        let gate_array = Array2::from_shape_vec((4, 4), gate_matrix)?;
292
293        // Contract the two tensors
294        let left_tensor = &self.tensors[left_qubit];
295        let right_tensor = &self.tensors[right_qubit];
296
297        let left_dim = left_tensor.left_dim;
298        let right_dim = right_tensor.right_dim;
299
300        // Combine tensors
301        let mut combined = Array3::<Complex64>::zeros((left_dim, 4, right_dim));
302        for l in 0..left_dim {
303            for r in 0..right_dim {
304                for i in 0..2 {
305                    for j in 0..2 {
306                        for k in 0..left_tensor.right_dim {
307                            combined[[l, i * 2 + j, r]] +=
308                                left_tensor.data[[l, i, k]] * right_tensor.data[[k, j, r]];
309                        }
310                    }
311                }
312            }
313        }
314
315        // Apply gate
316        let mut gated = Array3::<Complex64>::zeros((left_dim, 4, right_dim));
317        for l in 0..left_dim {
318            for r in 0..right_dim {
319                for out_idx in 0..4 {
320                    for in_idx in 0..4 {
321                        gated[[l, out_idx, r]] +=
322                            gate_array[[out_idx, in_idx]] * combined[[l, in_idx, r]];
323                    }
324                }
325            }
326        }
327
328        // Decompose back using SVD
329        let matrix = gated.into_shape((left_dim * 2, 2 * right_dim))?;
330        let (u, s, vt) = svd_decomposition(&matrix, self.max_bond_dim, self.truncation_threshold)?;
331
332        // Update tensors
333        let new_bond = s.len();
334        self.tensors[left_qubit].data = u.into_shape((left_dim, 2, new_bond))?;
335        self.tensors[left_qubit].right_dim = new_bond;
336
337        // Convert s to complex diagonal matrix and multiply with vt
338        let mut sv = Array2::<Complex64>::zeros((new_bond, vt.shape()[1]));
339        for i in 0..new_bond {
340            for j in 0..vt.shape()[1] {
341                sv[[i, j]] = Complex64::new(s[i], 0.0) * vt[[i, j]];
342            }
343        }
344        self.tensors[right_qubit].data = sv.t().to_owned().into_shape((new_bond, 2, right_dim))?;
345        self.tensors[right_qubit].left_dim = new_bond;
346
347        self.orthogonality_center = right_qubit as i32;
348
349        Ok(())
350    }
351
352    /// Compute amplitude of a basis state
353    pub fn get_amplitude(&self, bitstring: &[bool]) -> QuantRS2Result<Complex64> {
354        if bitstring.len() != self.num_qubits {
355            return Err(QuantRS2Error::ComputationError(format!(
356                "Bitstring length {} doesn't match qubit count {}",
357                bitstring.len(),
358                self.num_qubits
359            )));
360        }
361
362        // Contract from left to right
363        let mut result = Array2::eye(1);
364
365        for (i, &bit) in bitstring.iter().enumerate() {
366            let tensor = &self.tensors[i];
367            let idx = if bit { 1 } else { 0 };
368
369            // Extract the matrix for this bit value
370            let matrix = tensor.data.slice(s![.., idx, ..]);
371            result = result.dot(&matrix);
372        }
373
374        Ok(result[[0, 0]])
375    }
376
377    /// Sample from the MPS
378    pub fn sample(&self) -> Vec<bool> {
379        use rand::{thread_rng, Rng};
380        let mut rng = thread_rng();
381        let mut result = vec![false; self.num_qubits];
382        let mut accumulated_matrix = Array2::eye(1);
383
384        for i in 0..self.num_qubits {
385            let tensor = &self.tensors[i];
386
387            // Compute probabilities for this qubit
388            let mut prob0 = Complex64::new(0.0, 0.0);
389            let mut prob1 = Complex64::new(0.0, 0.0);
390
391            // Probability of |0>
392            let matrix0 = tensor.data.slice(s![.., 0, ..]);
393            let temp0: Array2<Complex64> = accumulated_matrix.dot(&matrix0);
394
395            // Contract with remaining tensors
396            let mut right_contract = Array2::eye(temp0.shape()[1]);
397            for j in (i + 1)..self.num_qubits {
398                let sum_matrix = self.tensors[j].data.slice(s![.., 0, ..]).to_owned()
399                    + self.tensors[j].data.slice(s![.., 1, ..]).to_owned();
400                right_contract = right_contract.dot(&sum_matrix);
401            }
402
403            prob0 = temp0.dot(&right_contract)[[0, 0]];
404
405            // Similar for |1>
406            let matrix1 = tensor.data.slice(s![.., 1, ..]);
407            let temp1: Array2<Complex64> = accumulated_matrix.dot(&matrix1);
408            prob1 = temp1.dot(&right_contract)[[0, 0]];
409
410            // Normalize and sample
411            let total = prob0.norm_sqr() + prob1.norm_sqr();
412            let threshold = prob0.norm_sqr() / total;
413
414            if rng.gen::<f64>() < threshold {
415                result[i] = false;
416                accumulated_matrix = temp0;
417            } else {
418                result[i] = true;
419                accumulated_matrix = temp1;
420            }
421        }
422
423        result
424    }
425}
426
427/// QR decomposition helper
428fn qr_decomposition(
429    matrix: &ArrayView2<Complex64>,
430) -> QuantRS2Result<(Array2<Complex64>, Array2<Complex64>)> {
431    // Simple Gram-Schmidt QR decomposition
432    let (m, n) = matrix.dim();
433    let mut q = Array2::zeros((m, n.min(m)));
434    let mut r = Array2::zeros((n.min(m), n));
435
436    for j in 0..n.min(m) {
437        let mut v = matrix.column(j).to_owned();
438
439        // Orthogonalize against previous columns
440        for i in 0..j {
441            let proj = q.column(i).dot(&v);
442            r[[i, j]] = proj;
443            v = v - &(proj * &q.column(i).to_owned());
444        }
445
446        let norm = (v.dot(&v)).sqrt();
447        if norm.norm() > 1e-10 {
448            r[[j, j]] = norm;
449            q.column_mut(j).assign(&(v / norm));
450        }
451    }
452
453    // Copy remaining columns of R
454    if n > m {
455        for j in m..n {
456            for i in 0..m {
457                r[[i, j]] = q.column(i).dot(&matrix.column(j));
458            }
459        }
460    }
461
462    Ok((q, r))
463}
464
465/// SVD decomposition with truncation
466fn svd_decomposition(
467    matrix: &Array2<Complex64>,
468    max_bond: usize,
469    threshold: f64,
470) -> QuantRS2Result<(Array2<Complex64>, Array1<f64>, Array2<Complex64>)> {
471    // Placeholder - in real implementation would use proper SVD
472    // For now, return identity-like decomposition
473    let (m, n) = matrix.dim();
474    let k = m.min(n).min(max_bond);
475
476    let u = Array2::eye(m).slice(s![.., ..k]).to_owned();
477    let s = Array1::ones(k);
478    let vt = Array2::eye(n).slice(s![..k, ..]).to_owned();
479
480    Ok((u, s, vt))
481}
482
483/// MPS quantum simulator
484pub struct MPSSimulator {
485    /// Maximum bond dimension
486    max_bond_dimension: usize,
487    /// SVD truncation threshold
488    truncation_threshold: f64,
489}
490
491impl MPSSimulator {
492    /// Create a new MPS simulator
493    pub fn new(max_bond_dimension: usize) -> Self {
494        Self {
495            max_bond_dimension,
496            truncation_threshold: 1e-10,
497        }
498    }
499
500    /// Set the truncation threshold
501    pub fn set_truncation_threshold(&mut self, threshold: f64) {
502        self.truncation_threshold = threshold;
503    }
504}
505
506impl<const N: usize> Simulator<N> for MPSSimulator {
507    fn run(&self, circuit: &Circuit<N>) -> QuantRS2Result<Register<N>> {
508        // Create initial MPS state
509        let mut mps = MPS::new(N, self.max_bond_dimension);
510        mps.set_truncation_threshold(self.truncation_threshold);
511
512        // Get gate sequence from circuit
513        // Note: This is a placeholder - would need actual circuit introspection
514        // For now, return a register in |0> state
515        Ok(Register::new())
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522    use quantrs2_core::gate::single::Hadamard;
523
524    #[test]
525    fn test_mps_creation() {
526        let mps = MPS::new(4, 10);
527        assert_eq!(mps.num_qubits, 4);
528        assert_eq!(mps.tensors.len(), 4);
529    }
530
531    #[test]
532    fn test_single_qubit_gate() {
533        let mut mps = MPS::new(1, 10);
534        let h = Hadamard {
535            target: QubitId::new(0),
536        };
537
538        mps.apply_single_qubit_gate(&h, 0).unwrap();
539
540        // Check amplitudes
541        let amp0 = mps.get_amplitude(&[false]).unwrap();
542        let amp1 = mps.get_amplitude(&[true]).unwrap();
543
544        let expected = 1.0 / 2.0_f64.sqrt();
545        assert!((amp0.re - expected).abs() < 1e-10);
546        assert!((amp1.re - expected).abs() < 1e-10);
547    }
548
549    #[test]
550    fn test_orthogonality_center() {
551        let mut mps = MPS::new(5, 10);
552
553        mps.move_orthogonality_center(2).unwrap();
554        assert_eq!(mps.orthogonality_center, 2);
555
556        mps.move_orthogonality_center(4).unwrap();
557        assert_eq!(mps.orthogonality_center, 4);
558
559        mps.move_orthogonality_center(0).unwrap();
560        assert_eq!(mps.orthogonality_center, 0);
561    }
562}