1use quantrs2_circuit::builder::{Circuit, Simulator};
8use quantrs2_core::{
9 error::{QuantRS2Error, QuantRS2Result},
10 gate::GateOp,
11 prelude::QubitId,
12 register::Register,
13};
14use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2};
15use scirs2_core::Complex64;
16
17#[derive(Debug, Clone)]
19struct MPSTensor {
20 data: Array3<Complex64>,
22 left_dim: usize,
24 right_dim: usize,
26}
27
28impl MPSTensor {
29 fn new(data: Array3<Complex64>) -> Self {
31 let shape = data.shape();
32 Self {
33 left_dim: shape[0],
34 right_dim: shape[2],
35 data,
36 }
37 }
38
39 fn zero_state(is_first: bool, is_last: bool) -> Self {
41 let data = if is_first && is_last {
42 let mut tensor = Array3::zeros((1, 2, 1));
44 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
45 tensor
46 } else if is_first {
47 let mut tensor = Array3::zeros((1, 2, 2));
49 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
50 tensor
51 } else if is_last {
52 let mut tensor = Array3::zeros((2, 2, 1));
54 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
55 tensor
56 } else {
57 let mut tensor = Array3::zeros((2, 2, 2));
59 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
60 tensor
61 };
62 Self::new(data)
63 }
64}
65
66pub struct MPS {
68 tensors: Vec<MPSTensor>,
70 num_qubits: usize,
72 max_bond_dim: usize,
74 truncation_threshold: f64,
76 orthogonality_center: i32,
78}
79
80impl MPS {
81 #[must_use]
83 pub fn new(num_qubits: usize, max_bond_dim: usize) -> Self {
84 let tensors = (0..num_qubits)
85 .map(|i| MPSTensor::zero_state(i == 0, i == num_qubits - 1))
86 .collect();
87
88 Self {
89 tensors,
90 num_qubits,
91 max_bond_dim,
92 truncation_threshold: 1e-10,
93 orthogonality_center: -1,
94 }
95 }
96
97 pub const fn set_truncation_threshold(&mut self, threshold: f64) {
99 self.truncation_threshold = threshold;
100 }
101
102 pub fn move_orthogonality_center(&mut self, target: usize) -> QuantRS2Result<()> {
104 if target >= self.num_qubits {
105 return Err(QuantRS2Error::InvalidQubitId(target as u32));
106 }
107
108 if self.orthogonality_center < 0 {
110 self.left_canonicalize_up_to(target)?;
111 self.orthogonality_center = target as i32;
112 return Ok(());
113 }
114
115 let current = self.orthogonality_center as usize;
116
117 if current < target {
118 for i in current..target {
120 self.move_center_right(i)?;
121 }
122 } else if current > target {
123 for i in (target + 1..=current).rev() {
125 self.move_center_left(i)?;
126 }
127 }
128
129 self.orthogonality_center = target as i32;
130 Ok(())
131 }
132
133 fn left_canonicalize_up_to(&mut self, position: usize) -> QuantRS2Result<()> {
135 for i in 0..position {
136 let tensor = &self.tensors[i];
137 let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
138
139 let matrix = tensor
141 .data
142 .view()
143 .into_shape((left_dim * phys_dim, right_dim))?;
144
145 let (q, r) = qr_decomposition(&matrix)?;
147
148 let new_shape = (left_dim, phys_dim, q.shape()[1]);
150 self.tensors[i].data = q.into_shape(new_shape)?;
151 self.tensors[i].right_dim = new_shape.2;
152
153 if i + 1 < self.num_qubits {
155 let next = &mut self.tensors[i + 1];
156 let next_matrix = next
157 .data
158 .view()
159 .into_shape((next.left_dim, 2 * next.right_dim))?;
160 let new_matrix = r.dot(&next_matrix);
161 next.data = new_matrix.into_shape((r.shape()[0], 2, next.right_dim))?;
162 next.left_dim = r.shape()[0];
163 }
164 }
165 Ok(())
166 }
167
168 fn move_center_right(&mut self, position: usize) -> QuantRS2Result<()> {
170 let tensor = &self.tensors[position];
171 let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
172
173 let matrix = tensor
175 .data
176 .view()
177 .into_shape((left_dim * phys_dim, right_dim))?;
178 let (q, r) = qr_decomposition(&matrix)?;
179
180 let q_cols = q.shape()[1];
182 self.tensors[position].data = q.into_shape((left_dim, phys_dim, q_cols))?;
183 self.tensors[position].right_dim = q_cols;
184
185 if position + 1 < self.num_qubits {
187 let next = &mut self.tensors[position + 1];
188 let next_matrix = next
189 .data
190 .view()
191 .into_shape((next.left_dim, 2 * next.right_dim))?;
192 let new_matrix = r.dot(&next_matrix);
193 next.data = new_matrix.into_shape((r.shape()[0], 2, next.right_dim))?;
194 next.left_dim = r.shape()[0];
195 }
196
197 Ok(())
198 }
199
200 fn move_center_left(&mut self, position: usize) -> QuantRS2Result<()> {
202 let tensor = &self.tensors[position];
203 let (left_dim, phys_dim, right_dim) = (tensor.left_dim, 2, tensor.right_dim);
204
205 let matrix = tensor
207 .data
208 .view()
209 .permuted_axes([2, 1, 0])
210 .into_shape((right_dim * phys_dim, left_dim))?;
211 let (q, r) = qr_decomposition(&matrix)?;
212
213 let q_cols = q.shape()[1];
215 let q_reshaped = q.into_shape((right_dim, phys_dim, q_cols))?;
216 self.tensors[position].data = q_reshaped.permuted_axes([2, 1, 0]);
217 self.tensors[position].left_dim = q_cols;
218
219 if position > 0 {
221 let prev = &mut self.tensors[position - 1];
222 let prev_matrix = prev
223 .data
224 .view()
225 .into_shape((prev.left_dim * 2, prev.right_dim))?;
226 let new_matrix = prev_matrix.dot(&r.t());
227 prev.data = new_matrix.into_shape((prev.left_dim, 2, r.shape()[0]))?;
228 prev.right_dim = r.shape()[0];
229 }
230
231 Ok(())
232 }
233
234 pub fn apply_single_qubit_gate(
236 &mut self,
237 gate: &dyn GateOp,
238 qubit: usize,
239 ) -> QuantRS2Result<()> {
240 if qubit >= self.num_qubits {
241 return Err(QuantRS2Error::InvalidQubitId(qubit as u32));
242 }
243
244 let gate_matrix = gate.matrix()?;
246 let gate_array = Array2::from_shape_vec((2, 2), gate_matrix)?;
247
248 let tensor = &mut self.tensors[qubit];
250 let mut new_data = Array3::zeros(tensor.data.dim());
251
252 for left in 0..tensor.left_dim {
253 for right in 0..tensor.right_dim {
254 for i in 0..2 {
255 for j in 0..2 {
256 new_data[[left, i, right]] +=
257 gate_array[[i, j]] * tensor.data[[left, j, right]];
258 }
259 }
260 }
261 }
262
263 tensor.data = new_data;
264 Ok(())
265 }
266
267 pub fn apply_two_qubit_gate(
269 &mut self,
270 gate: &dyn GateOp,
271 qubit1: usize,
272 qubit2: usize,
273 ) -> QuantRS2Result<()> {
274 if (qubit1 as i32 - qubit2 as i32).abs() != 1 {
276 return Err(QuantRS2Error::ComputationError(
277 "MPS simulator requires adjacent qubits for two-qubit gates".to_string(),
278 ));
279 }
280
281 let (left_qubit, right_qubit) = if qubit1 < qubit2 {
282 (qubit1, qubit2)
283 } else {
284 (qubit2, qubit1)
285 };
286
287 self.move_orthogonality_center(left_qubit)?;
289
290 let gate_matrix = gate.matrix()?;
292 let gate_array = Array2::from_shape_vec((4, 4), gate_matrix)?;
293
294 let left_tensor = &self.tensors[left_qubit];
296 let right_tensor = &self.tensors[right_qubit];
297
298 let left_dim = left_tensor.left_dim;
299 let right_dim = right_tensor.right_dim;
300
301 let mut combined = Array3::<Complex64>::zeros((left_dim, 4, right_dim));
303 for l in 0..left_dim {
304 for r in 0..right_dim {
305 for i in 0..2 {
306 for j in 0..2 {
307 for k in 0..left_tensor.right_dim {
308 combined[[l, i * 2 + j, r]] +=
309 left_tensor.data[[l, i, k]] * right_tensor.data[[k, j, r]];
310 }
311 }
312 }
313 }
314 }
315
316 let mut gated = Array3::<Complex64>::zeros((left_dim, 4, right_dim));
318 for l in 0..left_dim {
319 for r in 0..right_dim {
320 for out_idx in 0..4 {
321 for in_idx in 0..4 {
322 gated[[l, out_idx, r]] +=
323 gate_array[[out_idx, in_idx]] * combined[[l, in_idx, r]];
324 }
325 }
326 }
327 }
328
329 let matrix = gated.into_shape((left_dim * 2, 2 * right_dim))?;
331 let (u, s, vt) = svd_decomposition(&matrix, self.max_bond_dim, self.truncation_threshold)?;
332
333 let new_bond = s.len();
335 self.tensors[left_qubit].data = u.into_shape((left_dim, 2, new_bond))?;
336 self.tensors[left_qubit].right_dim = new_bond;
337
338 let mut sv = Array2::<Complex64>::zeros((new_bond, vt.shape()[1]));
340 for i in 0..new_bond {
341 for j in 0..vt.shape()[1] {
342 sv[[i, j]] = Complex64::new(s[i], 0.0) * vt[[i, j]];
343 }
344 }
345 self.tensors[right_qubit].data = sv.t().to_owned().into_shape((new_bond, 2, right_dim))?;
346 self.tensors[right_qubit].left_dim = new_bond;
347
348 self.orthogonality_center = right_qubit as i32;
349
350 Ok(())
351 }
352
353 pub fn get_amplitude(&self, bitstring: &[bool]) -> QuantRS2Result<Complex64> {
355 if bitstring.len() != self.num_qubits {
356 return Err(QuantRS2Error::ComputationError(format!(
357 "Bitstring length {} doesn't match qubit count {}",
358 bitstring.len(),
359 self.num_qubits
360 )));
361 }
362
363 let mut result = Array2::eye(1);
365
366 for (i, &bit) in bitstring.iter().enumerate() {
367 let tensor = &self.tensors[i];
368 let idx = i32::from(bit);
369
370 let matrix = tensor.data.slice(s![.., idx, ..]);
372 result = result.dot(&matrix);
373 }
374
375 Ok(result[[0, 0]])
376 }
377
378 #[must_use]
380 pub fn sample(&self) -> Vec<bool> {
381 use scirs2_core::random::prelude::*;
382 let mut rng = thread_rng();
383 let mut result = vec![false; self.num_qubits];
384 let mut accumulated_matrix = Array2::eye(1);
385
386 for (i, tensor) in self.tensors.iter().enumerate() {
387 let mut prob0 = Complex64::new(0.0, 0.0);
389 let mut prob1 = Complex64::new(0.0, 0.0);
390
391 let matrix0 = tensor.data.slice(s![.., 0, ..]);
393 let temp0: Array2<Complex64> = accumulated_matrix.dot(&matrix0);
394
395 let mut right_contract = Array2::eye(temp0.shape()[1]);
397 for j in (i + 1)..self.num_qubits {
398 let sum_matrix = self.tensors[j].data.slice(s![.., 0, ..]).to_owned()
399 + self.tensors[j].data.slice(s![.., 1, ..]).to_owned();
400 right_contract = right_contract.dot(&sum_matrix);
401 }
402
403 prob0 = temp0.dot(&right_contract)[[0, 0]];
404
405 let matrix1 = tensor.data.slice(s![.., 1, ..]);
407 let temp1: Array2<Complex64> = accumulated_matrix.dot(&matrix1);
408 prob1 = temp1.dot(&right_contract)[[0, 0]];
409
410 let total = prob0.norm_sqr() + prob1.norm_sqr();
412 let threshold = prob0.norm_sqr() / total;
413
414 if rng.gen::<f64>() < threshold {
415 result[i] = false;
416 accumulated_matrix = temp0;
417 } else {
418 result[i] = true;
419 accumulated_matrix = temp1;
420 }
421 }
422
423 result
424 }
425}
426
427fn qr_decomposition(
429 matrix: &ArrayView2<Complex64>,
430) -> QuantRS2Result<(Array2<Complex64>, Array2<Complex64>)> {
431 let (m, n) = matrix.dim();
433 let mut q = Array2::zeros((m, n.min(m)));
434 let mut r = Array2::zeros((n.min(m), n));
435
436 for j in 0..n.min(m) {
437 let mut v = matrix.column(j).to_owned();
438
439 for i in 0..j {
441 let proj = q.column(i).dot(&v);
442 r[[i, j]] = proj;
443 v -= &(proj * &q.column(i).to_owned());
444 }
445
446 let norm = (v.dot(&v)).sqrt();
447 if norm.norm() > 1e-10 {
448 r[[j, j]] = norm;
449 q.column_mut(j).assign(&(v / norm));
450 }
451 }
452
453 if n > m {
455 for j in m..n {
456 for i in 0..m {
457 r[[i, j]] = q.column(i).dot(&matrix.column(j));
458 }
459 }
460 }
461
462 Ok((q, r))
463}
464
465fn svd_decomposition(
467 matrix: &Array2<Complex64>,
468 max_bond: usize,
469 threshold: f64,
470) -> QuantRS2Result<(Array2<Complex64>, Array1<f64>, Array2<Complex64>)> {
471 let (m, n) = matrix.dim();
474 let k = m.min(n).min(max_bond);
475
476 let u = Array2::eye(m).slice(s![.., ..k]).to_owned();
477 let s = Array1::ones(k);
478 let vt = Array2::eye(n).slice(s![..k, ..]).to_owned();
479
480 Ok((u, s, vt))
481}
482
483pub struct MPSSimulator {
485 max_bond_dimension: usize,
487 truncation_threshold: f64,
489}
490
491impl MPSSimulator {
492 #[must_use]
494 pub const fn new(max_bond_dimension: usize) -> Self {
495 Self {
496 max_bond_dimension,
497 truncation_threshold: 1e-10,
498 }
499 }
500
501 pub const fn set_truncation_threshold(&mut self, threshold: f64) {
503 self.truncation_threshold = threshold;
504 }
505}
506
507impl<const N: usize> Simulator<N> for MPSSimulator {
508 fn run(&self, circuit: &Circuit<N>) -> QuantRS2Result<Register<N>> {
509 let mut mps = MPS::new(N, self.max_bond_dimension);
511 mps.set_truncation_threshold(self.truncation_threshold);
512
513 Ok(Register::new())
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use quantrs2_core::gate::single::Hadamard;
524
525 #[test]
526 fn test_mps_creation() {
527 let mps = MPS::new(4, 10);
528 assert_eq!(mps.num_qubits, 4);
529 assert_eq!(mps.tensors.len(), 4);
530 }
531
532 #[test]
533 fn test_single_qubit_gate() {
534 let mut mps = MPS::new(1, 10);
535 let h = Hadamard {
536 target: QubitId::new(0),
537 };
538
539 mps.apply_single_qubit_gate(&h, 0)
540 .expect("Failed to apply single qubit gate");
541
542 let amp0 = mps
544 .get_amplitude(&[false])
545 .expect("Failed to get amplitude for |0>");
546 let amp1 = mps
547 .get_amplitude(&[true])
548 .expect("Failed to get amplitude for |1>");
549
550 let expected = 1.0 / 2.0_f64.sqrt();
551 assert!((amp0.re - expected).abs() < 1e-10);
552 assert!((amp1.re - expected).abs() < 1e-10);
553 }
554
555 #[test]
556 fn test_orthogonality_center() {
557 let mut mps = MPS::new(5, 10);
558
559 mps.move_orthogonality_center(2)
560 .expect("Failed to move orthogonality center to 2");
561 assert_eq!(mps.orthogonality_center, 2);
562
563 mps.move_orthogonality_center(4)
564 .expect("Failed to move orthogonality center to 4");
565 assert_eq!(mps.orthogonality_center, 4);
566
567 mps.move_orthogonality_center(0)
568 .expect("Failed to move orthogonality center to 0");
569 assert_eq!(mps.orthogonality_center, 0);
570 }
571}