use hugr::ops::OpType;
use itertools::izip;
use std::fmt::Debug;
use std::iter::Sum;
use std::num::NonZeroUsize;
use std::ops::{Add, AddAssign};
use crate::ops::op_matches;
use crate::Tk2Op;
pub trait CircuitCost: Add<Output = Self> + Sum<Self> + Debug + Default + Clone + Ord {
type CostDelta: CostDelta;
fn as_usize(&self) -> usize;
fn sub_cost(&self, other: &Self) -> Self::CostDelta;
fn add_delta(&self, delta: &Self::CostDelta) -> Self;
fn div_cost(&self, n: NonZeroUsize) -> Self;
}
pub trait CostDelta:
AddAssign + Add<Output = Self> + Sum<Self> + Debug + Default + Clone + Ord
{
fn as_isize(&self) -> isize;
}
pub type MajorMinorCost<T = usize> = LexicographicCost<T, 2>;
impl<const N: usize, V, T> From<V> for LexicographicCost<T, N>
where
V: Into<[T; N]>,
{
fn from(v: V) -> Self {
Self(v.into())
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct LexicographicCost<T, const N: usize>([T; N]);
impl<const N: usize, T: Default + Copy> Default for LexicographicCost<T, N> {
fn default() -> Self {
Self([Default::default(); N])
}
}
impl<const N: usize> serde::Serialize for LexicographicCost<usize, N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&format!("{self:?}"))
}
}
impl<T: Debug, const N: usize> Debug for LexicographicCost<T, N> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}
impl<T: Add<Output = T> + Copy, const N: usize> Add for LexicographicCost<T, N> {
type Output = Self;
fn add(mut self, rhs: Self) -> Self::Output {
for i in 0..N {
self.0[i] = self.0[i] + rhs.0[i];
}
self
}
}
impl<T: AddAssign + Copy, const N: usize> AddAssign for LexicographicCost<T, N> {
fn add_assign(&mut self, rhs: Self) {
for i in 0..N {
self.0[i] += rhs.0[i];
}
}
}
impl<T: Add<Output = T> + Default + Copy, const N: usize> Sum for LexicographicCost<T, N> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(|a, b| a + b).unwrap_or_default()
}
}
impl<const N: usize> CostDelta for LexicographicCost<isize, N> {
#[inline]
fn as_isize(&self) -> isize {
if N > 0 {
self.0[0]
} else {
0
}
}
}
impl<const N: usize> CircuitCost for LexicographicCost<usize, N> {
type CostDelta = LexicographicCost<isize, N>;
#[inline]
fn as_usize(&self) -> usize {
if N > 0 {
self.0[0]
} else {
0
}
}
#[inline]
fn sub_cost(&self, other: &Self) -> Self::CostDelta {
let mut costdelta = [0; N];
for (delta, &a, &b) in izip!(costdelta.iter_mut(), &self.0, &other.0) {
*delta = (a as isize) - (b as isize);
}
LexicographicCost(costdelta)
}
#[inline]
fn add_delta(&self, delta: &Self::CostDelta) -> Self {
let mut ret = [0; N];
for (add, &a, &b) in izip!(ret.iter_mut(), &self.0, &delta.0) {
*add = a.saturating_add_signed(b);
}
Self(ret)
}
#[inline]
fn div_cost(&self, n: NonZeroUsize) -> Self {
let mut ret = [0; N];
for (div, &a) in ret.iter_mut().zip(&self.0) {
*div = (a.saturating_sub(1)) / n.get() + 1;
}
Self(ret)
}
}
impl CostDelta for isize {
#[inline]
fn as_isize(&self) -> isize {
*self
}
}
impl CircuitCost for usize {
type CostDelta = isize;
#[inline]
fn as_usize(&self) -> usize {
*self
}
#[inline]
fn sub_cost(&self, other: &Self) -> Self::CostDelta {
(*self as isize) - (*other as isize)
}
#[inline]
fn add_delta(&self, delta: &Self::CostDelta) -> Self {
self.saturating_add_signed(*delta)
}
#[inline]
fn div_cost(&self, n: NonZeroUsize) -> Self {
(self.saturating_sub(1)) / n.get() + 1
}
}
pub fn is_cx(op: &OpType) -> bool {
op_matches(op, Tk2Op::CX)
}
pub fn is_quantum(op: &OpType) -> bool {
let Some(op): Option<Tk2Op> = op.cast() else {
return false;
};
op.is_quantum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn major_minor() {
let a = LexicographicCost([10, 2]);
let b = LexicographicCost([20, 1]);
assert!(a < b);
assert_eq!(a + b, LexicographicCost([30, 3]));
assert_eq!(a.sub_cost(&b).as_isize(), -10);
assert_eq!(b.sub_cost(&a).as_isize(), 10);
assert_eq!(
a.div_cost(NonZeroUsize::new(2).unwrap()),
LexicographicCost([5, 1])
);
assert_eq!(
a.div_cost(NonZeroUsize::new(3).unwrap()),
LexicographicCost([4, 1])
);
assert_eq!(
a.div_cost(NonZeroUsize::new(1).unwrap()),
LexicographicCost([10, 2])
);
}
#[test]
fn zero_dim_cost() {
let a = LexicographicCost::<usize, 0>([]);
let b = LexicographicCost::<usize, 0>([]);
assert_eq!(a, b);
assert_eq!(a + b, LexicographicCost::<usize, 0>([]));
assert_eq!(a.sub_cost(&b).as_isize(), 0);
assert_eq!(b.sub_cost(&a).as_isize(), 0);
assert_eq!(a.div_cost(NonZeroUsize::new(2).unwrap()), a);
assert_eq!(a.div_cost(NonZeroUsize::new(3).unwrap()), a);
assert_eq!(a.div_cost(NonZeroUsize::new(1).unwrap()), a);
}
#[test]
fn as_usize() {
let a = LexicographicCost([10, 2]);
assert_eq!(a.as_usize(), 10);
let a = LexicographicCost::<usize, 0>([]);
assert_eq!(a.as_usize(), 0);
}
#[test]
fn serde_serialize() {
let a = LexicographicCost([10, 2]);
let s = serde_json::to_string(&a).unwrap();
assert_eq!(s, "\"[10, 2]\"");
}
}