use core::f64::consts::{FRAC_1_SQRT_2, PI};
use alloc::boxed::Box;
use num_complex::Complex64;
use super::circuit::Circuit;
#[derive(Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub enum Gate {
Hadamard,
Not,
Phase(Complex64),
Swap,
Controlled(Box<Gate>),
Circuit(Circuit),
}
impl Gate {
#[inline]
pub fn hadamard() -> Self {
Self::Hadamard
}
#[inline]
pub fn not() -> Self {
Self::Not
}
#[inline]
pub fn phase_radians(angle: f64) -> Self {
Self::Phase(Complex64::new(0.0, angle).exp())
}
#[inline]
pub fn phase_fraction(fraction: f64) -> Self {
Self::phase_radians(2.0 * PI * fraction)
}
#[inline]
pub fn swap() -> Self {
Self::Swap
}
#[inline]
pub fn control(self) -> Self {
Self::Controlled(Box::new(self))
}
#[inline]
pub fn multi_control(self, n: u8) -> Self {
(0..n).fold(self, |gate, _| gate.control())
}
#[inline]
pub fn circuit(circuit: Circuit) -> Self {
Self::Circuit(circuit)
}
#[inline]
pub fn cnot() -> Self {
Self::Not.control()
}
#[inline]
pub fn toffoli() -> Self {
Self::cnot().control()
}
#[inline]
pub fn fredkin() -> Self {
Self::swap().control()
}
#[inline]
pub fn count_controlled(&self) -> u8 {
if let Self::Controlled(inner) = self {
inner.count_controlled() + 1
} else {
0
}
}
#[inline]
pub fn inverse(&self) -> Self {
match self {
Self::Hadamard => Self::Hadamard,
Self::Not => Self::Not,
Self::Swap => Self::Swap,
Self::Phase(c) => Self::Phase(c.conj()),
Self::Controlled(inner_gate) => Self::Controlled(Box::new(inner_gate.inverse())),
Self::Circuit(circuit) => Self::Circuit(circuit.inverse()),
}
}
pub fn apply(&self, state: &mut [Complex64]) {
match self {
Self::Hadamard => {
debug_assert!(
state.len() >= 2,
"Hadamard gate requires at least 2 amplitudes"
);
let (a, b) = (state[0], state[1]);
state[0] = FRAC_1_SQRT_2 * (a + b);
state[1] = FRAC_1_SQRT_2 * (a - b);
}
Self::Not => {
debug_assert!(state.len() >= 2, "NOT gate requires at least 2 amplitudes");
state.swap(0, 1);
}
Self::Swap => {
debug_assert!(state.len() >= 3, "SWAP gate requires at least 3 amplitudes");
state.swap(1, 2);
}
Self::Phase(phase) => {
debug_assert!(
state.len() >= 2,
"Phase gate requires at least 2 amplitudes"
);
state[1] *= phase;
}
Self::Controlled(gate) => {
let substate_middle = state.len() >> 1;
gate.apply(&mut state[substate_middle..])
}
Self::Circuit(circuit) => {
circuit.apply(state);
}
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use std::println;
use crate::tests::{
assert_state_eq,
vectors::{self, TestVector},
};
use super::*;
fn validate_state_normalization(state: &[Complex64], tolerance: f64) -> bool {
let norm_squared: f64 = state.iter().map(|amp| amp.norm_sqr()).sum();
(norm_squared - 1.0).abs() < tolerance
}
pub(crate) fn run_parameterized_gate_tests<'a, Args: 'a>(
title: &str,
build_gate: impl Fn(&'a TestVector<Args>) -> Gate,
test_vectors: impl IntoIterator<Item = &'a TestVector<Args>>,
) {
const STATE_NORM_TOLERANCE: f64 = 1E-6;
for (i, test_vector) in test_vectors.into_iter().enumerate() {
println!("{title} circuit test case #{i}");
let circuit =
Circuit::from_gate(build_gate(test_vector), test_vector.qubits.iter().copied())
.unwrap();
let mut state = test_vector.initial_state.clone();
validate_state_normalization(&test_vector.expected_state, STATE_NORM_TOLERANCE);
validate_state_normalization(&state, STATE_NORM_TOLERANCE);
circuit.apply(&mut state);
assert_state_eq(&state, &test_vector.expected_state);
let inverse_circuit = circuit.inverse();
inverse_circuit.apply(&mut state);
assert_state_eq(&state, &test_vector.initial_state);
}
}
pub(crate) fn run_gate_tests<'a, Args: 'a>(
title: &str,
gate: Gate,
test_vectors: impl IntoIterator<Item = &'a TestVector<Args>>,
) {
run_parameterized_gate_tests(title, |_| gate.clone(), test_vectors)
}
#[test]
fn hadamard_gate() {
run_gate_tests("HADAMARD", Gate::hadamard(), &*vectors::HADAMARD_TESTS);
}
#[test]
fn not_gate() {
run_gate_tests("NOT", Gate::not(), &*vectors::NOT_TESTS);
}
#[test]
fn phase_gate() {
run_parameterized_gate_tests(
"PHASE",
|test_vector| Gate::phase_fraction(test_vector.args.fraction),
&*vectors::PHASE_TESTS,
);
}
#[test]
fn cnot_gate() {
run_gate_tests("CNOT", Gate::cnot(), &*vectors::CNOT_TESTS);
}
#[test]
fn swap_gate() {
run_gate_tests("SWAP", Gate::swap(), &*vectors::SWAP_TESTS);
}
#[test]
fn toffoli_gate() {
run_gate_tests("TOFFOLI", Gate::toffoli(), &*vectors::TOFFOLI_TESTS);
}
#[test]
fn fredkin_gate() {
run_gate_tests("FREDKIN", Gate::fredkin(), &*vectors::FREDKIN_TESTS);
}
}