use std::ops::{Mul, Sub};
use dashmap::DashMap;
use ndarray::Array2;
use num::Complex;
use rayon::prelude::*;
pub const EPSILON: f32 = 1e-9;
pub const MULTITHREADING_THRESHOLD: u32 =
2_u32.pow(crate::state_vector::MULTITHREADING_THRESHOLD as u32);
#[derive(Clone, Debug)]
pub struct SMat {
pub data: DashMap<(u32, u32), (Option<f32>, Option<f32>)>,
pub cols: u32,
pub rows: u32,
}
impl SMat {
pub fn get_unchecked(&self, i: u32, j: u32) -> Complex<f32> {
let g = self.data.get(&(i, j));
if let Some(g) = g {
Complex::new(g.0.unwrap_or(0.0), g.1.unwrap_or(0.0))
} else {
Complex::new(0.0, 0.0)
}
}
pub fn get(&self, i: u32, j: u32) -> Option<Complex<f32>> {
if i > self.rows || j > self.cols {
return None;
}
Some(self.get_unchecked(i, j))
}
pub fn set_unchecked(&self, i: u32, j: u32, v: Complex<f32>) {
self.data.remove(&(i, j));
let mut nr = None;
let mut ni = None;
if v.re.abs() > EPSILON {
nr = Some(v.re);
}
if v.im.abs() > EPSILON {
ni = Some(v.im);
}
if nr.is_some() || ni.is_some() {
self.data.insert((i, j), (nr, ni));
}
}
pub fn zeros(dim: (u32, u32)) -> Self {
let data = DashMap::new();
let rows = dim.0;
let cols = dim.1;
Self { data, rows, cols }
}
pub fn from_dense(d: &Array2<Complex<f32>>) -> Self {
let shape = d.shape();
let s = Self::zeros((shape[0] as u32, shape[1] as u32));
for (i, v) in d.indexed_iter() {
s.set_unchecked(i.0 as u32, i.1 as u32, *v);
}
s
}
pub fn to_dense(&self) -> Array2<Complex<f32>> {
let mut res = Array2::zeros((self.rows as usize, self.cols as usize));
self.data.iter().for_each(|x| {
let key = x.key();
let val = x.value();
*res.get_mut((key.0 as usize, key.1 as usize)).unwrap() =
Complex::new(val.0.unwrap_or(0.0), val.1.unwrap_or(0.0));
});
res
}
pub fn conj_trans(&self) -> Self {
let r = SMat::zeros((self.cols, self.rows));
self.data.iter().for_each(|x| {
let key = x.key();
let val = x.value();
r.set_unchecked(
key.1,
key.0,
Complex::new(val.0.unwrap_or(0.0), val.1.unwrap_or(0.0)),
);
});
r
}
pub fn eye(n: u32) -> Self {
let data = DashMap::with_capacity(n as usize);
for i in 0..n {
data.insert((i, i), (Some(1.0), Some(0.0)));
}
Self {
data,
rows: n,
cols: n,
}
}
pub fn dot_par(&self, other: &Self) -> Self {
if self.cols != other.rows {
panic!("dim mismatch when dotting");
}
let result = Self::zeros((self.rows, other.cols));
self.data.iter().par_bridge().for_each(|thingy1| {
let (i, k) = *thingy1.key();
let (real, imag) = *thingy1.value();
let (real, imag) = (real.unwrap_or(0.0), imag.unwrap_or(0.0));
other
.data
.iter()
.par_bridge()
.filter(|t| t.key().0 == k)
.for_each(|thingy2| {
let (_, j) = *thingy2.key();
let (oreal, oimag) = *thingy2.value();
let (oreal, oimag) = (oreal.unwrap_or(0.0), oimag.unwrap_or(0.0));
result.set_unchecked(
i,
j,
result.get_unchecked(i, j)
+ Complex::new(
real * oreal - imag * oimag,
real * oimag + imag * oreal,
),
);
});
});
result
}
pub fn dot_single(&self, other: &Self) -> Self {
if self.cols != other.rows {
panic!("dim mismatch when dotting");
}
let result = Self::zeros((self.rows, other.cols));
self.data.iter().for_each(|thingy1| {
let (i, k) = *thingy1.key();
let (real, imag) = *thingy1.value();
let (real, imag) = (real.unwrap_or(0.0), imag.unwrap_or(0.0));
other
.data
.iter()
.filter(|t| t.key().0 == k)
.for_each(|thingy2| {
let (_, j) = *thingy2.key();
let (oreal, oimag) = *thingy2.value();
let (oreal, oimag) = (oreal.unwrap_or(0.0), oimag.unwrap_or(0.0));
result.set_unchecked(
i,
j,
result.get_unchecked(i, j)
+ Complex::new(
real * oreal - imag * oimag,
real * oimag + imag * oreal,
),
);
});
});
result
}
pub fn dot(&self, other: &Self) -> Self {
if self.rows > MULTITHREADING_THRESHOLD || self.cols > MULTITHREADING_THRESHOLD {
self.dot_par(other)
} else {
self.dot_single(other)
}
}
pub fn kron(&self, other: &Self) -> Self {
if self.rows > MULTITHREADING_THRESHOLD || self.cols > MULTITHREADING_THRESHOLD {
self.kron_par(other)
} else {
self.kron_single(other)
}
}
pub fn forbenius_norm(&self) -> Complex<f32> {
self.data
.iter()
.map(|x| {
let value = x.value();
Complex::new(value.0.unwrap_or(0.0), value.1.unwrap_or(0.0)).powu(2)
})
.sum::<Complex<f32>>()
.sqrt()
}
pub fn transpose(&self) -> Self {
let nd = self.clone();
self.data.iter().for_each(|k| {
let (key, val) = (k.key(), *k.value());
nd.data.insert((key.1, key.0), val);
});
nd
}
pub fn kron_single(&self, other: &Self) -> Self {
let result = Self::zeros((self.rows * other.rows, self.cols * other.cols));
let offset = (other.rows, other.cols);
self.data.iter().for_each(|thingy1| {
let (i, j) = *thingy1.key();
let (real, imag) = *thingy1.value();
let (real, imag) = (real.unwrap_or(0.0), imag.unwrap_or(0.0));
let c = Complex::new(real, imag);
other.data.iter().for_each(|thingy2| {
let (oi, oj) = *thingy2.key();
let (oreal, oimag) = *thingy2.value();
let (oreal, oimag) = (oreal.unwrap_or(0.0), oimag.unwrap_or(0.0));
let oc = Complex::new(oreal, oimag);
let offset_index = (i * offset.0 + oi, j * offset.1 + oj);
result.set_unchecked(offset_index.0, offset_index.1, c * oc);
});
});
result
}
pub fn kron_par(&self, other: &Self) -> Self {
let result = Self::zeros((self.rows * other.rows, self.cols * other.cols));
let offset = (other.rows, other.cols);
self.data.iter().par_bridge().for_each(|thingy1| {
let (i, j) = *thingy1.key();
let (real, imag) = *thingy1.value();
let (real, imag) = (real.unwrap_or(0.0), imag.unwrap_or(0.0));
let c = Complex::new(real, imag);
other.data.iter().par_bridge().for_each(|thingy2| {
let (oi, oj) = *thingy2.key();
let (oreal, oimag) = *thingy2.value();
let (oreal, oimag) = (oreal.unwrap_or(0.0), oimag.unwrap_or(0.0));
let oc = Complex::new(oreal, oimag);
let offset_index = (i * offset.0 + oi, j * offset.1 + oj);
result.set_unchecked(offset_index.0, offset_index.1, c * oc);
});
});
result
}
pub fn norm_1d(&self) -> f32 {
assert!(
self.rows == 1,
"norm_1d can only be called on matrices with 1 row"
);
self.data
.iter()
.map(|x| {
let v = x.value();
let c = Complex::new(v.0.unwrap_or(0.0), v.1.unwrap_or(0.0));
c.norm()
})
.sum::<f32>()
}
pub fn normalize_1d(&self) -> Self {
assert!(
self.rows == 1,
"normalize_1d can only be called on matrices with 1 row"
);
let norm = self.norm_1d();
self.clone() * norm.recip()
}
}
impl Sub for SMat {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
let n = self.clone();
rhs.data.iter().for_each(|x| {
let key = x.key();
let val = x.value();
n.set_unchecked(
key.0,
key.1,
n.get_unchecked(key.0, key.1)
- Complex::new(val.0.unwrap_or(0.0), val.1.unwrap_or(0.0)),
);
});
n
}
}
impl Mul<Complex<f32>> for SMat {
type Output = Self;
fn mul(self, rhs: Complex<f32>) -> Self::Output {
let s = self.clone();
s.data.iter_mut().for_each(|mut x| {
let v = x.value_mut();
*v = (
v.0.and_then(|x| Some(x * rhs.re)),
v.1.and_then(|x| Some(x * rhs.im)),
);
});
s
}
}
impl Mul<f32> for SMat {
type Output = Self;
fn mul(self, rhs: f32) -> Self::Output {
let s = self.clone();
s.data.iter_mut().for_each(|mut x| {
let v = x.value_mut();
*v = (
v.0.and_then(|x| Some(x * rhs)),
v.1.and_then(|x| Some(x * rhs)),
);
});
s
}
}
impl PartialEq for SMat {
fn eq(&self, other: &Self) -> bool {
if self.rows != other.rows && self.cols != other.cols {
return false;
}
if self.data.len() != other.data.len() {
return false;
}
let mut iseq = true;
let mut sd: Vec<((u32, u32), (Option<f32>, Option<f32>))> = self
.data
.iter()
.map(|k| (k.key().clone(), k.value().clone()))
.collect();
sd.sort_by(|a, b| a.0.cmp(&b.0));
let mut od: Vec<((u32, u32), (Option<f32>, Option<f32>))> = other
.data
.iter()
.map(|k| (k.key().clone(), k.value().clone()))
.collect();
od.sort_by(|a, b| a.0.cmp(&b.0));
for (a, b) in sd.iter().zip(od.iter()) {
let ak = a.0;
let av = a.1;
let bk = b.0;
let bv = b.1;
if ak.0 != bk.0 || ak.1 != bk.1 {
iseq = false;
break;
}
if av.0 != bv.0 || av.1 != bv.1 {
iseq = false;
break;
}
}
iseq
}
}
#[cfg(test)]
mod test_sparse_mat {
use crate::instruction;
use super::SMat;
#[test]
fn mat_mul() {
let x = instruction::gate_matrices::PAULI_X.clone();
let z = instruction::gate_matrices::PAULI_Z.clone();
let dense = &x.dot(&z);
let dense = SMat::from_dense(&dense);
let sparse_x = SMat::from_dense(&x);
let sparse_z = SMat::from_dense(&z);
let sparse = sparse_x.dot(&sparse_z);
assert_eq!(dense, sparse);
}
#[test]
fn mat_mul_par() {
let x = instruction::gate_matrices::PAULI_X.clone();
let z = instruction::gate_matrices::PAULI_Z.clone();
let dense = &x.dot(&z);
let dense = SMat::from_dense(&dense);
let sparse_x = SMat::from_dense(&x);
let sparse_z = SMat::from_dense(&z);
let sparse = sparse_x.dot_par(&sparse_z);
assert_eq!(dense, sparse);
}
#[test]
fn kron() {
let x = instruction::gate_matrices::PAULI_X.clone();
let z = instruction::gate_matrices::PAULI_Z.clone();
let k = ndarray::linalg::kron(&x, &z);
let sk = SMat::from_dense(&k);
let s = SMat::from_dense(&x).kron(&SMat::from_dense(&z));
assert_eq!(sk, s);
}
}
pub fn kron(a: &SMat, b: &SMat) -> SMat {
a.kron(b)
}