1use crate::error::{MLError, Result};
14use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
15use scirs2_core::Complex64;
16use std::collections::HashMap;
17
18use super::{CType, TQDevice, TQModule, TQParameter};
19
20#[derive(Debug, Clone)]
26pub struct TensorNetworkConfig {
27 pub max_bond_dim: usize,
29 pub truncation_threshold: f64,
31 pub use_canonical_form: bool,
33 pub compression: CompressionMethod,
35}
36
37impl Default for TensorNetworkConfig {
38 fn default() -> Self {
39 Self {
40 max_bond_dim: 64,
41 truncation_threshold: 1e-12,
42 use_canonical_form: true,
43 compression: CompressionMethod::SVD,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum CompressionMethod {
51 SVD,
53 QR,
55 Variational,
57}
58
59#[derive(Debug, Clone)]
67pub struct MPSTensor {
68 pub data: Array3<CType>,
70 pub site: usize,
72}
73
74impl MPSTensor {
75 pub fn new(data: Array3<CType>, site: usize) -> Self {
77 Self { data, site }
78 }
79
80 pub fn bond_dims(&self) -> (usize, usize) {
82 let shape = self.data.shape();
83 (shape[0], shape[2])
84 }
85
86 pub fn physical_dim(&self) -> usize {
88 self.data.shape()[1]
89 }
90
91 pub fn contract_right(&self, other: &MPSTensor) -> Array3<CType> {
93 let (d_left, phys_a, d_mid) = (
94 self.data.shape()[0],
95 self.data.shape()[1],
96 self.data.shape()[2],
97 );
98 let (_d_mid2, phys_b, d_right) = (
99 other.data.shape()[0],
100 other.data.shape()[1],
101 other.data.shape()[2],
102 );
103
104 let mut result = Array3::<CType>::zeros((d_left, phys_a * phys_b, d_right));
106
107 for i in 0..d_left {
108 for j in 0..phys_a {
109 for k in 0..d_mid {
110 for l in 0..phys_b {
111 for m in 0..d_right {
112 let combined_phys = j * phys_b + l;
113 result[[i, combined_phys, m]] +=
114 self.data[[i, j, k]] * other.data[[k, l, m]];
115 }
116 }
117 }
118 }
119 }
120
121 result
122 }
123}
124
125#[derive(Debug, Clone)]
131pub struct MatrixProductState {
132 pub tensors: Vec<MPSTensor>,
134 pub n_qubits: usize,
136 pub config: TensorNetworkConfig,
138 pub norm: f64,
140}
141
142impl MatrixProductState {
143 pub fn from_computational_basis(n_qubits: usize, state: usize) -> Self {
145 let config = TensorNetworkConfig::default();
146 let mut tensors = Vec::with_capacity(n_qubits);
147
148 for site in 0..n_qubits {
149 let mut data = Array3::<CType>::zeros((1, 2, 1));
151 let bit = (state >> (n_qubits - 1 - site)) & 1;
152 data[[0, bit, 0]] = Complex64::new(1.0, 0.0);
153 tensors.push(MPSTensor::new(data, site));
154 }
155
156 Self {
157 tensors,
158 n_qubits,
159 config,
160 norm: 1.0,
161 }
162 }
163
164 pub fn from_tq_device(qdev: &TQDevice) -> Result<Self> {
166 let states = qdev.get_states_1d();
168 let state_vec: Vec<CType> = states.row(0).iter().cloned().collect();
169
170 Self::from_state_vector(&state_vec, qdev.n_wires)
171 }
172
173 pub fn from_state_vector(state_vec: &[CType], n_qubits: usize) -> Result<Self> {
175 let config = TensorNetworkConfig::default();
176 let dim = 1 << n_qubits;
177
178 if state_vec.len() != dim {
179 return Err(MLError::InvalidConfiguration(format!(
180 "State vector size {} doesn't match 2^{} = {}",
181 state_vec.len(),
182 n_qubits,
183 dim
184 )));
185 }
186
187 let mut tensors = Vec::with_capacity(n_qubits);
190
191 if n_qubits <= 4 {
192 for site in 0..n_qubits {
194 let bond_left = 1.min(1 << site);
195 let bond_right = 1.min(1 << (n_qubits - site - 1));
196 let mut data = Array3::<CType>::zeros((bond_left, 2, bond_right));
197
198 for idx in 0..dim {
200 let bit = (idx >> (n_qubits - 1 - site)) & 1;
201 let left_idx = (idx >> (n_qubits - site)) % bond_left;
202 let right_idx = idx % bond_right;
203 data[[left_idx, bit, right_idx]] += state_vec[idx];
204 }
205
206 tensors.push(MPSTensor::new(data, site));
207 }
208 } else {
209 let mut remaining = Array2::<CType>::from_shape_vec((1, dim), state_vec.to_vec())
211 .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
212
213 for site in 0..n_qubits {
214 let rows = remaining.nrows();
215 let cols = remaining.ncols();
216 let new_cols = cols / 2;
217
218 let reshaped = remaining
220 .clone()
221 .into_shape_with_order((rows * 2, new_cols))
222 .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
223
224 if site == n_qubits - 1 {
226 let mut data = Array3::<CType>::zeros((rows, 2, 1));
227 for i in 0..rows {
228 for j in 0..2 {
229 data[[i, j, 0]] = reshaped[[i * 2 + j, 0]];
230 }
231 }
232 tensors.push(MPSTensor::new(data, site));
233 } else {
234 let bond_dim = (rows * 2).min(config.max_bond_dim).min(new_cols);
236 let mut data = Array3::<CType>::zeros((rows, 2, bond_dim));
237
238 for i in 0..rows {
239 for j in 0..2 {
240 for k in 0..bond_dim {
241 if i * 2 + j < rows * 2 && k < new_cols {
242 data[[i, j, k]] = reshaped[[i * 2 + j, k]];
243 }
244 }
245 }
246 }
247
248 tensors.push(MPSTensor::new(data, site));
249
250 remaining = Array2::<CType>::zeros((bond_dim, new_cols));
252 for i in 0..bond_dim.min(rows * 2) {
253 for j in 0..new_cols {
254 remaining[[i.min(bond_dim - 1), j]] = reshaped[[i, j]];
255 }
256 }
257 }
258 }
259 }
260
261 Ok(Self {
262 tensors,
263 n_qubits,
264 config,
265 norm: 1.0,
266 })
267 }
268
269 pub fn apply_single_qubit_gate(&mut self, site: usize, gate: &Array2<CType>) -> Result<()> {
271 if site >= self.n_qubits {
272 return Err(MLError::InvalidConfiguration(format!(
273 "Site {} out of range for {} qubits",
274 site, self.n_qubits
275 )));
276 }
277
278 let tensor = &mut self.tensors[site];
279 let (d_left, _phys, d_right) = (
280 tensor.data.shape()[0],
281 tensor.data.shape()[1],
282 tensor.data.shape()[2],
283 );
284
285 let mut new_data = Array3::<CType>::zeros((d_left, 2, d_right));
286
287 for i in 0..d_left {
288 for k in 0..d_right {
289 let old_0 = tensor.data[[i, 0, k]];
290 let old_1 = tensor.data[[i, 1, k]];
291 new_data[[i, 0, k]] = gate[[0, 0]] * old_0 + gate[[0, 1]] * old_1;
292 new_data[[i, 1, k]] = gate[[1, 0]] * old_0 + gate[[1, 1]] * old_1;
293 }
294 }
295
296 tensor.data = new_data;
297 Ok(())
298 }
299
300 pub fn apply_two_qubit_gate(
302 &mut self,
303 site1: usize,
304 site2: usize,
305 gate: &Array2<CType>,
306 ) -> Result<()> {
307 if site1.abs_diff(site2) != 1 {
309 return Err(MLError::InvalidConfiguration(
310 "Two-qubit gates on non-adjacent sites require SWAP operations".to_string(),
311 ));
312 }
313
314 let (left_site, right_site) = if site1 < site2 {
315 (site1, site2)
316 } else {
317 (site2, site1)
318 };
319
320 let left_tensor = &self.tensors[left_site];
322 let right_tensor = &self.tensors[right_site];
323
324 let d_left = left_tensor.data.shape()[0];
325 let d_mid = left_tensor.data.shape()[2];
326 let d_right = right_tensor.data.shape()[2];
327
328 let mut contracted = Array3::<CType>::zeros((d_left, 4, d_right));
330
331 for i in 0..d_left {
332 for k in 0..d_mid {
333 for m in 0..d_right {
334 for j1 in 0..2 {
335 for j2 in 0..2 {
336 let combined = j1 * 2 + j2;
337 contracted[[i, combined, m]] +=
338 left_tensor.data[[i, j1, k]] * right_tensor.data[[k, j2, m]];
339 }
340 }
341 }
342 }
343 }
344
345 let mut gated = Array3::<CType>::zeros((d_left, 4, d_right));
347 for i in 0..d_left {
348 for m in 0..d_right {
349 for out_idx in 0..4 {
350 for in_idx in 0..4 {
351 gated[[i, out_idx, m]] +=
352 gate[[out_idx, in_idx]] * contracted[[i, in_idx, m]];
353 }
354 }
355 }
356 }
357
358 let new_bond = d_mid.min(self.config.max_bond_dim);
360
361 let mut new_left = Array3::<CType>::zeros((d_left, 2, new_bond));
362 let mut new_right = Array3::<CType>::zeros((new_bond, 2, d_right));
363
364 for i in 0..d_left {
366 for j1 in 0..2 {
367 for k in 0..new_bond {
368 for j2 in 0..2 {
369 for m in 0..d_right {
370 let combined = j1 * 2 + j2;
371 new_left[[i, j1, k]] += gated[[i, combined, m]]
373 * Complex64::new(1.0 / (new_bond * d_right) as f64, 0.0);
374 new_right[[k, j2, m]] += gated[[i, combined, m]]
375 * Complex64::new(1.0 / (d_left * 2) as f64, 0.0);
376 }
377 }
378 }
379 }
380 }
381
382 self.tensors[left_site] = MPSTensor::new(new_left, left_site);
383 self.tensors[right_site] = MPSTensor::new(new_right, right_site);
384
385 Ok(())
386 }
387
388 pub fn to_state_vector(&self) -> Result<Vec<CType>> {
390 let dim = 1 << self.n_qubits;
391 let mut state = vec![Complex64::new(0.0, 0.0); dim];
392
393 for idx in 0..dim {
395 let mut amp = Complex64::new(1.0, 0.0);
396
397 for site in 0..self.n_qubits {
398 let bit = (idx >> (self.n_qubits - 1 - site)) & 1;
399 amp *= self.tensors[site].data[[0, bit, 0]];
401 }
402
403 state[idx] = amp;
404 }
405
406 Ok(state)
407 }
408
409 pub fn overlap(&self, other: &MatrixProductState) -> Result<CType> {
411 if self.n_qubits != other.n_qubits {
412 return Err(MLError::InvalidConfiguration(
413 "MPS qubit counts don't match".to_string(),
414 ));
415 }
416
417 let mut transfer = Array2::<CType>::eye(1);
419
420 for site in 0..self.n_qubits {
421 let t1 = &self.tensors[site];
422 let t2 = &other.tensors[site];
423
424 let d1_left = t1.data.shape()[0];
425 let d1_right = t1.data.shape()[2];
426 let d2_left = t2.data.shape()[0];
427 let d2_right = t2.data.shape()[2];
428
429 let mut new_transfer = Array2::<CType>::zeros((d1_right, d2_right));
430
431 for i1 in 0..d1_left {
432 for i2 in 0..d2_left {
433 for j in 0..2 {
434 for k1 in 0..d1_right {
435 for k2 in 0..d2_right {
436 new_transfer[[k1, k2]] += transfer
437 [[i1.min(transfer.nrows() - 1), i2.min(transfer.ncols() - 1)]]
438 * t1.data[[i1, j, k1]].conj()
439 * t2.data[[i2, j, k2]];
440 }
441 }
442 }
443 }
444 }
445
446 transfer = new_transfer;
447 }
448
449 Ok(transfer[[0, 0]])
450 }
451
452 pub fn max_bond_dim(&self) -> usize {
454 self.tensors
455 .iter()
456 .map(|t| t.bond_dims().1)
457 .max()
458 .unwrap_or(1)
459 }
460}
461
462#[derive(Debug, Clone)]
470pub struct TQTensorNetworkBackend {
471 pub mps: Option<MatrixProductState>,
473 pub n_wires: usize,
475 pub config: TensorNetworkConfig,
477 pub static_mode: bool,
479 pub gate_cache: HashMap<String, Array2<CType>>,
481}
482
483impl TQTensorNetworkBackend {
484 pub fn new(n_wires: usize) -> Self {
486 Self {
487 mps: Some(MatrixProductState::from_computational_basis(n_wires, 0)),
488 n_wires,
489 config: TensorNetworkConfig::default(),
490 static_mode: false,
491 gate_cache: HashMap::new(),
492 }
493 }
494
495 pub fn with_config(n_wires: usize, config: TensorNetworkConfig) -> Self {
497 let mut mps = MatrixProductState::from_computational_basis(n_wires, 0);
498 mps.config = config.clone();
499
500 Self {
501 mps: Some(mps),
502 n_wires,
503 config,
504 static_mode: false,
505 gate_cache: HashMap::new(),
506 }
507 }
508
509 pub fn reset(&mut self) {
511 self.mps = Some(MatrixProductState::from_computational_basis(
512 self.n_wires,
513 0,
514 ));
515 self.mps.as_mut().map(|m| m.config = self.config.clone());
516 }
517
518 pub fn apply_gate(&mut self, site: usize, gate: &Array2<CType>) -> Result<()> {
520 if let Some(ref mut mps) = self.mps {
521 mps.apply_single_qubit_gate(site, gate)
522 } else {
523 Err(MLError::InvalidConfiguration(
524 "MPS not initialized".to_string(),
525 ))
526 }
527 }
528
529 pub fn apply_two_qubit_gate(
531 &mut self,
532 site1: usize,
533 site2: usize,
534 gate: &Array2<CType>,
535 ) -> Result<()> {
536 if let Some(ref mut mps) = self.mps {
537 mps.apply_two_qubit_gate(site1, site2, gate)
538 } else {
539 Err(MLError::InvalidConfiguration(
540 "MPS not initialized".to_string(),
541 ))
542 }
543 }
544
545 pub fn get_state_vector(&self) -> Result<Vec<CType>> {
547 if let Some(ref mps) = self.mps {
548 mps.to_state_vector()
549 } else {
550 Err(MLError::InvalidConfiguration(
551 "MPS not initialized".to_string(),
552 ))
553 }
554 }
555
556 pub fn expectation_value(&self, observable: &Array2<CType>, sites: &[usize]) -> Result<f64> {
558 if sites.len() == 1 && observable.nrows() == 2 {
560 if let Some(ref mps) = self.mps {
561 let site = sites[0];
562 let tensor = &mps.tensors[site];
563
564 let mut exp_val = Complex64::new(0.0, 0.0);
567
568 for i in 0..2 {
569 for j in 0..2 {
570 let rho_ji = tensor.data[[0, j, 0]].conj() * tensor.data[[0, i, 0]];
572 exp_val += observable[[i, j]] * rho_ji;
573 }
574 }
575
576 return Ok(exp_val.re);
577 }
578 }
579
580 Err(MLError::NotSupported(
581 "Multi-site observables not yet implemented for MPS".to_string(),
582 ))
583 }
584
585 pub fn bond_dimension(&self) -> usize {
587 self.mps.as_ref().map(|m| m.max_bond_dim()).unwrap_or(0)
588 }
589
590 pub fn to_tq_device(&self) -> Result<TQDevice> {
592 let state_vec = self.get_state_vector()?;
593 let mut qdev = TQDevice::new(self.n_wires);
594
595 use scirs2_core::ndarray::{ArrayD, IxDyn};
597 let mut shape = vec![1usize]; shape.extend(vec![2; self.n_wires]);
599
600 let states = ArrayD::from_shape_vec(IxDyn(&shape), state_vec)
601 .map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
602 qdev.set_states(states);
603
604 Ok(qdev)
605 }
606
607 pub fn from_tq_device(qdev: &TQDevice) -> Result<Self> {
609 let mps = MatrixProductState::from_tq_device(qdev)?;
610 Ok(Self {
611 n_wires: qdev.n_wires,
612 mps: Some(mps),
613 config: TensorNetworkConfig::default(),
614 static_mode: false,
615 gate_cache: HashMap::new(),
616 })
617 }
618}
619
620impl TQModule for TQTensorNetworkBackend {
621 fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
622 Ok(())
624 }
625
626 fn parameters(&self) -> Vec<TQParameter> {
627 Vec::new()
628 }
629
630 fn n_wires(&self) -> Option<usize> {
631 Some(self.n_wires)
632 }
633
634 fn set_n_wires(&mut self, n_wires: usize) {
635 self.n_wires = n_wires;
636 self.reset();
637 }
638
639 fn is_static_mode(&self) -> bool {
640 self.static_mode
641 }
642
643 fn static_on(&mut self) {
644 self.static_mode = true;
645 }
646
647 fn static_off(&mut self) {
648 self.static_mode = false;
649 self.gate_cache.clear();
650 }
651
652 fn name(&self) -> &str {
653 "TQTensorNetworkBackend"
654 }
655}
656
657#[cfg(test)]
662mod tests {
663 use super::*;
664
665 #[test]
666 fn test_mps_creation() {
667 let mps = MatrixProductState::from_computational_basis(4, 0);
668 assert_eq!(mps.n_qubits, 4);
669 assert_eq!(mps.tensors.len(), 4);
670 }
671
672 #[test]
673 fn test_mps_state_vector() {
674 let mps = MatrixProductState::from_computational_basis(2, 0);
675 let state = mps.to_state_vector().expect("Should succeed");
676 assert_eq!(state.len(), 4);
677 assert!((state[0].re - 1.0).abs() < 1e-10);
678 for i in 1..4 {
679 assert!(state[i].norm() < 1e-10);
680 }
681 }
682
683 #[test]
684 fn test_tensor_network_backend() {
685 let backend = TQTensorNetworkBackend::new(3);
686 assert_eq!(backend.n_wires, 3);
687 assert!(backend.mps.is_some());
688 }
689
690 #[test]
691 fn test_single_qubit_gate_application() {
692 let mut backend = TQTensorNetworkBackend::new(2);
693
694 let x_gate = Array2::from_shape_vec(
696 (2, 2),
697 vec![
698 Complex64::new(0.0, 0.0),
699 Complex64::new(1.0, 0.0),
700 Complex64::new(1.0, 0.0),
701 Complex64::new(0.0, 0.0),
702 ],
703 )
704 .expect("Should create matrix");
705
706 backend.apply_gate(0, &x_gate).expect("Should apply gate");
707
708 let state = backend.get_state_vector().expect("Should get state");
709 assert!(state[0].norm() < 1e-10);
711 assert!(state[1].norm() < 1e-10);
712 assert!((state[2].re - 1.0).abs() < 1e-10);
713 assert!(state[3].norm() < 1e-10);
714 }
715
716 #[test]
717 fn test_config_defaults() {
718 let config = TensorNetworkConfig::default();
719 assert_eq!(config.max_bond_dim, 64);
720 assert_eq!(config.compression, CompressionMethod::SVD);
721 }
722}