1use scirs2_core::ndarray::{array, s, Array2, Array3};
6use scirs2_core::Complex64;
7use quantrs2_circuit::builder::{Circuit, Simulator};
8use quantrs2_core::{
9 error::{QuantRS2Error, QuantRS2Result},
10 gate::GateOp,
11 register::Register,
12};
13use scirs2_core::random::{thread_rng, Rng};
14use std::f64::consts::SQRT_2;
15use scirs2_core::random::prelude::*;
16
17#[derive(Debug, Clone)]
19pub struct BasicMPSConfig {
20 pub max_bond_dim: usize,
22 pub svd_threshold: f64,
24}
25
26impl Default for BasicMPSConfig {
27 fn default() -> Self {
28 BasicMPSConfig {
29 max_bond_dim: 64,
30 svd_threshold: 1e-10,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
37struct MPSTensor {
38 data: Array3<Complex64>,
40}
41
42impl MPSTensor {
43 fn zero_state(position: usize, num_qubits: usize) -> Self {
45 let is_first = position == 0;
46 let is_last = position == num_qubits - 1;
47
48 let data = if is_first && is_last {
49 let mut tensor = Array3::zeros((1, 2, 1));
51 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
52 tensor
53 } else if is_first {
54 let mut tensor = Array3::zeros((1, 2, 2));
56 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
57 tensor
58 } else if is_last {
59 let mut tensor = Array3::zeros((2, 2, 1));
61 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
62 tensor
63 } else {
64 let mut tensor = Array3::zeros((2, 2, 2));
66 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
67 tensor
68 };
69 Self { data }
70 }
71}
72
73pub struct BasicMPS {
75 tensors: Vec<MPSTensor>,
77 num_qubits: usize,
79 config: BasicMPSConfig,
81}
82
83impl BasicMPS {
84 pub fn new(num_qubits: usize, config: BasicMPSConfig) -> Self {
86 let tensors = (0..num_qubits)
87 .map(|i| MPSTensor::zero_state(i, num_qubits))
88 .collect();
89
90 Self {
91 tensors,
92 num_qubits,
93 config,
94 }
95 }
96
97 pub fn apply_single_qubit_gate(
99 &mut self,
100 gate_matrix: &Array2<Complex64>,
101 qubit: usize,
102 ) -> QuantRS2Result<()> {
103 if qubit >= self.num_qubits {
104 return Err(QuantRS2Error::InvalidQubitId(qubit as u32));
105 }
106
107 let tensor = &mut self.tensors[qubit];
108 let shape = tensor.data.shape();
109 let (left_dim, _, right_dim) = (shape[0], shape[1], shape[2]);
110
111 let mut new_data = Array3::zeros((left_dim, 2, right_dim));
112
113 for l in 0..left_dim {
115 for r in 0..right_dim {
116 for new_phys in 0..2 {
117 for old_phys in 0..2 {
118 new_data[[l, new_phys, r]] +=
119 gate_matrix[[new_phys, old_phys]] * tensor.data[[l, old_phys, r]];
120 }
121 }
122 }
123 }
124
125 tensor.data = new_data;
126 Ok(())
127 }
128
129 pub fn apply_two_qubit_gate(
131 &mut self,
132 gate_matrix: &Array2<Complex64>,
133 qubit1: usize,
134 qubit2: usize,
135 ) -> QuantRS2Result<()> {
136 if (qubit1 as i32 - qubit2 as i32).abs() != 1 {
137 return Err(QuantRS2Error::InvalidInput(
138 "MPS requires adjacent qubits for two-qubit gates".to_string(),
139 ));
140 }
141
142 let (left_q, right_q) = if qubit1 < qubit2 {
143 (qubit1, qubit2)
144 } else {
145 (qubit2, qubit1)
146 };
147
148 let left_shape = self.tensors[left_q].data.shape().to_vec();
152 let right_shape = self.tensors[right_q].data.shape().to_vec();
153
154 let mut combined = Array3::<Complex64>::zeros((left_shape[0], 4, right_shape[2]));
156
157 for l in 0..left_shape[0] {
158 for r in 0..right_shape[2] {
159 for i in 0..2 {
160 for j in 0..2 {
161 for m in 0..left_shape[2] {
162 combined[[l, i * 2 + j, r]] += self.tensors[left_q].data[[l, i, m]]
163 * self.tensors[right_q].data[[m, j, r]];
164 }
165 }
166 }
167 }
168 }
169
170 let mut result = Array3::<Complex64>::zeros((left_shape[0], 4, right_shape[2]));
172 for l in 0..left_shape[0] {
173 for r in 0..right_shape[2] {
174 for out_idx in 0..4 {
175 for in_idx in 0..4 {
176 result[[l, out_idx, r]] +=
177 gate_matrix[[out_idx, in_idx]] * combined[[l, in_idx, r]];
178 }
179 }
180 }
181 }
182
183 let new_bond = 2.min(self.config.max_bond_dim);
186
187 let mut left_new = Array3::zeros((left_shape[0], 2, new_bond));
188 let mut right_new = Array3::zeros((new_bond, 2, right_shape[2]));
189
190 for l in 0..left_shape[0] {
192 for r in 0..right_shape[2] {
193 for i in 0..2 {
194 for j in 0..2 {
195 let bond_idx = (i + j) % new_bond;
196 left_new[[l, i, bond_idx]] = result[[l, i * 2 + j, r]];
197 right_new[[bond_idx, j, r]] = Complex64::new(1.0, 0.0);
198 }
199 }
200 }
201 }
202
203 self.tensors[left_q].data = left_new;
204 self.tensors[right_q].data = right_new;
205
206 Ok(())
207 }
208
209 pub fn get_amplitude(&self, bitstring: &[bool]) -> QuantRS2Result<Complex64> {
211 if bitstring.len() != self.num_qubits {
212 return Err(QuantRS2Error::InvalidInput(format!(
213 "Bitstring length {} doesn't match qubit count {}",
214 bitstring.len(),
215 self.num_qubits
216 )));
217 }
218
219 let mut result = Array2::from_elem((1, 1), Complex64::new(1.0, 0.0));
221
222 for (i, &bit) in bitstring.iter().enumerate() {
223 let tensor = &self.tensors[i];
224 let physical_idx = if bit { 1 } else { 0 };
225
226 let matrix = tensor.data.slice(s![.., physical_idx, ..]);
228
229 result = result.dot(&matrix);
231 }
232
233 Ok(result[[0, 0]])
234 }
235
236 pub fn sample(&self) -> Vec<bool> {
238 let mut rng = thread_rng();
239 let mut result = vec![false; self.num_qubits];
240 let mut accumulated = Array2::from_elem((1, 1), Complex64::new(1.0, 0.0));
241
242 for i in 0..self.num_qubits {
243 let tensor = &self.tensors[i];
244
245 let matrix0 = tensor.data.slice(s![.., 0, ..]);
247 let matrix1 = tensor.data.slice(s![.., 1, ..]);
248
249 let branch0: Array2<Complex64> = accumulated.dot(&matrix0);
250 let branch1: Array2<Complex64> = accumulated.dot(&matrix1);
251
252 let norm0_sq: f64 = branch0.iter().map(|x| x.norm_sqr()).sum();
254 let norm1_sq: f64 = branch1.iter().map(|x| x.norm_sqr()).sum();
255
256 let total = norm0_sq + norm1_sq;
257 let prob0 = norm0_sq / total;
258
259 if rng.gen::<f64>() < prob0 {
260 result[i] = false;
261 accumulated = branch0;
262 } else {
263 result[i] = true;
264 accumulated = branch1;
265 }
266
267 let norm_sq: f64 = accumulated.iter().map(|x| x.norm_sqr()).sum();
269 if norm_sq > 0.0 {
270 accumulated /= Complex64::new(norm_sq.sqrt(), 0.0);
271 }
272 }
273
274 result
275 }
276}
277
278pub struct BasicMPSSimulator {
280 config: BasicMPSConfig,
281}
282
283impl BasicMPSSimulator {
284 pub fn new(config: BasicMPSConfig) -> Self {
286 Self { config }
287 }
288
289 pub fn default() -> Self {
291 Self::new(BasicMPSConfig::default())
292 }
293}
294
295impl<const N: usize> Simulator<N> for BasicMPSSimulator {
296 fn run(&self, circuit: &Circuit<N>) -> QuantRS2Result<Register<N>> {
297 let mut mps = BasicMPS::new(N, self.config.clone());
299
300 for gate in circuit.gates() {
302 match gate.name().as_ref() {
303 "H" => {
304 let h_matrix = {
305 let h = 1.0 / SQRT_2;
306 array![
307 [Complex64::new(h, 0.), Complex64::new(h, 0.)],
308 [Complex64::new(h, 0.), Complex64::new(-h, 0.)]
309 ]
310 };
311 if let Some(&qubit) = gate.qubits().first() {
312 mps.apply_single_qubit_gate(&h_matrix, qubit.id() as usize)?;
313 }
314 }
315 "X" => {
316 let x_matrix = array![
317 [Complex64::new(0., 0.), Complex64::new(1., 0.)],
318 [Complex64::new(1., 0.), Complex64::new(0., 0.)]
319 ];
320 if let Some(&qubit) = gate.qubits().first() {
321 mps.apply_single_qubit_gate(&x_matrix, qubit.id() as usize)?;
322 }
323 }
324 "CNOT" | "CX" => {
325 let cnot_matrix = array![
326 [
327 Complex64::new(1., 0.),
328 Complex64::new(0., 0.),
329 Complex64::new(0., 0.),
330 Complex64::new(0., 0.)
331 ],
332 [
333 Complex64::new(0., 0.),
334 Complex64::new(1., 0.),
335 Complex64::new(0., 0.),
336 Complex64::new(0., 0.)
337 ],
338 [
339 Complex64::new(0., 0.),
340 Complex64::new(0., 0.),
341 Complex64::new(0., 0.),
342 Complex64::new(1., 0.)
343 ],
344 [
345 Complex64::new(0., 0.),
346 Complex64::new(0., 0.),
347 Complex64::new(1., 0.),
348 Complex64::new(0., 0.)
349 ],
350 ];
351 let qubits = gate.qubits();
352 if qubits.len() == 2 {
353 mps.apply_two_qubit_gate(
354 &cnot_matrix,
355 qubits[0].id() as usize,
356 qubits[1].id() as usize,
357 )?;
358 }
359 }
360 _ => {
361 }
363 }
364 }
365
366 Ok(Register::new())
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn test_basic_mps_initialization() {
378 let mps = BasicMPS::new(4, BasicMPSConfig::default());
379
380 let amp = mps.get_amplitude(&[false, false, false, false]).unwrap();
382 assert!((amp.norm() - 1.0).abs() < 1e-10);
383
384 let amp = mps.get_amplitude(&[true, false, false, false]).unwrap();
385 assert!(amp.norm() < 1e-10);
386 }
387
388 #[test]
389 fn test_single_qubit_gate() {
390 let mut mps = BasicMPS::new(3, BasicMPSConfig::default());
391
392 let x_matrix = array![
394 [Complex64::new(0., 0.), Complex64::new(1., 0.)],
395 [Complex64::new(1., 0.), Complex64::new(0., 0.)]
396 ];
397 mps.apply_single_qubit_gate(&x_matrix, 0).unwrap();
398
399 let amp = mps.get_amplitude(&[true, false, false]).unwrap();
401 assert!((amp.norm() - 1.0).abs() < 1e-10);
402 }
403}