use ndarray::{Array2, arr2};
use num_complex::Complex;
use num_traits::{Float, FromPrimitive, One, Zero};
#[derive(Debug, Clone)]
pub struct Gate<T> {
pub name: String,
pub matrix: Array2<Complex<T>>,
}
impl<T> Gate<T>
where
T: Float + FromPrimitive + Copy + PartialOrd + FromPrimitive + std::fmt::LowerExp + 'static,
{
pub fn new(name: String, matrix: Array2<Complex<T>>) -> Result<Self, String> {
let rows = matrix.shape()[0];
let cols = matrix.shape()[1];
if rows != cols {
return Err(format!(
"Matrix for gate '{}' must be square, but has dimensions {}x{}.",
name, rows, cols
));
}
if !is_unitary(&matrix) {
Err(format!(
"Matrix for gate '{}' is not unitary. Max difference from identity:",
name
))
} else {
Ok(Self { name, matrix })
}
}
pub fn i() -> Self {
Self::new(
"I".to_string(),
arr2(&[
[Complex::one(), Complex::zero()],
[Complex::zero(), Complex::one()],
]),
)
.unwrap()
}
pub fn x() -> Self {
Self::new(
"X".to_string(),
arr2(&[
[Complex::zero(), Complex::one()],
[Complex::one(), Complex::zero()],
]),
)
.unwrap()
}
pub fn y() -> Self {
Self::new(
"Y".to_string(),
arr2(&[
[Complex::zero(), -Complex::i()],
[Complex::i(), Complex::zero()],
]),
)
.unwrap()
}
pub fn z() -> Self {
Self::new(
"Z".to_string(),
arr2(&[
[Complex::one(), Complex::zero()],
[Complex::zero(), -Complex::one()],
]),
)
.unwrap()
}
pub fn h() -> Self {
let factor = T::one() / T::from(2.0).unwrap().sqrt();
Self::new(
"H".to_string(),
arr2(&[
[Complex::one() * factor, Complex::one() * factor],
[Complex::one() * factor, -Complex::one() * factor],
]),
)
.unwrap()
}
pub fn s() -> Self {
Self::new(
"S".to_string(),
arr2(&[
[Complex::one(), Complex::zero()],
[Complex::zero(), Complex::i()],
]),
)
.unwrap()
}
pub fn t() -> Self {
let pi = T::from(std::f64::consts::PI).unwrap();
let angle = pi / T::from(4.0).unwrap();
Self::new(
"T".to_string(),
arr2(&[
[Complex::one(), Complex::zero()],
[Complex::zero(), Complex::new(angle.cos(), angle.sin())],
]),
)
.unwrap()
}
pub fn cnot() -> Self {
Self::new(
"CNOT".to_string(),
arr2(&[
[
Complex::one(),
Complex::zero(),
Complex::zero(),
Complex::zero(),
],
[
Complex::zero(),
Complex::one(),
Complex::zero(),
Complex::zero(),
],
[
Complex::zero(),
Complex::zero(),
Complex::zero(),
Complex::one(),
],
[
Complex::zero(),
Complex::zero(),
Complex::one(),
Complex::zero(),
],
]),
)
.unwrap()
}
}
fn is_unitary<T>(matrix: &Array2<Complex<T>>) -> bool
where
T: Float + 'static,
{
let product = matrix.dot(&matrix.t().mapv(|c| c.conj()));
let identity = Array2::<Complex<T>>::eye(matrix.shape()[0]);
let diff = &product - &identity;
let max_diff_norm = diff.iter().map(|c| c.norm()).fold(T::zero(), T::max);
let epsilon = T::from(1e-6).unwrap(); max_diff_norm <= epsilon
}