use crate::error::{MLError, Result};
use scirs2_core::ndarray::{Array1, Array2, Array3, Axis};
use scirs2_core::Complex64;
use std::collections::HashMap;
use super::{CType, TQDevice, TQModule, TQParameter};
#[derive(Debug, Clone)]
pub struct TensorNetworkConfig {
pub max_bond_dim: usize,
pub truncation_threshold: f64,
pub use_canonical_form: bool,
pub compression: CompressionMethod,
}
impl Default for TensorNetworkConfig {
fn default() -> Self {
Self {
max_bond_dim: 64,
truncation_threshold: 1e-12,
use_canonical_form: true,
compression: CompressionMethod::SVD,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionMethod {
SVD,
QR,
Variational,
}
#[derive(Debug, Clone)]
pub struct MPSTensor {
pub data: Array3<CType>,
pub site: usize,
}
impl MPSTensor {
pub fn new(data: Array3<CType>, site: usize) -> Self {
Self { data, site }
}
pub fn bond_dims(&self) -> (usize, usize) {
let shape = self.data.shape();
(shape[0], shape[2])
}
pub fn physical_dim(&self) -> usize {
self.data.shape()[1]
}
pub fn contract_right(&self, other: &MPSTensor) -> Array3<CType> {
let (d_left, phys_a, d_mid) = (
self.data.shape()[0],
self.data.shape()[1],
self.data.shape()[2],
);
let (_d_mid2, phys_b, d_right) = (
other.data.shape()[0],
other.data.shape()[1],
other.data.shape()[2],
);
let mut result = Array3::<CType>::zeros((d_left, phys_a * phys_b, d_right));
for i in 0..d_left {
for j in 0..phys_a {
for k in 0..d_mid {
for l in 0..phys_b {
for m in 0..d_right {
let combined_phys = j * phys_b + l;
result[[i, combined_phys, m]] +=
self.data[[i, j, k]] * other.data[[k, l, m]];
}
}
}
}
}
result
}
}
#[derive(Debug, Clone)]
pub struct MatrixProductState {
pub tensors: Vec<MPSTensor>,
pub n_qubits: usize,
pub config: TensorNetworkConfig,
pub norm: f64,
}
impl MatrixProductState {
pub fn from_computational_basis(n_qubits: usize, state: usize) -> Self {
let config = TensorNetworkConfig::default();
let mut tensors = Vec::with_capacity(n_qubits);
for site in 0..n_qubits {
let mut data = Array3::<CType>::zeros((1, 2, 1));
let bit = (state >> (n_qubits - 1 - site)) & 1;
data[[0, bit, 0]] = Complex64::new(1.0, 0.0);
tensors.push(MPSTensor::new(data, site));
}
Self {
tensors,
n_qubits,
config,
norm: 1.0,
}
}
pub fn from_tq_device(qdev: &TQDevice) -> Result<Self> {
let states = qdev.get_states_1d();
let state_vec: Vec<CType> = states.row(0).iter().cloned().collect();
Self::from_state_vector(&state_vec, qdev.n_wires)
}
pub fn from_state_vector(state_vec: &[CType], n_qubits: usize) -> Result<Self> {
let config = TensorNetworkConfig::default();
let dim = 1 << n_qubits;
if state_vec.len() != dim {
return Err(MLError::InvalidConfiguration(format!(
"State vector size {} doesn't match 2^{} = {}",
state_vec.len(),
n_qubits,
dim
)));
}
let mut tensors = Vec::with_capacity(n_qubits);
if n_qubits <= 4 {
for site in 0..n_qubits {
let bond_left = 1.min(1 << site);
let bond_right = 1.min(1 << (n_qubits - site - 1));
let mut data = Array3::<CType>::zeros((bond_left, 2, bond_right));
for idx in 0..dim {
let bit = (idx >> (n_qubits - 1 - site)) & 1;
let left_idx = (idx >> (n_qubits - site)) % bond_left;
let right_idx = idx % bond_right;
data[[left_idx, bit, right_idx]] += state_vec[idx];
}
tensors.push(MPSTensor::new(data, site));
}
} else {
let mut remaining = Array2::<CType>::from_shape_vec((1, dim), state_vec.to_vec())
.map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
for site in 0..n_qubits {
let rows = remaining.nrows();
let cols = remaining.ncols();
let new_cols = cols / 2;
let reshaped = remaining
.clone()
.into_shape_with_order((rows * 2, new_cols))
.map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
if site == n_qubits - 1 {
let mut data = Array3::<CType>::zeros((rows, 2, 1));
for i in 0..rows {
for j in 0..2 {
data[[i, j, 0]] = reshaped[[i * 2 + j, 0]];
}
}
tensors.push(MPSTensor::new(data, site));
} else {
let bond_dim = (rows * 2).min(config.max_bond_dim).min(new_cols);
let mut data = Array3::<CType>::zeros((rows, 2, bond_dim));
for i in 0..rows {
for j in 0..2 {
for k in 0..bond_dim {
if i * 2 + j < rows * 2 && k < new_cols {
data[[i, j, k]] = reshaped[[i * 2 + j, k]];
}
}
}
}
tensors.push(MPSTensor::new(data, site));
remaining = Array2::<CType>::zeros((bond_dim, new_cols));
for i in 0..bond_dim.min(rows * 2) {
for j in 0..new_cols {
remaining[[i.min(bond_dim - 1), j]] = reshaped[[i, j]];
}
}
}
}
}
Ok(Self {
tensors,
n_qubits,
config,
norm: 1.0,
})
}
pub fn apply_single_qubit_gate(&mut self, site: usize, gate: &Array2<CType>) -> Result<()> {
if site >= self.n_qubits {
return Err(MLError::InvalidConfiguration(format!(
"Site {} out of range for {} qubits",
site, self.n_qubits
)));
}
let tensor = &mut self.tensors[site];
let (d_left, _phys, d_right) = (
tensor.data.shape()[0],
tensor.data.shape()[1],
tensor.data.shape()[2],
);
let mut new_data = Array3::<CType>::zeros((d_left, 2, d_right));
for i in 0..d_left {
for k in 0..d_right {
let old_0 = tensor.data[[i, 0, k]];
let old_1 = tensor.data[[i, 1, k]];
new_data[[i, 0, k]] = gate[[0, 0]] * old_0 + gate[[0, 1]] * old_1;
new_data[[i, 1, k]] = gate[[1, 0]] * old_0 + gate[[1, 1]] * old_1;
}
}
tensor.data = new_data;
Ok(())
}
pub fn apply_two_qubit_gate(
&mut self,
site1: usize,
site2: usize,
gate: &Array2<CType>,
) -> Result<()> {
if site1.abs_diff(site2) != 1 {
return Err(MLError::InvalidConfiguration(
"Two-qubit gates on non-adjacent sites require SWAP operations".to_string(),
));
}
let (left_site, right_site) = if site1 < site2 {
(site1, site2)
} else {
(site2, site1)
};
let left_tensor = &self.tensors[left_site];
let right_tensor = &self.tensors[right_site];
let d_left = left_tensor.data.shape()[0];
let d_mid = left_tensor.data.shape()[2];
let d_right = right_tensor.data.shape()[2];
let mut contracted = Array3::<CType>::zeros((d_left, 4, d_right));
for i in 0..d_left {
for k in 0..d_mid {
for m in 0..d_right {
for j1 in 0..2 {
for j2 in 0..2 {
let combined = j1 * 2 + j2;
contracted[[i, combined, m]] +=
left_tensor.data[[i, j1, k]] * right_tensor.data[[k, j2, m]];
}
}
}
}
}
let mut gated = Array3::<CType>::zeros((d_left, 4, d_right));
for i in 0..d_left {
for m in 0..d_right {
for out_idx in 0..4 {
for in_idx in 0..4 {
gated[[i, out_idx, m]] +=
gate[[out_idx, in_idx]] * contracted[[i, in_idx, m]];
}
}
}
}
let new_bond = d_mid.min(self.config.max_bond_dim);
let mut new_left = Array3::<CType>::zeros((d_left, 2, new_bond));
let mut new_right = Array3::<CType>::zeros((new_bond, 2, d_right));
for i in 0..d_left {
for j1 in 0..2 {
for k in 0..new_bond {
for j2 in 0..2 {
for m in 0..d_right {
let combined = j1 * 2 + j2;
new_left[[i, j1, k]] += gated[[i, combined, m]]
* Complex64::new(1.0 / (new_bond * d_right) as f64, 0.0);
new_right[[k, j2, m]] += gated[[i, combined, m]]
* Complex64::new(1.0 / (d_left * 2) as f64, 0.0);
}
}
}
}
}
self.tensors[left_site] = MPSTensor::new(new_left, left_site);
self.tensors[right_site] = MPSTensor::new(new_right, right_site);
Ok(())
}
pub fn to_state_vector(&self) -> Result<Vec<CType>> {
let dim = 1 << self.n_qubits;
let mut state = vec![Complex64::new(0.0, 0.0); dim];
for idx in 0..dim {
let mut amp = Complex64::new(1.0, 0.0);
for site in 0..self.n_qubits {
let bit = (idx >> (self.n_qubits - 1 - site)) & 1;
amp *= self.tensors[site].data[[0, bit, 0]];
}
state[idx] = amp;
}
Ok(state)
}
pub fn overlap(&self, other: &MatrixProductState) -> Result<CType> {
if self.n_qubits != other.n_qubits {
return Err(MLError::InvalidConfiguration(
"MPS qubit counts don't match".to_string(),
));
}
let mut transfer = Array2::<CType>::eye(1);
for site in 0..self.n_qubits {
let t1 = &self.tensors[site];
let t2 = &other.tensors[site];
let d1_left = t1.data.shape()[0];
let d1_right = t1.data.shape()[2];
let d2_left = t2.data.shape()[0];
let d2_right = t2.data.shape()[2];
let mut new_transfer = Array2::<CType>::zeros((d1_right, d2_right));
for i1 in 0..d1_left {
for i2 in 0..d2_left {
for j in 0..2 {
for k1 in 0..d1_right {
for k2 in 0..d2_right {
new_transfer[[k1, k2]] += transfer
[[i1.min(transfer.nrows() - 1), i2.min(transfer.ncols() - 1)]]
* t1.data[[i1, j, k1]].conj()
* t2.data[[i2, j, k2]];
}
}
}
}
}
transfer = new_transfer;
}
Ok(transfer[[0, 0]])
}
pub fn max_bond_dim(&self) -> usize {
self.tensors
.iter()
.map(|t| t.bond_dims().1)
.max()
.unwrap_or(1)
}
}
#[derive(Debug, Clone)]
pub struct TQTensorNetworkBackend {
pub mps: Option<MatrixProductState>,
pub n_wires: usize,
pub config: TensorNetworkConfig,
pub static_mode: bool,
pub gate_cache: HashMap<String, Array2<CType>>,
}
impl TQTensorNetworkBackend {
pub fn new(n_wires: usize) -> Self {
Self {
mps: Some(MatrixProductState::from_computational_basis(n_wires, 0)),
n_wires,
config: TensorNetworkConfig::default(),
static_mode: false,
gate_cache: HashMap::new(),
}
}
pub fn with_config(n_wires: usize, config: TensorNetworkConfig) -> Self {
let mut mps = MatrixProductState::from_computational_basis(n_wires, 0);
mps.config = config.clone();
Self {
mps: Some(mps),
n_wires,
config,
static_mode: false,
gate_cache: HashMap::new(),
}
}
pub fn reset(&mut self) {
self.mps = Some(MatrixProductState::from_computational_basis(
self.n_wires,
0,
));
self.mps.as_mut().map(|m| m.config = self.config.clone());
}
pub fn apply_gate(&mut self, site: usize, gate: &Array2<CType>) -> Result<()> {
if let Some(ref mut mps) = self.mps {
mps.apply_single_qubit_gate(site, gate)
} else {
Err(MLError::InvalidConfiguration(
"MPS not initialized".to_string(),
))
}
}
pub fn apply_two_qubit_gate(
&mut self,
site1: usize,
site2: usize,
gate: &Array2<CType>,
) -> Result<()> {
if let Some(ref mut mps) = self.mps {
mps.apply_two_qubit_gate(site1, site2, gate)
} else {
Err(MLError::InvalidConfiguration(
"MPS not initialized".to_string(),
))
}
}
pub fn get_state_vector(&self) -> Result<Vec<CType>> {
if let Some(ref mps) = self.mps {
mps.to_state_vector()
} else {
Err(MLError::InvalidConfiguration(
"MPS not initialized".to_string(),
))
}
}
pub fn expectation_value(&self, observable: &Array2<CType>, sites: &[usize]) -> Result<f64> {
if sites.len() == 1 && observable.nrows() == 2 {
if let Some(ref mps) = self.mps {
let site = sites[0];
let tensor = &mps.tensors[site];
let mut exp_val = Complex64::new(0.0, 0.0);
for i in 0..2 {
for j in 0..2 {
let rho_ji = tensor.data[[0, j, 0]].conj() * tensor.data[[0, i, 0]];
exp_val += observable[[i, j]] * rho_ji;
}
}
return Ok(exp_val.re);
}
}
Err(MLError::NotSupported(
"Multi-site observables not yet implemented for MPS".to_string(),
))
}
pub fn bond_dimension(&self) -> usize {
self.mps.as_ref().map(|m| m.max_bond_dim()).unwrap_or(0)
}
pub fn to_tq_device(&self) -> Result<TQDevice> {
let state_vec = self.get_state_vector()?;
let mut qdev = TQDevice::new(self.n_wires);
use scirs2_core::ndarray::{ArrayD, IxDyn};
let mut shape = vec![1usize]; shape.extend(vec![2; self.n_wires]);
let states = ArrayD::from_shape_vec(IxDyn(&shape), state_vec)
.map_err(|e| MLError::InvalidConfiguration(e.to_string()))?;
qdev.set_states(states);
Ok(qdev)
}
pub fn from_tq_device(qdev: &TQDevice) -> Result<Self> {
let mps = MatrixProductState::from_tq_device(qdev)?;
Ok(Self {
n_wires: qdev.n_wires,
mps: Some(mps),
config: TensorNetworkConfig::default(),
static_mode: false,
gate_cache: HashMap::new(),
})
}
}
impl TQModule for TQTensorNetworkBackend {
fn forward(&mut self, _qdev: &mut TQDevice) -> Result<()> {
Ok(())
}
fn parameters(&self) -> Vec<TQParameter> {
Vec::new()
}
fn n_wires(&self) -> Option<usize> {
Some(self.n_wires)
}
fn set_n_wires(&mut self, n_wires: usize) {
self.n_wires = n_wires;
self.reset();
}
fn is_static_mode(&self) -> bool {
self.static_mode
}
fn static_on(&mut self) {
self.static_mode = true;
}
fn static_off(&mut self) {
self.static_mode = false;
self.gate_cache.clear();
}
fn name(&self) -> &str {
"TQTensorNetworkBackend"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mps_creation() {
let mps = MatrixProductState::from_computational_basis(4, 0);
assert_eq!(mps.n_qubits, 4);
assert_eq!(mps.tensors.len(), 4);
}
#[test]
fn test_mps_state_vector() {
let mps = MatrixProductState::from_computational_basis(2, 0);
let state = mps.to_state_vector().expect("Should succeed");
assert_eq!(state.len(), 4);
assert!((state[0].re - 1.0).abs() < 1e-10);
for i in 1..4 {
assert!(state[i].norm() < 1e-10);
}
}
#[test]
fn test_tensor_network_backend() {
let backend = TQTensorNetworkBackend::new(3);
assert_eq!(backend.n_wires, 3);
assert!(backend.mps.is_some());
}
#[test]
fn test_single_qubit_gate_application() {
let mut backend = TQTensorNetworkBackend::new(2);
let x_gate = Array2::from_shape_vec(
(2, 2),
vec![
Complex64::new(0.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
],
)
.expect("Should create matrix");
backend.apply_gate(0, &x_gate).expect("Should apply gate");
let state = backend.get_state_vector().expect("Should get state");
assert!(state[0].norm() < 1e-10);
assert!(state[1].norm() < 1e-10);
assert!((state[2].re - 1.0).abs() < 1e-10);
assert!(state[3].norm() < 1e-10);
}
#[test]
fn test_config_defaults() {
let config = TensorNetworkConfig::default();
assert_eq!(config.max_bond_dim, 64);
assert_eq!(config.compression, CompressionMethod::SVD);
}
}