1use ndarray::{s, Array1, Array2, Array3, ArrayView2};
8use num_complex::Complex64;
9use quantrs2_circuit::builder::{Circuit, Simulator};
10use quantrs2_core::{
11 error::{QuantRS2Error, QuantRS2Result},
12 gate::GateOp,
13 prelude::QubitId,
14 register::Register,
15};
16
17#[derive(Debug, Clone)]
19struct MPSTensor {
20 data: Array3<Complex64>,
22 left_dim: usize,
24 right_dim: usize,
26}
27
28impl MPSTensor {
29 fn new(data: Array3<Complex64>) -> Self {
31 let shape = data.shape();
32 Self {
33 left_dim: shape[0],
34 right_dim: shape[2],
35 data,
36 }
37 }
38
39 fn zero_state(is_first: bool, is_last: bool) -> Self {
41 let data = if is_first && is_last {
42 let mut tensor = Array3::zeros((1, 2, 1));
44 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
45 tensor
46 } else if is_first {
47 let mut tensor = Array3::zeros((1, 2, 2));
49 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
50 tensor
51 } else if is_last {
52 let mut tensor = Array3::zeros((2, 2, 1));
54 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
55 tensor
56 } else {
57 let mut tensor = Array3::zeros((2, 2, 2));
59 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
60 tensor
61 };
62 Self::new(data)
63 }
64}
65
66pub struct MPS {
68 tensors: Vec<MPSTensor>,
70 num_qubits: usize,
72 max_bond_dim: usize,
74 truncation_threshold: f64,
76 orthogonality_center: i32,
78}
79
80impl MPS {
81 pub fn new(num_qubits: usize, max_bond_dim: usize) -> Self {
83 let tensors = (0..num_qubits)
84 .map(|i| MPSTensor::zero_state(i == 0, i == num_qubits - 1))
85 .collect();
86
87 Self {
88 tensors,
89 num_qubits,
90 max_bond_dim,
91 truncation_threshold: 1e-10,
92 orthogonality_center: -1,
93 }
94 }
95
96 pub fn set_truncation_threshold(&mut self, threshold: f64) {
98 self.truncation_threshold = threshold;
99 }
100
101 pub fn move_orthogonality_center(&mut self, target: usize) -> QuantRS2Result<()> {
103 if target >= self.num_qubits {
104 return Err(QuantRS2Error::InvalidQubitId(target as u32));
105 }
106
107 if self.orthogonality_center < 0 {
109 self.left_canonicalize_up_to(target)?;
110 self.orthogonality_center = target as i32;
111 return Ok(());
112 }
113
114 let current = self.orthogonality_center as usize;
115
116 if current < target {
117 for i in current..target {
119 self.move_center_right(i)?;
120 }
121 } else if current > target {
122 for i in (target + 1..=current).rev() {
124 self.move_center_left(i)?;
125 }
126 }
127
128 self.orthogonality_center = target as i32;
129 Ok(())
130 }
131
132 fn left_canonicalize_up_to(&mut self, position: usize) -> QuantRS2Result<()> {
134 for i in 0..position {
135 let tensor = &self.tensors[i];
136 let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
137
138 let matrix = tensor
140 .data
141 .view()
142 .into_shape((left_dim * phys_dim, right_dim))?;
143
144 let (q, r) = qr_decomposition(&matrix)?;
146
147 let new_shape = (left_dim, phys_dim, q.shape()[1]);
149 self.tensors[i].data = q.into_shape(new_shape)?;
150 self.tensors[i].right_dim = new_shape.2;
151
152 if i + 1 < self.num_qubits {
154 let next = &mut self.tensors[i + 1];
155 let next_matrix = next
156 .data
157 .view()
158 .into_shape((next.left_dim, 2 * next.right_dim))?;
159 let new_matrix = r.dot(&next_matrix);
160 next.data = new_matrix.into_shape((r.shape()[0], 2, next.right_dim))?;
161 next.left_dim = r.shape()[0];
162 }
163 }
164 Ok(())
165 }
166
167 fn move_center_right(&mut self, position: usize) -> QuantRS2Result<()> {
169 let tensor = &self.tensors[position];
170 let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
171
172 let matrix = tensor
174 .data
175 .view()
176 .into_shape((left_dim * phys_dim, right_dim))?;
177 let (q, r) = qr_decomposition(&matrix)?;
178
179 let q_cols = q.shape()[1];
181 self.tensors[position].data = q.into_shape((left_dim, phys_dim, q_cols))?;
182 self.tensors[position].right_dim = q_cols;
183
184 if position + 1 < self.num_qubits {
186 let next = &mut self.tensors[position + 1];
187 let next_matrix = next
188 .data
189 .view()
190 .into_shape((next.left_dim, 2 * next.right_dim))?;
191 let new_matrix = r.dot(&next_matrix);
192 next.data = new_matrix.into_shape((r.shape()[0], 2, next.right_dim))?;
193 next.left_dim = r.shape()[0];
194 }
195
196 Ok(())
197 }
198
199 fn move_center_left(&mut self, position: usize) -> QuantRS2Result<()> {
201 let tensor = &self.tensors[position];
202 let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
203
204 let matrix = tensor
206 .data
207 .view()
208 .permuted_axes([2, 1, 0])
209 .into_shape((right_dim * phys_dim, left_dim))?;
210 let (q, r) = qr_decomposition(&matrix)?;
211
212 let q_cols = q.shape()[1];
214 let q_reshaped = q.into_shape((right_dim, phys_dim, q_cols))?;
215 self.tensors[position].data = q_reshaped.permuted_axes([2, 1, 0]);
216 self.tensors[position].left_dim = q_cols;
217
218 if position > 0 {
220 let prev = &mut self.tensors[position - 1];
221 let prev_matrix = prev
222 .data
223 .view()
224 .into_shape((prev.left_dim * 2, prev.right_dim))?;
225 let new_matrix = prev_matrix.dot(&r.t());
226 prev.data = new_matrix.into_shape((prev.left_dim, 2, r.shape()[0]))?;
227 prev.right_dim = r.shape()[0];
228 }
229
230 Ok(())
231 }
232
233 pub fn apply_single_qubit_gate(
235 &mut self,
236 gate: &dyn GateOp,
237 qubit: usize,
238 ) -> QuantRS2Result<()> {
239 if qubit >= self.num_qubits {
240 return Err(QuantRS2Error::InvalidQubitId(qubit as u32));
241 }
242
243 let gate_matrix = gate.matrix()?;
245 let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)?;
246
247 let tensor = &mut self.tensors[qubit];
249 let mut new_data = Array3::zeros(tensor.data.dim());
250
251 for left in 0..tensor.left_dim {
252 for right in 0..tensor.right_dim {
253 for i in 0..2 {
254 for j in 0..2 {
255 new_data[[left, i, right]] +=
256 gate_array[[i, j]] * tensor.data[[left, j, right]];
257 }
258 }
259 }
260 }
261
262 tensor.data = new_data;
263 Ok(())
264 }
265
266 pub fn apply_two_qubit_gate(
268 &mut self,
269 gate: &dyn GateOp,
270 qubit1: usize,
271 qubit2: usize,
272 ) -> QuantRS2Result<()> {
273 if (qubit1 as i32 - qubit2 as i32).abs() != 1 {
275 return Err(QuantRS2Error::ComputationError(
276 "MPS simulator requires adjacent qubits for two-qubit gates".to_string(),
277 ));
278 }
279
280 let (left_qubit, right_qubit) = if qubit1 < qubit2 {
281 (qubit1, qubit2)
282 } else {
283 (qubit2, qubit1)
284 };
285
286 self.move_orthogonality_center(left_qubit)?;
288
289 let gate_matrix = gate.matrix()?;
291 let gate_array = Array2::from_shape_vec((4, 4), gate_matrix)?;
292
293 let left_tensor = &self.tensors[left_qubit];
295 let right_tensor = &self.tensors[right_qubit];
296
297 let left_dim = left_tensor.left_dim;
298 let right_dim = right_tensor.right_dim;
299
300 let mut combined = Array3::<Complex64>::zeros((left_dim, 4, right_dim));
302 for l in 0..left_dim {
303 for r in 0..right_dim {
304 for i in 0..2 {
305 for j in 0..2 {
306 for k in 0..left_tensor.right_dim {
307 combined[[l, i * 2 + j, r]] +=
308 left_tensor.data[[l, i, k]] * right_tensor.data[[k, j, r]];
309 }
310 }
311 }
312 }
313 }
314
315 let mut gated = Array3::<Complex64>::zeros((left_dim, 4, right_dim));
317 for l in 0..left_dim {
318 for r in 0..right_dim {
319 for out_idx in 0..4 {
320 for in_idx in 0..4 {
321 gated[[l, out_idx, r]] +=
322 gate_array[[out_idx, in_idx]] * combined[[l, in_idx, r]];
323 }
324 }
325 }
326 }
327
328 let matrix = gated.into_shape((left_dim * 2, 2 * right_dim))?;
330 let (u, s, vt) = svd_decomposition(&matrix, self.max_bond_dim, self.truncation_threshold)?;
331
332 let new_bond = s.len();
334 self.tensors[left_qubit].data = u.into_shape((left_dim, 2, new_bond))?;
335 self.tensors[left_qubit].right_dim = new_bond;
336
337 let mut sv = Array2::<Complex64>::zeros((new_bond, vt.shape()[1]));
339 for i in 0..new_bond {
340 for j in 0..vt.shape()[1] {
341 sv[[i, j]] = Complex64::new(s[i], 0.0) * vt[[i, j]];
342 }
343 }
344 self.tensors[right_qubit].data = sv.t().to_owned().into_shape((new_bond, 2, right_dim))?;
345 self.tensors[right_qubit].left_dim = new_bond;
346
347 self.orthogonality_center = right_qubit as i32;
348
349 Ok(())
350 }
351
352 pub fn get_amplitude(&self, bitstring: &[bool]) -> QuantRS2Result<Complex64> {
354 if bitstring.len() != self.num_qubits {
355 return Err(QuantRS2Error::ComputationError(format!(
356 "Bitstring length {} doesn't match qubit count {}",
357 bitstring.len(),
358 self.num_qubits
359 )));
360 }
361
362 let mut result = Array2::eye(1);
364
365 for (i, &bit) in bitstring.iter().enumerate() {
366 let tensor = &self.tensors[i];
367 let idx = if bit { 1 } else { 0 };
368
369 let matrix = tensor.data.slice(s![.., idx, ..]);
371 result = result.dot(&matrix);
372 }
373
374 Ok(result[[0, 0]])
375 }
376
377 pub fn sample(&self) -> Vec<bool> {
379 use rand::{thread_rng, Rng};
380 let mut rng = thread_rng();
381 let mut result = vec![false; self.num_qubits];
382 let mut accumulated_matrix = Array2::eye(1);
383
384 for i in 0..self.num_qubits {
385 let tensor = &self.tensors[i];
386
387 let mut prob0 = Complex64::new(0.0, 0.0);
389 let mut prob1 = Complex64::new(0.0, 0.0);
390
391 let matrix0 = tensor.data.slice(s![.., 0, ..]);
393 let temp0: Array2<Complex64> = accumulated_matrix.dot(&matrix0);
394
395 let mut right_contract = Array2::eye(temp0.shape()[1]);
397 for j in (i + 1)..self.num_qubits {
398 let sum_matrix = self.tensors[j].data.slice(s![.., 0, ..]).to_owned()
399 + self.tensors[j].data.slice(s![.., 1, ..]).to_owned();
400 right_contract = right_contract.dot(&sum_matrix);
401 }
402
403 prob0 = temp0.dot(&right_contract)[[0, 0]];
404
405 let matrix1 = tensor.data.slice(s![.., 1, ..]);
407 let temp1: Array2<Complex64> = accumulated_matrix.dot(&matrix1);
408 prob1 = temp1.dot(&right_contract)[[0, 0]];
409
410 let total = prob0.norm_sqr() + prob1.norm_sqr();
412 let threshold = prob0.norm_sqr() / total;
413
414 if rng.gen::<f64>() < threshold {
415 result[i] = false;
416 accumulated_matrix = temp0;
417 } else {
418 result[i] = true;
419 accumulated_matrix = temp1;
420 }
421 }
422
423 result
424 }
425}
426
427fn qr_decomposition(
429 matrix: &ArrayView2<Complex64>,
430) -> QuantRS2Result<(Array2<Complex64>, Array2<Complex64>)> {
431 let (m, n) = matrix.dim();
433 let mut q = Array2::zeros((m, n.min(m)));
434 let mut r = Array2::zeros((n.min(m), n));
435
436 for j in 0..n.min(m) {
437 let mut v = matrix.column(j).to_owned();
438
439 for i in 0..j {
441 let proj = q.column(i).dot(&v);
442 r[[i, j]] = proj;
443 v = v - &(proj * &q.column(i).to_owned());
444 }
445
446 let norm = (v.dot(&v)).sqrt();
447 if norm.norm() > 1e-10 {
448 r[[j, j]] = norm;
449 q.column_mut(j).assign(&(v / norm));
450 }
451 }
452
453 if n > m {
455 for j in m..n {
456 for i in 0..m {
457 r[[i, j]] = q.column(i).dot(&matrix.column(j));
458 }
459 }
460 }
461
462 Ok((q, r))
463}
464
465fn svd_decomposition(
467 matrix: &Array2<Complex64>,
468 max_bond: usize,
469 threshold: f64,
470) -> QuantRS2Result<(Array2<Complex64>, Array1<f64>, Array2<Complex64>)> {
471 let (m, n) = matrix.dim();
474 let k = m.min(n).min(max_bond);
475
476 let u = Array2::eye(m).slice(s![.., ..k]).to_owned();
477 let s = Array1::ones(k);
478 let vt = Array2::eye(n).slice(s![..k, ..]).to_owned();
479
480 Ok((u, s, vt))
481}
482
483pub struct MPSSimulator {
485 max_bond_dimension: usize,
487 truncation_threshold: f64,
489}
490
491impl MPSSimulator {
492 pub fn new(max_bond_dimension: usize) -> Self {
494 Self {
495 max_bond_dimension,
496 truncation_threshold: 1e-10,
497 }
498 }
499
500 pub fn set_truncation_threshold(&mut self, threshold: f64) {
502 self.truncation_threshold = threshold;
503 }
504}
505
506impl<const N: usize> Simulator<N> for MPSSimulator {
507 fn run(&self, circuit: &Circuit<N>) -> QuantRS2Result<Register<N>> {
508 let mut mps = MPS::new(N, self.max_bond_dimension);
510 mps.set_truncation_threshold(self.truncation_threshold);
511
512 Ok(Register::new())
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522 use quantrs2_core::gate::single::Hadamard;
523
524 #[test]
525 fn test_mps_creation() {
526 let mps = MPS::new(4, 10);
527 assert_eq!(mps.num_qubits, 4);
528 assert_eq!(mps.tensors.len(), 4);
529 }
530
531 #[test]
532 fn test_single_qubit_gate() {
533 let mut mps = MPS::new(1, 10);
534 let h = Hadamard {
535 target: QubitId::new(0),
536 };
537
538 mps.apply_single_qubit_gate(&h, 0).unwrap();
539
540 let amp0 = mps.get_amplitude(&[false]).unwrap();
542 let amp1 = mps.get_amplitude(&[true]).unwrap();
543
544 let expected = 1.0 / 2.0_f64.sqrt();
545 assert!((amp0.re - expected).abs() < 1e-10);
546 assert!((amp1.re - expected).abs() < 1e-10);
547 }
548
549 #[test]
550 fn test_orthogonality_center() {
551 let mut mps = MPS::new(5, 10);
552
553 mps.move_orthogonality_center(2).unwrap();
554 assert_eq!(mps.orthogonality_center, 2);
555
556 mps.move_orthogonality_center(4).unwrap();
557 assert_eq!(mps.orthogonality_center, 4);
558
559 mps.move_orthogonality_center(0).unwrap();
560 assert_eq!(mps.orthogonality_center, 0);
561 }
562}