1use ndarray::{array, s, Array2, Array3};
6use num_complex::Complex64;
7use quantrs2_circuit::builder::{Circuit, Simulator};
8use quantrs2_core::{
9 error::{QuantRS2Error, QuantRS2Result},
10 gate::GateOp,
11 register::Register,
12};
13use rand::{thread_rng, Rng};
14use std::f64::consts::SQRT_2;
15
16#[derive(Debug, Clone)]
18pub struct BasicMPSConfig {
19 pub max_bond_dim: usize,
21 pub svd_threshold: f64,
23}
24
25impl Default for BasicMPSConfig {
26 fn default() -> Self {
27 BasicMPSConfig {
28 max_bond_dim: 64,
29 svd_threshold: 1e-10,
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
36struct MPSTensor {
37 data: Array3<Complex64>,
39}
40
41impl MPSTensor {
42 fn zero_state(position: usize, num_qubits: usize) -> Self {
44 let is_first = position == 0;
45 let is_last = position == num_qubits - 1;
46
47 let data = if is_first && is_last {
48 let mut tensor = Array3::zeros((1, 2, 1));
50 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
51 tensor
52 } else if is_first {
53 let mut tensor = Array3::zeros((1, 2, 2));
55 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
56 tensor
57 } else if is_last {
58 let mut tensor = Array3::zeros((2, 2, 1));
60 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
61 tensor
62 } else {
63 let mut tensor = Array3::zeros((2, 2, 2));
65 tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
66 tensor
67 };
68 Self { data }
69 }
70}
71
72pub struct BasicMPS {
74 tensors: Vec<MPSTensor>,
76 num_qubits: usize,
78 config: BasicMPSConfig,
80}
81
82impl BasicMPS {
83 pub fn new(num_qubits: usize, config: BasicMPSConfig) -> Self {
85 let tensors = (0..num_qubits)
86 .map(|i| MPSTensor::zero_state(i, num_qubits))
87 .collect();
88
89 Self {
90 tensors,
91 num_qubits,
92 config,
93 }
94 }
95
96 pub fn apply_single_qubit_gate(
98 &mut self,
99 gate_matrix: &Array2<Complex64>,
100 qubit: usize,
101 ) -> QuantRS2Result<()> {
102 if qubit >= self.num_qubits {
103 return Err(QuantRS2Error::InvalidQubitId(qubit as u32));
104 }
105
106 let tensor = &mut self.tensors[qubit];
107 let shape = tensor.data.shape();
108 let (left_dim, _, right_dim) = (shape[0], shape[1], shape[2]);
109
110 let mut new_data = Array3::zeros((left_dim, 2, right_dim));
111
112 for l in 0..left_dim {
114 for r in 0..right_dim {
115 for new_phys in 0..2 {
116 for old_phys in 0..2 {
117 new_data[[l, new_phys, r]] +=
118 gate_matrix[[new_phys, old_phys]] * tensor.data[[l, old_phys, r]];
119 }
120 }
121 }
122 }
123
124 tensor.data = new_data;
125 Ok(())
126 }
127
128 pub fn apply_two_qubit_gate(
130 &mut self,
131 gate_matrix: &Array2<Complex64>,
132 qubit1: usize,
133 qubit2: usize,
134 ) -> QuantRS2Result<()> {
135 if (qubit1 as i32 - qubit2 as i32).abs() != 1 {
136 return Err(QuantRS2Error::InvalidInput(
137 "MPS requires adjacent qubits for two-qubit gates".to_string(),
138 ));
139 }
140
141 let (left_q, right_q) = if qubit1 < qubit2 {
142 (qubit1, qubit2)
143 } else {
144 (qubit2, qubit1)
145 };
146
147 let left_shape = self.tensors[left_q].data.shape().to_vec();
151 let right_shape = self.tensors[right_q].data.shape().to_vec();
152
153 let mut combined = Array3::<Complex64>::zeros((left_shape[0], 4, right_shape[2]));
155
156 for l in 0..left_shape[0] {
157 for r in 0..right_shape[2] {
158 for i in 0..2 {
159 for j in 0..2 {
160 for m in 0..left_shape[2] {
161 combined[[l, i * 2 + j, r]] += self.tensors[left_q].data[[l, i, m]]
162 * self.tensors[right_q].data[[m, j, r]];
163 }
164 }
165 }
166 }
167 }
168
169 let mut result = Array3::<Complex64>::zeros((left_shape[0], 4, right_shape[2]));
171 for l in 0..left_shape[0] {
172 for r in 0..right_shape[2] {
173 for out_idx in 0..4 {
174 for in_idx in 0..4 {
175 result[[l, out_idx, r]] +=
176 gate_matrix[[out_idx, in_idx]] * combined[[l, in_idx, r]];
177 }
178 }
179 }
180 }
181
182 let new_bond = 2.min(self.config.max_bond_dim);
185
186 let mut left_new = Array3::zeros((left_shape[0], 2, new_bond));
187 let mut right_new = Array3::zeros((new_bond, 2, right_shape[2]));
188
189 for l in 0..left_shape[0] {
191 for r in 0..right_shape[2] {
192 for i in 0..2 {
193 for j in 0..2 {
194 let bond_idx = (i + j) % new_bond;
195 left_new[[l, i, bond_idx]] = result[[l, i * 2 + j, r]];
196 right_new[[bond_idx, j, r]] = Complex64::new(1.0, 0.0);
197 }
198 }
199 }
200 }
201
202 self.tensors[left_q].data = left_new;
203 self.tensors[right_q].data = right_new;
204
205 Ok(())
206 }
207
208 pub fn get_amplitude(&self, bitstring: &[bool]) -> QuantRS2Result<Complex64> {
210 if bitstring.len() != self.num_qubits {
211 return Err(QuantRS2Error::InvalidInput(format!(
212 "Bitstring length {} doesn't match qubit count {}",
213 bitstring.len(),
214 self.num_qubits
215 )));
216 }
217
218 let mut result = Array2::from_elem((1, 1), Complex64::new(1.0, 0.0));
220
221 for (i, &bit) in bitstring.iter().enumerate() {
222 let tensor = &self.tensors[i];
223 let physical_idx = if bit { 1 } else { 0 };
224
225 let matrix = tensor.data.slice(s![.., physical_idx, ..]);
227
228 result = result.dot(&matrix);
230 }
231
232 Ok(result[[0, 0]])
233 }
234
235 pub fn sample(&self) -> Vec<bool> {
237 let mut rng = thread_rng();
238 let mut result = vec![false; self.num_qubits];
239 let mut accumulated = Array2::from_elem((1, 1), Complex64::new(1.0, 0.0));
240
241 for i in 0..self.num_qubits {
242 let tensor = &self.tensors[i];
243
244 let matrix0 = tensor.data.slice(s![.., 0, ..]);
246 let matrix1 = tensor.data.slice(s![.., 1, ..]);
247
248 let branch0: Array2<Complex64> = accumulated.dot(&matrix0);
249 let branch1: Array2<Complex64> = accumulated.dot(&matrix1);
250
251 let norm0_sq: f64 = branch0.iter().map(|x| x.norm_sqr()).sum();
253 let norm1_sq: f64 = branch1.iter().map(|x| x.norm_sqr()).sum();
254
255 let total = norm0_sq + norm1_sq;
256 let prob0 = norm0_sq / total;
257
258 if rng.gen::<f64>() < prob0 {
259 result[i] = false;
260 accumulated = branch0;
261 } else {
262 result[i] = true;
263 accumulated = branch1;
264 }
265
266 let norm_sq: f64 = accumulated.iter().map(|x| x.norm_sqr()).sum();
268 if norm_sq > 0.0 {
269 accumulated /= Complex64::new(norm_sq.sqrt(), 0.0);
270 }
271 }
272
273 result
274 }
275}
276
277pub struct BasicMPSSimulator {
279 config: BasicMPSConfig,
280}
281
282impl BasicMPSSimulator {
283 pub fn new(config: BasicMPSConfig) -> Self {
285 Self { config }
286 }
287
288 pub fn default() -> Self {
290 Self::new(BasicMPSConfig::default())
291 }
292}
293
294impl<const N: usize> Simulator<N> for BasicMPSSimulator {
295 fn run(&self, circuit: &Circuit<N>) -> QuantRS2Result<Register<N>> {
296 let mut mps = BasicMPS::new(N, self.config.clone());
298
299 for gate in circuit.gates() {
301 match gate.name().as_ref() {
302 "H" => {
303 let h_matrix = {
304 let h = 1.0 / SQRT_2;
305 array![
306 [Complex64::new(h, 0.), Complex64::new(h, 0.)],
307 [Complex64::new(h, 0.), Complex64::new(-h, 0.)]
308 ]
309 };
310 if let Some(&qubit) = gate.qubits().first() {
311 mps.apply_single_qubit_gate(&h_matrix, qubit.id() as usize)?;
312 }
313 }
314 "X" => {
315 let x_matrix = array![
316 [Complex64::new(0., 0.), Complex64::new(1., 0.)],
317 [Complex64::new(1., 0.), Complex64::new(0., 0.)]
318 ];
319 if let Some(&qubit) = gate.qubits().first() {
320 mps.apply_single_qubit_gate(&x_matrix, qubit.id() as usize)?;
321 }
322 }
323 "CNOT" | "CX" => {
324 let cnot_matrix = array![
325 [
326 Complex64::new(1., 0.),
327 Complex64::new(0., 0.),
328 Complex64::new(0., 0.),
329 Complex64::new(0., 0.)
330 ],
331 [
332 Complex64::new(0., 0.),
333 Complex64::new(1., 0.),
334 Complex64::new(0., 0.),
335 Complex64::new(0., 0.)
336 ],
337 [
338 Complex64::new(0., 0.),
339 Complex64::new(0., 0.),
340 Complex64::new(0., 0.),
341 Complex64::new(1., 0.)
342 ],
343 [
344 Complex64::new(0., 0.),
345 Complex64::new(0., 0.),
346 Complex64::new(1., 0.),
347 Complex64::new(0., 0.)
348 ],
349 ];
350 let qubits = gate.qubits();
351 if qubits.len() == 2 {
352 mps.apply_two_qubit_gate(
353 &cnot_matrix,
354 qubits[0].id() as usize,
355 qubits[1].id() as usize,
356 )?;
357 }
358 }
359 _ => {
360 }
362 }
363 }
364
365 Ok(Register::new())
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_basic_mps_initialization() {
377 let mps = BasicMPS::new(4, BasicMPSConfig::default());
378
379 let amp = mps.get_amplitude(&[false, false, false, false]).unwrap();
381 assert!((amp.norm() - 1.0).abs() < 1e-10);
382
383 let amp = mps.get_amplitude(&[true, false, false, false]).unwrap();
384 assert!(amp.norm() < 1e-10);
385 }
386
387 #[test]
388 fn test_single_qubit_gate() {
389 let mut mps = BasicMPS::new(3, BasicMPSConfig::default());
390
391 let x_matrix = array![
393 [Complex64::new(0., 0.), Complex64::new(1., 0.)],
394 [Complex64::new(1., 0.), Complex64::new(0., 0.)]
395 ];
396 mps.apply_single_qubit_gate(&x_matrix, 0).unwrap();
397
398 let amp = mps.get_amplitude(&[true, false, false]).unwrap();
400 assert!((amp.norm() - 1.0).abs() < 1e-10);
401 }
402}