1use quantrs2_circuit::builder::{Circuit, Simulator};
6use quantrs2_core::{
7 error::{QuantRS2Error, QuantRS2Result},
8 gate::GateOp,
9 register::Register,
10};
11use scirs2_core::ndarray::{array, s, Array2, Array3};
12use scirs2_core::random::prelude::*;
13use scirs2_core::random::{thread_rng, Rng};
14use scirs2_core::Complex64;
15use std::f64::consts::SQRT_2;
16
17#[derive(Debug, Clone)]
19pub struct BasicMPSConfig {
20 pub max_bond_dim: usize,
22 pub svd_threshold: f64,
24}
25
26impl Default for BasicMPSConfig {
27 fn default() -> Self {
28 Self {
29 max_bond_dim: 64,
30 svd_threshold: 1e-10,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37struct MPSTensor {
38 data: Array3<Complex64>,
40}
41
42impl MPSTensor {
43 fn zero_state(position: usize, num_qubits: usize) -> Self {
45 let is_first = position == 0;
46 let is_last = position == num_qubits - 1;
47
48 let data = if is_first && is_last {
49 let mut tensor = Array3::zeros((1, 2, 1));
51 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
52 tensor
53 } else if is_first {
54 let mut tensor = Array3::zeros((1, 2, 2));
56 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
57 tensor
58 } else if is_last {
59 let mut tensor = Array3::zeros((2, 2, 1));
61 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
62 tensor
63 } else {
64 let mut tensor = Array3::zeros((2, 2, 2));
66 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
67 tensor
68 };
69 Self { data }
70 }
71}
72
73pub struct BasicMPS {
75 tensors: Vec<MPSTensor>,
77 num_qubits: usize,
79 config: BasicMPSConfig,
81}
82
83impl BasicMPS {
84 #[must_use]
86 pub fn new(num_qubits: usize, config: BasicMPSConfig) -> Self {
87 let tensors = (0..num_qubits)
88 .map(|i| MPSTensor::zero_state(i, num_qubits))
89 .collect();
90
91 Self {
92 tensors,
93 num_qubits,
94 config,
95 }
96 }
97
98 pub fn apply_single_qubit_gate(
100 &mut self,
101 gate_matrix: &Array2<Complex64>,
102 qubit: usize,
103 ) -> QuantRS2Result<()> {
104 if qubit >= self.num_qubits {
105 return Err(QuantRS2Error::InvalidQubitId(qubit as u32));
106 }
107
108 let tensor = &mut self.tensors[qubit];
109 let shape = tensor.data.shape();
110 let (left_dim, _, right_dim) = (shape[0], shape[1], shape[2]);
111
112 let mut new_data = Array3::zeros((left_dim, 2, right_dim));
113
114 for l in 0..left_dim {
116 for r in 0..right_dim {
117 for new_phys in 0..2 {
118 for old_phys in 0..2 {
119 new_data[[l, new_phys, r]] +=
120 gate_matrix[[new_phys, old_phys]] * tensor.data[[l, old_phys, r]];
121 }
122 }
123 }
124 }
125
126 tensor.data = new_data;
127 Ok(())
128 }
129
130 pub fn apply_two_qubit_gate(
132 &mut self,
133 gate_matrix: &Array2<Complex64>,
134 qubit1: usize,
135 qubit2: usize,
136 ) -> QuantRS2Result<()> {
137 if (qubit1 as i32 - qubit2 as i32).abs() != 1 {
138 return Err(QuantRS2Error::InvalidInput(
139 "MPS requires adjacent qubits for two-qubit gates".to_string(),
140 ));
141 }
142
143 let (left_q, right_q) = if qubit1 < qubit2 {
144 (qubit1, qubit2)
145 } else {
146 (qubit2, qubit1)
147 };
148
149 let left_shape = self.tensors[left_q].data.shape().to_vec();
153 let right_shape = self.tensors[right_q].data.shape().to_vec();
154
155 let mut combined = Array3::<Complex64>::zeros((left_shape[0], 4, right_shape[2]));
157
158 for l in 0..left_shape[0] {
159 for r in 0..right_shape[2] {
160 for i in 0..2 {
161 for j in 0..2 {
162 for m in 0..left_shape[2] {
163 combined[[l, i * 2 + j, r]] += self.tensors[left_q].data[[l, i, m]]
164 * self.tensors[right_q].data[[m, j, r]];
165 }
166 }
167 }
168 }
169 }
170
171 let mut result = Array3::<Complex64>::zeros((left_shape[0], 4, right_shape[2]));
173 for l in 0..left_shape[0] {
174 for r in 0..right_shape[2] {
175 for out_idx in 0..4 {
176 for in_idx in 0..4 {
177 result[[l, out_idx, r]] +=
178 gate_matrix[[out_idx, in_idx]] * combined[[l, in_idx, r]];
179 }
180 }
181 }
182 }
183
184 let new_bond = 2.min(self.config.max_bond_dim);
187
188 let mut left_new = Array3::zeros((left_shape[0], 2, new_bond));
189 let mut right_new = Array3::zeros((new_bond, 2, right_shape[2]));
190
191 for l in 0..left_shape[0] {
193 for r in 0..right_shape[2] {
194 for i in 0..2 {
195 for j in 0..2 {
196 let bond_idx = (i + j) % new_bond;
197 left_new[[l, i, bond_idx]] = result[[l, i * 2 + j, r]];
198 right_new[[bond_idx, j, r]] = Complex64::new(1.0, 0.0);
199 }
200 }
201 }
202 }
203
204 self.tensors[left_q].data = left_new;
205 self.tensors[right_q].data = right_new;
206
207 Ok(())
208 }
209
210 pub fn get_amplitude(&self, bitstring: &[bool]) -> QuantRS2Result<Complex64> {
212 if bitstring.len() != self.num_qubits {
213 return Err(QuantRS2Error::InvalidInput(format!(
214 "Bitstring length {} doesn't match qubit count {}",
215 bitstring.len(),
216 self.num_qubits
217 )));
218 }
219
220 let mut result = Array2::from_elem((1, 1), Complex64::new(1.0, 0.0));
222
223 for (i, &bit) in bitstring.iter().enumerate() {
224 let tensor = &self.tensors[i];
225 let physical_idx = i32::from(bit);
226
227 let matrix = tensor.data.slice(s![.., physical_idx, ..]);
229
230 result = result.dot(&matrix);
232 }
233
234 Ok(result[[0, 0]])
235 }
236
237 #[must_use]
239 pub fn sample(&self) -> Vec<bool> {
240 let mut rng = thread_rng();
241 let mut result = vec![false; self.num_qubits];
242 let mut accumulated = Array2::from_elem((1, 1), Complex64::new(1.0, 0.0));
243
244 for (i, tensor) in self.tensors.iter().enumerate() {
245 let matrix0 = tensor.data.slice(s![.., 0, ..]);
247 let matrix1 = tensor.data.slice(s![.., 1, ..]);
248
249 let branch0: Array2<Complex64> = accumulated.dot(&matrix0);
250 let branch1: Array2<Complex64> = accumulated.dot(&matrix1);
251
252 let norm0_sq: f64 = branch0.iter().map(scirs2_core::Complex::norm_sqr).sum();
254 let norm1_sq: f64 = branch1.iter().map(scirs2_core::Complex::norm_sqr).sum();
255
256 let total = norm0_sq + norm1_sq;
257 let prob0 = norm0_sq / total;
258
259 if rng.gen::<f64>() < prob0 {
260 result[i] = false;
261 accumulated = branch0;
262 } else {
263 result[i] = true;
264 accumulated = branch1;
265 }
266
267 let norm_sq: f64 = accumulated.iter().map(scirs2_core::Complex::norm_sqr).sum();
269 if norm_sq > 0.0 {
270 accumulated /= Complex64::new(norm_sq.sqrt(), 0.0);
271 }
272 }
273
274 result
275 }
276}
277
278pub struct BasicMPSSimulator {
280 config: BasicMPSConfig,
281}
282
283impl BasicMPSSimulator {
284 #[must_use]
286 pub const fn new(config: BasicMPSConfig) -> Self {
287 Self { config }
288 }
289
290 #[must_use]
292 pub fn default() -> Self {
293 Self::new(BasicMPSConfig::default())
294 }
295}
296
297impl<const N: usize> Simulator<N> for BasicMPSSimulator {
298 fn run(&self, circuit: &Circuit<N>) -> QuantRS2Result<Register<N>> {
299 let mut mps = BasicMPS::new(N, self.config.clone());
301
302 for gate in circuit.gates() {
304 match gate.name() {
305 "H" => {
306 let h_matrix = {
307 let h = 1.0 / SQRT_2;
308 array![
309 [Complex64::new(h, 0.), Complex64::new(h, 0.)],
310 [Complex64::new(h, 0.), Complex64::new(-h, 0.)]
311 ]
312 };
313 if let Some(&qubit) = gate.qubits().first() {
314 mps.apply_single_qubit_gate(&h_matrix, qubit.id() as usize)?;
315 }
316 }
317 "X" => {
318 let x_matrix = array![
319 [Complex64::new(0., 0.), Complex64::new(1., 0.)],
320 [Complex64::new(1., 0.), Complex64::new(0., 0.)]
321 ];
322 if let Some(&qubit) = gate.qubits().first() {
323 mps.apply_single_qubit_gate(&x_matrix, qubit.id() as usize)?;
324 }
325 }
326 "CNOT" | "CX" => {
327 let cnot_matrix = array![
328 [
329 Complex64::new(1., 0.),
330 Complex64::new(0., 0.),
331 Complex64::new(0., 0.),
332 Complex64::new(0., 0.)
333 ],
334 [
335 Complex64::new(0., 0.),
336 Complex64::new(1., 0.),
337 Complex64::new(0., 0.),
338 Complex64::new(0., 0.)
339 ],
340 [
341 Complex64::new(0., 0.),
342 Complex64::new(0., 0.),
343 Complex64::new(0., 0.),
344 Complex64::new(1., 0.)
345 ],
346 [
347 Complex64::new(0., 0.),
348 Complex64::new(0., 0.),
349 Complex64::new(1., 0.),
350 Complex64::new(0., 0.)
351 ],
352 ];
353 let qubits = gate.qubits();
354 if qubits.len() == 2 {
355 mps.apply_two_qubit_gate(
356 &cnot_matrix,
357 qubits[0].id() as usize,
358 qubits[1].id() as usize,
359 )?;
360 }
361 }
362 _ => {
363 }
365 }
366 }
367
368 Ok(Register::new())
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_basic_mps_initialization() {
380 let mps = BasicMPS::new(4, BasicMPSConfig::default());
381
382 let amp = mps
384 .get_amplitude(&[false, false, false, false])
385 .expect("Failed to get amplitude for |0000>");
386 assert!((amp.norm() - 1.0).abs() < 1e-10);
387
388 let amp = mps
389 .get_amplitude(&[true, false, false, false])
390 .expect("Failed to get amplitude for |1000>");
391 assert!(amp.norm() < 1e-10);
392 }
393
394 #[test]
395 fn test_single_qubit_gate() {
396 let mut mps = BasicMPS::new(3, BasicMPSConfig::default());
397
398 let x_matrix = array![
400 [Complex64::new(0., 0.), Complex64::new(1., 0.)],
401 [Complex64::new(1., 0.), Complex64::new(0., 0.)]
402 ];
403 mps.apply_single_qubit_gate(&x_matrix, 0)
404 .expect("Failed to apply X gate");
405
406 let amp = mps
408 .get_amplitude(&[true, false, false])
409 .expect("Failed to get amplitude for |100>");
410 assert!((amp.norm() - 1.0).abs() < 1e-10);
411 }
412}