use crate::{
error::{QuantRS2Error, QuantRS2Result},
gate::GateOp,
linalg_stubs::svd,
register::Register,
};
use scirs2_core::ndarray::{Array, Array2, ArrayD, IxDyn};
use scirs2_core::Complex;
use std::collections::{HashMap, HashSet};
type Complex64 = Complex<f64>;
#[derive(Debug, Clone)]
pub struct Tensor {
pub id: usize,
pub data: ArrayD<Complex64>,
pub indices: Vec<String>,
pub shape: Vec<usize>,
}
impl Tensor {
pub fn new(id: usize, data: ArrayD<Complex64>, indices: Vec<String>) -> Self {
let shape = data.shape().to_vec();
Self {
id,
data,
indices,
shape,
}
}
pub fn from_matrix(
id: usize,
matrix: Array2<Complex64>,
in_idx: String,
out_idx: String,
) -> Self {
let shape = matrix.shape().to_vec();
let data = matrix.into_dyn();
Self {
id,
data,
indices: vec![in_idx, out_idx],
shape,
}
}
pub fn qubit_zero(id: usize, idx: String) -> Self {
let mut data = Array::zeros(IxDyn(&[2]));
data[[0]] = Complex64::new(1.0, 0.0);
Self {
id,
data,
indices: vec![idx],
shape: vec![2],
}
}
pub fn qubit_one(id: usize, idx: String) -> Self {
let mut data = Array::zeros(IxDyn(&[2]));
data[[1]] = Complex64::new(1.0, 0.0);
Self {
id,
data,
indices: vec![idx],
shape: vec![2],
}
}
pub fn from_array<D>(
array: scirs2_core::ndarray::ArrayBase<scirs2_core::ndarray::OwnedRepr<Complex64>, D>,
indices: Vec<usize>,
) -> Self
where
D: scirs2_core::ndarray::Dimension,
{
let shape = array.shape().to_vec();
let data = array.into_dyn();
let index_labels: Vec<String> = indices.iter().map(|i| format!("idx_{i}")).collect();
Self {
id: 0, data,
indices: index_labels,
shape,
}
}
pub fn rank(&self) -> usize {
self.indices.len()
}
pub const fn tensor(&self) -> &ArrayD<Complex64> {
&self.data
}
pub fn ndim(&self) -> usize {
self.data.ndim()
}
pub fn contract(&self, other: &Self, self_idx: &str, other_idx: &str) -> QuantRS2Result<Self> {
let self_pos = self
.indices
.iter()
.position(|s| s == self_idx)
.ok_or_else(|| {
QuantRS2Error::InvalidInput(format!("Index {self_idx} not found in tensor"))
})?;
let other_pos = other
.indices
.iter()
.position(|s| s == other_idx)
.ok_or_else(|| {
QuantRS2Error::InvalidInput(format!("Index {other_idx} not found in tensor"))
})?;
if self.shape[self_pos] != other.shape[other_pos] {
return Err(QuantRS2Error::InvalidInput(format!(
"Cannot contract indices with different dimensions: {} vs {}",
self.shape[self_pos], other.shape[other_pos]
)));
}
let contracted = self.contract_indices(&other, self_pos, other_pos)?;
let mut new_indices = Vec::new();
for (i, idx) in self.indices.iter().enumerate() {
if i != self_pos {
new_indices.push(idx.clone());
}
}
for (i, idx) in other.indices.iter().enumerate() {
if i != other_pos {
new_indices.push(idx.clone());
}
}
Ok(Self::new(
self.id.max(other.id) + 1,
contracted,
new_indices,
))
}
fn contract_indices(
&self,
other: &Self,
self_idx: usize,
other_idx: usize,
) -> QuantRS2Result<ArrayD<Complex64>> {
let self_shape = self.data.shape();
let other_shape = other.data.shape();
let mut self_left_dims = 1;
let mut self_right_dims = 1;
for i in 0..self_idx {
self_left_dims *= self_shape[i];
}
for i in (self_idx + 1)..self_shape.len() {
self_right_dims *= self_shape[i];
}
let mut other_left_dims = 1;
let mut other_right_dims = 1;
for i in 0..other_idx {
other_left_dims *= other_shape[i];
}
for i in (other_idx + 1)..other_shape.len() {
other_right_dims *= other_shape[i];
}
let contract_dim = self_shape[self_idx];
let self_mat = self
.data
.view()
.into_shape_with_order((self_left_dims, contract_dim * self_right_dims))
.map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
.to_owned();
let other_mat = other
.data
.view()
.into_shape_with_order((other_left_dims * contract_dim, other_right_dims))
.map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
.to_owned();
let _result_mat: Array2<Complex64> = Array2::zeros((
self_left_dims * self_right_dims,
other_left_dims * other_right_dims,
));
let mut result_vec = Vec::new();
for i in 0..self_left_dims {
for j in 0..self_right_dims {
for k in 0..other_left_dims {
for l in 0..other_right_dims {
let mut sum = Complex64::new(0.0, 0.0);
for c in 0..contract_dim {
sum += self_mat[[i, c * self_right_dims + j]]
* other_mat[[k * contract_dim + c, l]];
}
result_vec.push(sum);
}
}
}
}
let mut result_shape = Vec::new();
for i in 0..self_idx {
result_shape.push(self_shape[i]);
}
for i in (self_idx + 1)..self_shape.len() {
result_shape.push(self_shape[i]);
}
for i in 0..other_idx {
result_shape.push(other_shape[i]);
}
for i in (other_idx + 1)..other_shape.len() {
result_shape.push(other_shape[i]);
}
ArrayD::from_shape_vec(IxDyn(&result_shape), result_vec)
.map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))
}
pub fn svd_decompose(
&self,
idx: usize,
max_rank: Option<usize>,
) -> QuantRS2Result<(Self, Self)> {
if idx >= self.rank() {
return Err(QuantRS2Error::InvalidInput(format!(
"Index {} out of bounds for tensor with rank {}",
idx,
self.rank()
)));
}
let shape = self.data.shape();
let mut left_dim = 1;
let mut right_dim = 1;
for i in 0..=idx {
left_dim *= shape[i];
}
for i in (idx + 1)..shape.len() {
right_dim *= shape[i];
}
let matrix = self
.data
.view()
.into_shape_with_order((left_dim, right_dim))
.map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
.to_owned();
let real_matrix = matrix.mapv(|c| c.re);
let (u, s, vt) = svd(&real_matrix.view(), false, None)
.map_err(|e| QuantRS2Error::ComputationError(format!("SVD failed: {e:?}")))?;
let rank = if let Some(max_r) = max_rank {
max_r.min(s.len())
} else {
s.len()
};
let u_trunc = u.slice(scirs2_core::ndarray::s![.., ..rank]).to_owned();
let s_trunc = s.slice(scirs2_core::ndarray::s![..rank]).to_owned();
let vt_trunc = vt.slice(scirs2_core::ndarray::s![..rank, ..]).to_owned();
let mut s_mat = Array2::zeros((rank, rank));
for i in 0..rank {
s_mat[[i, i]] = Complex64::new(s_trunc[i].sqrt(), 0.0);
}
let left_data = u_trunc.mapv(|x| Complex64::new(x, 0.0)).dot(&s_mat);
let right_data = s_mat.dot(&vt_trunc.mapv(|x| Complex64::new(x, 0.0)));
let mut left_indices = self.indices[..=idx].to_vec();
left_indices.push(format!("bond_{}", self.id));
let mut right_indices = vec![format!("bond_{}", self.id)];
right_indices.extend_from_slice(&self.indices[(idx + 1)..]);
let left_tensor = Self::new(self.id * 2, left_data.into_dyn(), left_indices);
let right_tensor = Self::new(self.id * 2 + 1, right_data.into_dyn(), right_indices);
Ok((left_tensor, right_tensor))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TensorEdge {
pub tensor1: usize,
pub index1: String,
pub tensor2: usize,
pub index2: String,
}
#[derive(Debug)]
pub struct TensorNetwork {
pub tensors: HashMap<usize, Tensor>,
pub edges: Vec<TensorEdge>,
pub open_indices: HashMap<usize, Vec<String>>,
next_id: usize,
}
impl TensorNetwork {
pub fn new() -> Self {
Self {
tensors: HashMap::new(),
edges: Vec::new(),
open_indices: HashMap::new(),
next_id: 0,
}
}
pub fn add_tensor(&mut self, tensor: Tensor) -> usize {
let id = tensor.id;
self.open_indices.insert(id, tensor.indices.clone());
self.tensors.insert(id, tensor);
self.next_id = self.next_id.max(id + 1);
id
}
pub fn connect(
&mut self,
tensor1: usize,
index1: String,
tensor2: usize,
index2: String,
) -> QuantRS2Result<()> {
if !self.tensors.contains_key(&tensor1) {
return Err(QuantRS2Error::InvalidInput(format!(
"Tensor {tensor1} not found"
)));
}
if !self.tensors.contains_key(&tensor2) {
return Err(QuantRS2Error::InvalidInput(format!(
"Tensor {tensor2} not found"
)));
}
let t1 = &self.tensors[&tensor1];
let t2 = &self.tensors[&tensor2];
let idx1_pos = t1
.indices
.iter()
.position(|s| s == &index1)
.ok_or_else(|| {
QuantRS2Error::InvalidInput(format!("Index {index1} not found in tensor {tensor1}"))
})?;
let idx2_pos = t2
.indices
.iter()
.position(|s| s == &index2)
.ok_or_else(|| {
QuantRS2Error::InvalidInput(format!("Index {index2} not found in tensor {tensor2}"))
})?;
if t1.shape[idx1_pos] != t2.shape[idx2_pos] {
return Err(QuantRS2Error::InvalidInput(format!(
"Connected indices must have same dimension: {} vs {}",
t1.shape[idx1_pos], t2.shape[idx2_pos]
)));
}
self.edges.push(TensorEdge {
tensor1,
index1: index1.clone(),
tensor2,
index2: index2.clone(),
});
if let Some(indices) = self.open_indices.get_mut(&tensor1) {
indices.retain(|s| s != &index1);
}
if let Some(indices) = self.open_indices.get_mut(&tensor2) {
indices.retain(|s| s != &index2);
}
Ok(())
}
pub fn find_contraction_order(&self) -> Vec<(usize, usize)> {
let mut remaining_tensors: HashSet<_> = self.tensors.keys().copied().collect();
let mut order = Vec::new();
let mut adjacency: HashMap<usize, Vec<usize>> = HashMap::new();
for edge in &self.edges {
adjacency
.entry(edge.tensor1)
.or_insert_with(Vec::new)
.push(edge.tensor2);
adjacency
.entry(edge.tensor2)
.or_insert_with(Vec::new)
.push(edge.tensor1);
}
while remaining_tensors.len() > 1 {
let mut best_pair = None;
let mut min_cost = usize::MAX;
for &t1 in &remaining_tensors {
if let Some(neighbors) = adjacency.get(&t1) {
for &t2 in neighbors {
if t2 > t1 && remaining_tensors.contains(&t2) {
let cost = self.estimate_contraction_cost(t1, t2);
if cost < min_cost {
min_cost = cost;
best_pair = Some((t1, t2));
}
}
}
}
}
if let Some((t1, t2)) = best_pair {
order.push((t1, t2));
remaining_tensors.remove(&t1);
remaining_tensors.remove(&t2);
let virtual_id = self.next_id + order.len();
remaining_tensors.insert(virtual_id);
let mut virtual_neighbors = HashSet::new();
if let Some(n1) = adjacency.get(&t1) {
virtual_neighbors.extend(
n1.iter()
.filter(|&&n| n != t2 && remaining_tensors.contains(&n)),
);
}
if let Some(n2) = adjacency.get(&t2) {
virtual_neighbors.extend(
n2.iter()
.filter(|&&n| n != t1 && remaining_tensors.contains(&n)),
);
}
adjacency.insert(virtual_id, virtual_neighbors.into_iter().collect());
} else {
break;
}
}
order
}
const fn estimate_contraction_cost(&self, _t1: usize, _t2: usize) -> usize {
1000 }
pub fn contract_all(&mut self) -> QuantRS2Result<Tensor> {
if self.tensors.is_empty() {
return Err(QuantRS2Error::InvalidInput(
"Cannot contract empty tensor network".into(),
));
}
if self.tensors.len() == 1 {
return self
.tensors
.values()
.next()
.map(|t| t.clone())
.ok_or_else(|| {
QuantRS2Error::InvalidInput("Single tensor expected but not found".into())
});
}
let order = self.find_contraction_order();
let mut tensor_map = self.tensors.clone();
let mut next_id = self.next_id;
for (t1_id, t2_id) in order {
let edge = self
.edges
.iter()
.find(|e| {
(e.tensor1 == t1_id && e.tensor2 == t2_id)
|| (e.tensor1 == t2_id && e.tensor2 == t1_id)
})
.ok_or_else(|| QuantRS2Error::InvalidInput("Tensors not connected".into()))?;
let t1 = tensor_map
.remove(&t1_id)
.ok_or_else(|| QuantRS2Error::InvalidInput("Tensor not found".into()))?;
let t2 = tensor_map
.remove(&t2_id)
.ok_or_else(|| QuantRS2Error::InvalidInput("Tensor not found".into()))?;
let contracted = if edge.tensor1 == t1_id {
t1.contract(&t2, &edge.index1, &edge.index2)?
} else {
t1.contract(&t2, &edge.index2, &edge.index1)?
};
let mut new_tensor = contracted;
new_tensor.id = next_id;
tensor_map.insert(next_id, new_tensor);
next_id += 1;
}
tensor_map
.into_values()
.next()
.ok_or_else(|| QuantRS2Error::InvalidInput("Contraction failed".into()))
}
pub const fn to_mps(&self, _max_bond_dim: Option<usize>) -> QuantRS2Result<Vec<Tensor>> {
Ok(vec![])
}
pub const fn apply_mpo(&mut self, _mpo: &[Tensor], _qubits: &[usize]) -> QuantRS2Result<()> {
Ok(())
}
pub fn tensors(&self) -> Vec<&Tensor> {
self.tensors.values().collect()
}
pub fn tensor(&self, id: usize) -> Option<&Tensor> {
self.tensors.get(&id)
}
}
pub struct TensorNetworkBuilder {
network: TensorNetwork,
qubit_indices: HashMap<usize, String>,
current_indices: HashMap<usize, String>,
}
impl TensorNetworkBuilder {
pub fn new(num_qubits: usize) -> Self {
let mut network = TensorNetwork::new();
let mut qubit_indices = HashMap::new();
let mut current_indices = HashMap::new();
for i in 0..num_qubits {
let idx = format!("q{i}_0");
let tensor = Tensor::qubit_zero(i, idx.clone());
network.add_tensor(tensor);
qubit_indices.insert(i, idx.clone());
current_indices.insert(i, idx);
}
Self {
network,
qubit_indices,
current_indices,
}
}
pub fn apply_single_qubit_gate(
&mut self,
gate: &dyn GateOp,
qubit: usize,
) -> QuantRS2Result<()> {
let matrix_vec = gate.matrix()?;
let matrix = Array2::from_shape_vec((2, 2), matrix_vec)
.map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?;
let in_idx = self.current_indices[&qubit].clone();
let out_idx = format!("q{}_{}", qubit, self.network.next_id);
let gate_tensor = Tensor::from_matrix(
self.network.next_id,
matrix,
in_idx.clone(),
out_idx.clone(),
);
let gate_id = self.network.add_tensor(gate_tensor);
if let Some(prev_tensor) = self.find_tensor_with_index(&in_idx) {
self.network
.connect(prev_tensor, in_idx.clone(), gate_id, in_idx)?;
}
self.current_indices.insert(qubit, out_idx);
Ok(())
}
pub fn apply_two_qubit_gate(
&mut self,
gate: &dyn GateOp,
qubit1: usize,
qubit2: usize,
) -> QuantRS2Result<()> {
let matrix_vec = gate.matrix()?;
let matrix = Array2::from_shape_vec((4, 4), matrix_vec)
.map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?;
let tensor_data = matrix
.into_shape_with_order((2, 2, 2, 2))
.map_err(|e| QuantRS2Error::InvalidInput(format!("Shape error: {e}")))?
.into_dyn();
let in1_idx = self.current_indices[&qubit1].clone();
let in2_idx = self.current_indices[&qubit2].clone();
let out1_idx = format!("q{}_{}", qubit1, self.network.next_id);
let out2_idx = format!("q{}_{}", qubit2, self.network.next_id);
let gate_tensor = Tensor::new(
self.network.next_id,
tensor_data,
vec![
in1_idx.clone(),
in2_idx.clone(),
out1_idx.clone(),
out2_idx.clone(),
],
);
let gate_id = self.network.add_tensor(gate_tensor);
if let Some(prev1) = self.find_tensor_with_index(&in1_idx) {
self.network
.connect(prev1, in1_idx.clone(), gate_id, in1_idx)?;
}
if let Some(prev2) = self.find_tensor_with_index(&in2_idx) {
self.network
.connect(prev2, in2_idx.clone(), gate_id, in2_idx)?;
}
self.current_indices.insert(qubit1, out1_idx);
self.current_indices.insert(qubit2, out2_idx);
Ok(())
}
fn find_tensor_with_index(&self, index: &str) -> Option<usize> {
for (id, tensor) in &self.network.tensors {
if tensor.indices.iter().any(|idx| idx == index) {
return Some(*id);
}
}
None
}
pub fn build(self) -> TensorNetwork {
self.network
}
#[must_use]
pub fn to_statevector(&mut self) -> QuantRS2Result<Vec<Complex64>> {
let final_tensor = self.network.contract_all()?;
Ok(final_tensor.data.into_raw_vec_and_offset().0)
}
}
pub struct TensorNetworkSimulator {
max_bond_dim: usize,
use_compression: bool,
parallel_threshold: usize,
}
impl TensorNetworkSimulator {
pub const fn new() -> Self {
Self {
max_bond_dim: 64,
use_compression: true,
parallel_threshold: 1000,
}
}
#[must_use]
pub const fn with_max_bond_dim(mut self, dim: usize) -> Self {
self.max_bond_dim = dim;
self
}
#[must_use]
pub const fn with_compression(mut self, compress: bool) -> Self {
self.use_compression = compress;
self
}
pub fn simulate<const N: usize>(
&self,
gates: &[Box<dyn GateOp>],
) -> QuantRS2Result<Register<N>> {
let mut builder = TensorNetworkBuilder::new(N);
for gate in gates {
let qubits = gate.qubits();
match qubits.len() {
1 => builder.apply_single_qubit_gate(gate.as_ref(), qubits[0].0 as usize)?,
2 => builder.apply_two_qubit_gate(
gate.as_ref(),
qubits[0].0 as usize,
qubits[1].0 as usize,
)?,
_ => {
return Err(QuantRS2Error::UnsupportedOperation(format!(
"Gates with {} qubits not supported in tensor network",
qubits.len()
)))
}
}
}
let amplitudes = builder.to_statevector()?;
Register::with_amplitudes(amplitudes)
}
}
pub mod contraction_optimization {
use super::*;
pub struct DynamicProgrammingOptimizer {
memo: HashMap<Vec<usize>, (usize, Vec<(usize, usize)>)>,
}
impl DynamicProgrammingOptimizer {
pub fn new() -> Self {
Self {
memo: HashMap::new(),
}
}
pub fn optimize(&mut self, network: &TensorNetwork) -> Vec<(usize, usize)> {
let tensor_ids: Vec<_> = network.tensors.keys().copied().collect();
self.find_optimal_order(&tensor_ids, network).1
}
fn find_optimal_order(
&mut self,
tensors: &[usize],
network: &TensorNetwork,
) -> (usize, Vec<(usize, usize)>) {
if tensors.len() <= 1 {
return (0, vec![]);
}
let key = tensors.to_vec();
if let Some(result) = self.memo.get(&key) {
return result.clone();
}
let mut best_cost = usize::MAX;
let mut best_order = vec![];
for i in 0..tensors.len() {
for j in (i + 1)..tensors.len() {
if self.are_connected(tensors[i], tensors[j], network) {
let cost = network.estimate_contraction_cost(tensors[i], tensors[j]);
let mut remaining = vec![];
for (k, &t) in tensors.iter().enumerate() {
if k != i && k != j {
remaining.push(t);
}
}
remaining.push(network.next_id + remaining.len());
let (sub_cost, sub_order) = self.find_optimal_order(&remaining, network);
let total_cost = cost + sub_cost;
if total_cost < best_cost {
best_cost = total_cost;
best_order = vec![(tensors[i], tensors[j])];
best_order.extend(sub_order);
}
}
}
}
self.memo.insert(key, (best_cost, best_order.clone()));
(best_cost, best_order)
}
fn are_connected(&self, t1: usize, t2: usize, network: &TensorNetwork) -> bool {
network.edges.iter().any(|e| {
(e.tensor1 == t1 && e.tensor2 == t2) || (e.tensor1 == t2 && e.tensor2 == t1)
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_creation() {
let data = ArrayD::zeros(IxDyn(&[2, 2]));
let tensor = Tensor::new(0, data, vec!["in".to_string(), "out".to_string()]);
assert_eq!(tensor.rank(), 2);
assert_eq!(tensor.shape, vec![2, 2]);
}
#[test]
fn test_qubit_tensors() {
let t0 = Tensor::qubit_zero(0, "q0".to_string());
assert_eq!(t0.data[[0]], Complex64::new(1.0, 0.0));
assert_eq!(t0.data[[1]], Complex64::new(0.0, 0.0));
let t1 = Tensor::qubit_one(1, "q1".to_string());
assert_eq!(t1.data[[0]], Complex64::new(0.0, 0.0));
assert_eq!(t1.data[[1]], Complex64::new(1.0, 0.0));
}
#[test]
fn test_tensor_network_builder() {
let builder = TensorNetworkBuilder::new(2);
assert_eq!(builder.network.tensors.len(), 2);
}
#[test]
fn test_network_connection() {
let mut network = TensorNetwork::new();
let t1 = Tensor::qubit_zero(0, "q0".to_string());
let t2 = Tensor::qubit_zero(1, "q1".to_string());
let id1 = network.add_tensor(t1);
let id2 = network.add_tensor(t2);
assert!(network
.connect(id1, "bond".to_string(), id2, "bond".to_string())
.is_err());
}
}