use crate::error::{PgmError, Result};
use crate::graph::FactorGraph;
use crate::linear_chain_crf::LinearChainCRF;
use quantrs2_sim::Complex64;
use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayD};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct Tensor {
pub name: String,
pub data: ArrayD<Complex64>,
pub indices: Vec<String>,
pub bond_dims: Vec<usize>,
}
impl Tensor {
pub fn new(name: String, data: ArrayD<Complex64>, indices: Vec<String>) -> Self {
let bond_dims = data.shape().to_vec();
Self {
name,
data,
indices,
bond_dims,
}
}
pub fn from_real(name: String, data: ArrayD<f64>, indices: Vec<String>) -> Self {
let complex_data = data.mapv(|x| Complex64::new(x, 0.0));
Self::new(name, complex_data, indices)
}
pub fn rank(&self) -> usize {
self.indices.len()
}
pub fn bond_dim(&self, index: &str) -> Option<usize> {
self.indices
.iter()
.position(|i| i == index)
.map(|pos| self.bond_dims[pos])
}
pub fn contract(&self, other: &Tensor) -> Result<Tensor> {
let shared: Vec<(usize, usize)> = self
.indices
.iter()
.enumerate()
.filter_map(|(i, idx)| {
other
.indices
.iter()
.position(|oidx| oidx == idx)
.map(|j| (i, j))
})
.collect();
if shared.is_empty() {
return self.outer_product(other);
}
let result_indices: Vec<String> = self
.indices
.iter()
.enumerate()
.filter(|(i, _)| !shared.iter().any(|(si, _)| si == i))
.map(|(_, idx)| idx.clone())
.chain(
other
.indices
.iter()
.enumerate()
.filter(|(j, _)| !shared.iter().any(|(_, sj)| sj == j))
.map(|(_, idx)| idx.clone()),
)
.collect();
let result_shape: Vec<usize> = self
.bond_dims
.iter()
.enumerate()
.filter(|(i, _)| !shared.iter().any(|(si, _)| si == i))
.map(|(_, &d)| d)
.chain(
other
.bond_dims
.iter()
.enumerate()
.filter(|(j, _)| !shared.iter().any(|(_, sj)| sj == j))
.map(|(_, &d)| d),
)
.collect();
let result_data = self.contract_data(other, &shared, &result_shape)?;
Ok(Tensor {
name: format!("{}*{}", self.name, other.name),
data: result_data,
indices: result_indices,
bond_dims: result_shape,
})
}
fn contract_data(
&self,
_other: &Tensor,
_shared: &[(usize, usize)],
result_shape: &[usize],
) -> Result<ArrayD<Complex64>> {
let total_size: usize = result_shape.iter().product();
let data = vec![Complex64::new(0.0, 0.0); total_size.max(1)];
ArrayD::from_shape_vec(result_shape.to_vec(), data)
.map_err(|e| PgmError::InvalidDistribution(format!("Contraction failed: {}", e)))
}
fn outer_product(&self, other: &Tensor) -> Result<Tensor> {
let result_indices: Vec<String> = self
.indices
.iter()
.chain(other.indices.iter())
.cloned()
.collect();
let result_shape: Vec<usize> = self
.bond_dims
.iter()
.chain(other.bond_dims.iter())
.copied()
.collect();
let total_size: usize = result_shape.iter().product();
let mut data = vec![Complex64::new(0.0, 0.0); total_size.max(1)];
for (i, &a) in self.data.iter().enumerate() {
for (j, &b) in other.data.iter().enumerate() {
data[i * other.data.len() + j] = a * b;
}
}
Ok(Tensor {
name: format!("{}⊗{}", self.name, other.name),
data: ArrayD::from_shape_vec(result_shape.clone(), data).map_err(|e| {
PgmError::InvalidDistribution(format!("Outer product failed: {}", e))
})?,
indices: result_indices,
bond_dims: result_shape,
})
}
}
#[derive(Debug, Clone)]
pub struct TensorNetwork {
tensors: Vec<Tensor>,
physical_indices: Vec<String>,
bond_indices: Vec<String>,
}
impl TensorNetwork {
pub fn new() -> Self {
Self {
tensors: Vec::new(),
physical_indices: Vec::new(),
bond_indices: Vec::new(),
}
}
pub fn add_tensor(&mut self, tensor: Tensor) {
self.tensors.push(tensor);
}
pub fn add_physical_index(&mut self, index: String) {
if !self.physical_indices.contains(&index) {
self.physical_indices.push(index);
}
}
pub fn add_bond_index(&mut self, index: String) {
if !self.bond_indices.contains(&index) {
self.bond_indices.push(index);
}
}
pub fn num_tensors(&self) -> usize {
self.tensors.len()
}
pub fn num_physical_indices(&self) -> usize {
self.physical_indices.len()
}
pub fn total_bond_dim(&self) -> usize {
self.tensors
.iter()
.map(|t| t.bond_dims.iter().product::<usize>())
.sum()
}
pub fn contract(&self) -> Result<Tensor> {
if self.tensors.is_empty() {
return Err(PgmError::InvalidGraph(
"Cannot contract empty tensor network".to_string(),
));
}
let mut result = self.tensors[0].clone();
for tensor in self.tensors.iter().skip(1) {
result = result.contract(tensor)?;
}
Ok(result)
}
pub fn partition_function(&self) -> Result<Complex64> {
let contracted = self.contract()?;
Ok(contracted.data.iter().sum())
}
pub fn marginal(&self, indices: &[String]) -> Result<Tensor> {
let contracted = self.contract()?;
let keep_positions: Vec<usize> = contracted
.indices
.iter()
.enumerate()
.filter_map(
|(i, idx)| {
if indices.contains(idx) {
Some(i)
} else {
None
}
},
)
.collect();
if keep_positions.is_empty() {
let sum: Complex64 = contracted.data.iter().sum();
return Ok(Tensor::new(
"marginal".to_string(),
ArrayD::from_elem(vec![], sum),
vec![],
));
}
let result_shape: Vec<usize> = keep_positions
.iter()
.map(|&i| contracted.bond_dims[i])
.collect();
let result_indices: Vec<String> = keep_positions
.iter()
.map(|&i| contracted.indices[i].clone())
.collect();
Ok(Tensor {
name: "marginal".to_string(),
data: contracted.data, indices: result_indices,
bond_dims: result_shape,
})
}
}
impl Default for TensorNetwork {
fn default() -> Self {
Self::new()
}
}
pub fn factor_graph_to_tensor_network(graph: &FactorGraph) -> Result<TensorNetwork> {
let mut tn = TensorNetwork::new();
for var_name in graph.variable_names() {
tn.add_physical_index(var_name.clone());
}
for factor in graph.factors() {
let indices = factor.variables.clone();
let tensor = Tensor::from_real(factor.name.clone(), factor.values.clone(), indices);
tn.add_tensor(tensor);
}
Ok(tn)
}
#[derive(Debug, Clone)]
pub struct MatrixProductState {
pub tensors: Vec<Array3<Complex64>>,
pub physical_dims: Vec<usize>,
pub bond_dims: Vec<usize>,
}
impl MatrixProductState {
pub fn new(length: usize, physical_dim: usize, bond_dim: usize) -> Self {
let mut tensors = Vec::with_capacity(length);
let mut bond_dims = Vec::with_capacity(length + 1);
bond_dims.push(1);
for i in 0..length {
let left_dim = bond_dims[i];
let right_dim = if i == length - 1 { 1 } else { bond_dim };
bond_dims.push(right_dim);
let tensor = Array3::from_shape_fn((left_dim, physical_dim, right_dim), |_| {
Complex64::new(1.0 / (left_dim * physical_dim * right_dim) as f64, 0.0)
});
tensors.push(tensor);
}
Self {
tensors,
physical_dims: vec![physical_dim; length],
bond_dims,
}
}
pub fn product_state(length: usize, physical_dim: usize) -> Self {
let mut tensors = Vec::with_capacity(length);
for _ in 0..length {
let mut tensor = Array3::zeros((1, physical_dim, 1));
tensor[[0, 0, 0]] = Complex64::new(1.0, 0.0);
tensors.push(tensor);
}
Self {
tensors,
physical_dims: vec![physical_dim; length],
bond_dims: vec![1; length + 1],
}
}
pub fn length(&self) -> usize {
self.tensors.len()
}
pub fn max_bond_dim(&self) -> usize {
*self.bond_dims.iter().max().unwrap_or(&1)
}
pub fn to_state_vector(&self) -> Result<Array1<Complex64>> {
if self.tensors.is_empty() {
return Ok(Array1::from(vec![Complex64::new(1.0, 0.0)]));
}
let total_dim: usize = self.physical_dims.iter().product();
let mut state = Array1::zeros(total_dim);
for basis_idx in 0..total_dim {
let mut indices = vec![0; self.tensors.len()];
let mut temp = basis_idx;
for (i, &dim) in self.physical_dims.iter().enumerate().rev() {
indices[i] = temp % dim;
temp /= dim;
}
let mut amplitude = Complex64::new(1.0, 0.0);
let mut left_idx = 0;
for (site, &phys_idx) in indices.iter().enumerate() {
let tensor = &self.tensors[site];
let right_dim = tensor.shape()[2];
let mut sum = Complex64::new(0.0, 0.0);
for right_idx in 0..right_dim {
sum += tensor[[left_idx, phys_idx, right_idx]];
}
amplitude *= sum;
left_idx = 0; }
state[basis_idx] = amplitude;
}
let norm: f64 = state
.iter()
.map(|x: &Complex64| x.norm_sqr())
.sum::<f64>()
.sqrt();
if norm > 1e-10 {
for x in state.iter_mut() {
*x /= norm;
}
}
Ok(state)
}
pub fn norm(&self) -> f64 {
let state_result: Result<Array1<Complex64>> = self.to_state_vector();
match state_result {
Ok(state) => {
let state_arr: Array1<Complex64> = state;
state_arr
.iter()
.map(|x: &Complex64| x.norm_sqr())
.sum::<f64>()
.sqrt()
}
Err(_) => 0.0,
}
}
pub fn expectation_local(
&self,
site: usize,
operator: &Array2<Complex64>,
) -> Result<Complex64> {
if site >= self.tensors.len() {
return Err(PgmError::VariableNotFound(format!(
"Site {} out of range",
site
)));
}
let state = self.to_state_vector()?;
let mut result = Complex64::new(0.0, 0.0);
let num_sites = self.tensors.len();
let total_dim: usize = self.physical_dims.iter().product();
for basis_idx in 0..total_dim {
let mut indices = vec![0; num_sites];
let mut temp = basis_idx;
for (i, &dim) in self.physical_dims.iter().enumerate().rev() {
indices[i] = temp % dim;
temp /= dim;
}
for new_idx in 0..self.physical_dims[site] {
let op_elem = operator[[new_idx, indices[site]]];
if op_elem.norm_sqr() > 1e-20 {
let mut new_basis_idx = 0;
let mut multiplier = 1;
for (i, &idx) in indices.iter().enumerate().rev() {
let idx_to_use = if i == site { new_idx } else { idx };
new_basis_idx += idx_to_use * multiplier;
multiplier *= self.physical_dims[i];
}
result += state[new_basis_idx].conj() * op_elem * state[basis_idx];
}
}
}
Ok(result)
}
}
pub fn linear_chain_to_mps(
crf: &LinearChainCRF,
input_sequence: &[usize],
) -> Result<MatrixProductState> {
let factor_graph = crf.to_factor_graph(input_sequence)?;
let num_sites = input_sequence.len();
if num_sites == 0 {
return Err(PgmError::InvalidGraph("Empty sequence".to_string()));
}
let num_states = factor_graph
.get_variable("y_0")
.map(|v| v.cardinality)
.unwrap_or(2);
let mut mps = MatrixProductState::new(num_sites, num_states, num_states);
for t in 0..num_sites {
let emission_name = format!("emission_{}", t);
let transition_name = format!("transition_{}", t);
if let Some(emission) = factor_graph.get_factor_by_name(&emission_name) {
for (s, &val) in emission.values.iter().enumerate() {
if s < num_states {
mps.tensors[t][[0, s, 0]] = Complex64::new(val.sqrt(), 0.0);
}
}
}
if t > 0 {
if let Some(transition) = factor_graph.get_factor_by_name(&transition_name) {
for s_prev in 0..num_states {
for s_curr in 0..num_states {
if s_prev < transition.values.shape()[0]
&& s_curr < transition.values.shape()[1]
{
let val = transition.values[[s_prev, s_curr]];
let tensor = &mut mps.tensors[t];
let left_dim = tensor.shape()[0];
if s_prev < left_dim && s_curr < num_states {
tensor[[s_prev.min(left_dim - 1), s_curr, 0]] =
Complex64::new(val.sqrt(), 0.0);
}
}
}
}
}
}
}
Ok(mps)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorNetworkStats {
pub num_tensors: usize,
pub total_elements: usize,
pub max_rank: usize,
pub avg_rank: f64,
pub num_physical_indices: usize,
pub num_bond_indices: usize,
}
impl TensorNetwork {
pub fn stats(&self) -> TensorNetworkStats {
let num_tensors = self.tensors.len();
let total_elements: usize = self.tensors.iter().map(|t| t.data.len()).sum();
let max_rank = self.tensors.iter().map(|t| t.rank()).max().unwrap_or(0);
let avg_rank = if num_tensors > 0 {
self.tensors.iter().map(|t| t.rank()).sum::<usize>() as f64 / num_tensors as f64
} else {
0.0
};
TensorNetworkStats {
num_tensors,
total_elements,
max_rank,
avg_rank,
num_physical_indices: self.physical_indices.len(),
num_bond_indices: self.bond_indices.len(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::FactorGraph;
use approx::assert_abs_diff_eq;
#[test]
fn test_tensor_creation() {
let data = ArrayD::from_shape_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("Array creation failed");
let tensor = Tensor::from_real(
"test".to_string(),
data,
vec!["i".to_string(), "j".to_string()],
);
assert_eq!(tensor.rank(), 2);
assert_eq!(tensor.bond_dim("i"), Some(2));
assert_eq!(tensor.bond_dim("j"), Some(3));
}
#[test]
fn test_tensor_network_creation() {
let mut tn = TensorNetwork::new();
let data = ArrayD::from_shape_vec(vec![2], vec![1.0, 0.0]).expect("Array creation failed");
let tensor = Tensor::from_real("A".to_string(), data, vec!["x".to_string()]);
tn.add_tensor(tensor);
tn.add_physical_index("x".to_string());
assert_eq!(tn.num_tensors(), 1);
assert_eq!(tn.num_physical_indices(), 1);
}
#[test]
fn test_factor_graph_to_tn() {
let mut graph = FactorGraph::new();
graph.add_variable_with_card("x".to_string(), "Binary".to_string(), 2);
graph.add_variable_with_card("y".to_string(), "Binary".to_string(), 2);
let tn = factor_graph_to_tensor_network(&graph);
assert!(tn.is_ok());
let tn = tn.expect("TN creation failed");
assert_eq!(tn.num_physical_indices(), 2);
}
#[test]
fn test_mps_creation() {
let mps = MatrixProductState::new(4, 2, 4);
assert_eq!(mps.length(), 4);
assert_eq!(mps.physical_dims.len(), 4);
assert!(mps.max_bond_dim() <= 4);
}
#[test]
fn test_mps_product_state() {
let mps = MatrixProductState::product_state(3, 2);
assert_eq!(mps.length(), 3);
assert_eq!(mps.max_bond_dim(), 1);
let norm = mps.norm();
assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-6);
}
#[test]
fn test_mps_to_state_vector() {
let mps = MatrixProductState::product_state(2, 2);
let state = mps.to_state_vector();
assert!(state.is_ok());
let state = state.expect("State vector failed");
assert_eq!(state.len(), 4); }
#[test]
fn test_tensor_network_stats() {
let mut tn = TensorNetwork::new();
let data1 =
ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).expect("Array creation failed");
let data2 =
ArrayD::from_shape_vec(vec![3, 4], vec![1.0; 12]).expect("Array creation failed");
tn.add_tensor(Tensor::from_real(
"A".to_string(),
data1,
vec!["i".to_string(), "j".to_string()],
));
tn.add_tensor(Tensor::from_real(
"B".to_string(),
data2,
vec!["j".to_string(), "k".to_string()],
));
let stats = tn.stats();
assert_eq!(stats.num_tensors, 2);
assert_eq!(stats.total_elements, 18);
assert_eq!(stats.max_rank, 2);
}
#[test]
fn test_tensor_outer_product() {
let data1 = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).expect("Array creation failed");
let data2 =
ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("Array creation failed");
let t1 = Tensor::from_real("A".to_string(), data1, vec!["i".to_string()]);
let t2 = Tensor::from_real("B".to_string(), data2, vec!["j".to_string()]);
let result = t1.contract(&t2);
assert!(result.is_ok());
let result = result.expect("Contraction failed");
assert_eq!(result.indices.len(), 2);
assert_eq!(result.bond_dims, vec![2, 3]);
}
#[test]
fn test_mps_expectation() {
let mps = MatrixProductState::product_state(2, 2);
let z_op = Array2::from_shape_vec(
(2, 2),
vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(-1.0, 0.0),
],
)
.expect("Operator creation failed");
let exp_val = mps.expectation_local(0, &z_op);
assert!(exp_val.is_ok());
let exp_val = exp_val.expect("Expectation failed");
assert_abs_diff_eq!(exp_val.re, 1.0, epsilon = 1e-6);
}
}