use std::iter;
use std::{collections::HashSet, fmt::Debug};
use derive_more::From;
use hugr::ops::OpType;
use hugr::{HugrView, Node};
use itertools::Itertools;
use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, LexicographicCost};
use crate::{op_matches, Circuit, Tk2Op};
use super::trace::RewriteTrace;
use super::CircuitRewrite;
pub trait RewriteStrategy {
type Cost: CircuitCost;
fn apply_rewrites(
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Circuit,
) -> impl Iterator<Item = RewriteResult<Self::Cost>>;
fn op_cost(&self, op: &OpType) -> Self::Cost;
#[inline]
fn circuit_cost(&self, circ: &Circuit<impl HugrView<Node = Node>>) -> Self::Cost {
circ.circuit_cost(|op| self.op_cost(op))
}
#[inline]
fn pre_rewrite_cost(&self, rw: &CircuitRewrite, circ: &Circuit) -> Self::Cost {
circ.nodes_cost(rw.subcircuit().nodes().iter().copied(), |op| {
self.op_cost(op)
})
}
fn post_rewrite_cost(&self, rw: &CircuitRewrite) -> Self::Cost {
rw.replacement().circuit_cost(|op| self.op_cost(op))
}
}
#[derive(Debug, Clone)]
pub struct RewriteResult<C: CircuitCost> {
pub circ: Circuit,
pub cost_delta: C::CostDelta,
}
impl<C: CircuitCost, T: HugrView<Node = Node>> From<(Circuit<T>, C::CostDelta)>
for RewriteResult<C>
{
#[inline]
fn from((circ, cost_delta): (Circuit<T>, C::CostDelta)) -> Self {
Self {
circ: circ.to_owned(),
cost_delta,
}
}
}
#[derive(Debug, Copy, Clone)]
pub struct GreedyRewriteStrategy;
impl RewriteStrategy for GreedyRewriteStrategy {
type Cost = usize;
#[tracing::instrument(skip_all)]
fn apply_rewrites(
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Circuit,
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
let rewrites = rewrites
.into_iter()
.sorted_by_key(|rw| rw.node_count_delta())
.take_while(|rw| rw.node_count_delta() < 0);
let mut changed_nodes = HashSet::new();
let mut cost_delta = 0;
let mut circ = circ.clone();
for rewrite in rewrites {
if rewrite
.subcircuit()
.nodes()
.iter()
.any(|n| changed_nodes.contains(n))
{
continue;
}
changed_nodes.extend(rewrite.subcircuit().nodes().iter().copied());
cost_delta += rewrite.node_count_delta();
rewrite
.apply(&mut circ)
.expect("Could not perform rewrite in greedy strategy");
}
iter::once((circ, cost_delta).into())
}
fn circuit_cost(&self, circ: &Circuit<impl HugrView<Node = Node>>) -> Self::Cost {
circ.num_operations()
}
fn op_cost(&self, _op: &OpType) -> Self::Cost {
1
}
}
#[derive(Debug, Copy, Clone, From)]
pub struct ExhaustiveGreedyStrategy<T> {
pub strat_cost: T,
}
impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
type Cost = T::OpCost;
#[tracing::instrument(skip_all)]
fn apply_rewrites(
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Circuit,
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
let rewrites = rewrites
.into_iter()
.filter_map(|rw| {
let pattern_cost = self.pre_rewrite_cost(&rw, circ);
let target_cost = self.post_rewrite_cost(&rw);
if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) {
return None;
}
Some((rw, target_cost.sub_cost(&pattern_cost)))
})
.sorted_by_key(|(_, delta)| delta.clone())
.collect_vec();
(0..rewrites.len()).map(move |i| {
let mut curr_circ = circ.clone();
let mut changed_nodes = HashSet::new();
let mut cost_delta = Default::default();
let mut composed_rewrite_count = 0;
for (rewrite, delta) in &rewrites[i..] {
if !changed_nodes.is_empty()
&& rewrite
.invalidation_set()
.any(|n| changed_nodes.contains(&n))
{
continue;
}
changed_nodes.extend(rewrite.invalidation_set());
cost_delta += delta.clone();
composed_rewrite_count += 1;
rewrite
.clone()
.apply_notrace(&mut curr_circ)
.expect("Could not perform rewrite in exhaustive greedy strategy");
}
curr_circ.add_rewrite_trace(RewriteTrace::new(composed_rewrite_count));
(curr_circ, cost_delta).into()
})
}
#[inline]
fn op_cost(&self, op: &OpType) -> Self::Cost {
self.strat_cost.op_cost(op)
}
}
#[derive(Debug, Copy, Clone, From)]
pub struct ExhaustiveThresholdStrategy<T> {
pub strat_cost: T,
}
impl<T: StrategyCost> RewriteStrategy for ExhaustiveThresholdStrategy<T> {
type Cost = T::OpCost;
#[tracing::instrument(skip_all)]
fn apply_rewrites(
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Circuit,
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
rewrites.into_iter().filter_map(|rw| {
let pattern_cost = self.pre_rewrite_cost(&rw, circ);
let target_cost = self.post_rewrite_cost(&rw);
if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) {
return None;
}
let mut circ = circ.clone();
rw.apply(&mut circ).expect("invalid pattern match");
Some((circ, target_cost.sub_cost(&pattern_cost)).into())
})
}
#[inline]
fn op_cost(&self, op: &OpType) -> Self::Cost {
self.strat_cost.op_cost(op)
}
}
pub trait StrategyCost {
type OpCost: CircuitCost;
#[inline]
fn under_threshold(&self, pattern_cost: &Self::OpCost, target_cost: &Self::OpCost) -> bool {
target_cost.sub_cost(pattern_cost).as_isize() <= 0
}
fn op_cost(&self, op: &OpType) -> Self::OpCost;
}
#[derive(Debug, Clone)]
pub struct LexicographicCostFunction<F, const N: usize> {
cost_fns: [F; N],
}
impl<F, const N: usize> StrategyCost for LexicographicCostFunction<F, N>
where
F: Fn(&OpType) -> usize,
{
type OpCost = LexicographicCost<usize, N>;
#[inline]
fn op_cost(&self, op: &OpType) -> Self::OpCost {
let mut costs = [0; N];
for (cost_fn, cost_mut) in self.cost_fns.iter().zip(&mut costs) {
*cost_mut = cost_fn(op);
}
costs.into()
}
}
impl LexicographicCostFunction<fn(&OpType) -> usize, 2> {
pub fn default_cx_strategy() -> ExhaustiveGreedyStrategy<Self> {
Self::cx_count().into_greedy_strategy()
}
#[inline]
pub fn cx_count() -> Self {
Self {
cost_fns: [|op| is_cx(op) as usize, |op| is_quantum(op) as usize],
}
}
#[inline]
pub fn rz_count() -> Self {
Self {
cost_fns: [
|op| op_matches(op, Tk2Op::Rz) as usize,
|op| is_quantum(op) as usize,
],
}
}
pub fn into_greedy_strategy(self) -> ExhaustiveGreedyStrategy<Self> {
ExhaustiveGreedyStrategy { strat_cost: self }
}
pub fn into_threshold_strategy(self) -> ExhaustiveThresholdStrategy<Self> {
ExhaustiveThresholdStrategy { strat_cost: self }
}
}
impl Default for LexicographicCostFunction<fn(&OpType) -> usize, 2> {
fn default() -> Self {
LexicographicCostFunction::cx_count()
}
}
#[derive(Debug, Clone)]
pub struct GammaStrategyCost<C> {
pub gamma: f64,
pub op_cost: C,
}
impl<C: Fn(&OpType) -> usize> StrategyCost for GammaStrategyCost<C> {
type OpCost = usize;
#[inline]
fn under_threshold(&self, &pattern_cost: &Self::OpCost, &target_cost: &Self::OpCost) -> bool {
(target_cost as f64) < self.gamma * (pattern_cost as f64)
}
#[inline]
fn op_cost(&self, op: &OpType) -> Self::OpCost {
(self.op_cost)(op)
}
}
impl<C> GammaStrategyCost<C> {
#[inline]
pub fn with_cost(op_cost: C) -> ExhaustiveThresholdStrategy<Self> {
Self {
gamma: 1.0001,
op_cost,
}
.into()
}
#[inline]
pub fn new(gamma: f64, op_cost: C) -> ExhaustiveThresholdStrategy<Self> {
Self { gamma, op_cost }.into()
}
}
impl GammaStrategyCost<fn(&OpType) -> usize> {
#[inline]
pub fn exhaustive_cx() -> ExhaustiveThresholdStrategy<Self> {
GammaStrategyCost::with_cost(|op| is_cx(op) as usize)
}
#[inline]
pub fn exhaustive_cx_with_gamma(gamma: f64) -> ExhaustiveThresholdStrategy<Self> {
GammaStrategyCost::new(gamma, |op| is_cx(op) as usize)
}
}
#[cfg(test)]
mod tests {
use super::*;
use hugr::Node;
use itertools::Itertools;
use crate::rewrite::trace::REWRITE_TRACING_ENABLED;
use crate::{
circuit::Circuit,
rewrite::{CircuitRewrite, Subcircuit},
utils::build_simple_circuit,
};
fn n_cx(n_gates: usize) -> Circuit {
let qbs = [0, 1];
build_simple_circuit(2, |circ| {
for _ in 0..n_gates {
circ.append(Tk2Op::CX, qbs).unwrap();
}
Ok(())
})
.unwrap_or_else(|e| panic!("{}", e))
}
fn rw_to_empty(circ: &Circuit, cx_nodes: impl Into<Vec<Node>>) -> CircuitRewrite {
let subcirc = Subcircuit::try_from_nodes(cx_nodes, circ).unwrap();
subcirc
.create_rewrite(circ, n_cx(0))
.unwrap_or_else(|e| panic!("{}", e))
}
fn rw_to_full(circ: &Circuit, cx_nodes: impl Into<Vec<Node>>) -> CircuitRewrite {
let subcirc = Subcircuit::try_from_nodes(cx_nodes, circ).unwrap();
subcirc
.create_rewrite(circ, n_cx(10))
.unwrap_or_else(|e| panic!("{}", e))
}
#[test]
fn test_greedy_strategy() {
let mut circ = n_cx(10);
let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec();
assert!(circ.rewrite_trace().is_none());
circ.enable_rewrite_tracing();
match REWRITE_TRACING_ENABLED {
true => assert_eq!(circ.rewrite_trace().unwrap().collect_vec(), []),
false => assert!(circ.rewrite_trace().is_none()),
}
let rws = [
rw_to_empty(&circ, cx_gates[0..2].to_vec()),
rw_to_full(&circ, cx_gates[4..7].to_vec()),
rw_to_empty(&circ, cx_gates[4..6].to_vec()),
rw_to_empty(&circ, cx_gates[9..10].to_vec()),
];
let strategy = GreedyRewriteStrategy;
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
assert_eq!(rewritten.len(), 1);
assert_eq!(rewritten[0].circ.num_operations(), 5);
if REWRITE_TRACING_ENABLED {
assert_eq!(rewritten[0].circ.rewrite_trace().unwrap().count(), 3);
}
}
#[test]
fn test_exhaustive_default_strategy() {
let mut circ = n_cx(10);
let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec();
circ.enable_rewrite_tracing();
let rws = [
rw_to_empty(&circ, cx_gates[0..2].to_vec()),
rw_to_full(&circ, cx_gates[4..7].to_vec()),
rw_to_empty(&circ, cx_gates[4..8].to_vec()),
rw_to_empty(&circ, cx_gates[9..10].to_vec()),
];
let strategy = LexicographicCostFunction::cx_count().into_greedy_strategy();
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_operations()).collect();
assert_eq!(circ_lens, exp_circ_lens);
if REWRITE_TRACING_ENABLED {
assert_eq!(
rewritten[0].circ.rewrite_trace().unwrap().collect_vec(),
vec![RewriteTrace::new(3)]
);
assert_eq!(
rewritten[1].circ.rewrite_trace().unwrap().collect_vec(),
vec![RewriteTrace::new(2)]
);
assert_eq!(
rewritten[2].circ.rewrite_trace().unwrap().collect_vec(),
vec![RewriteTrace::new(1)]
);
}
}
#[test]
fn test_exhaustive_gamma_strategy() {
let circ = n_cx(10);
let cx_gates = circ.commands().map(|cmd| cmd.node()).collect_vec();
let rws = [
rw_to_empty(&circ, cx_gates[0..2].to_vec()),
rw_to_full(&circ, cx_gates[4..7].to_vec()),
rw_to_empty(&circ, cx_gates[4..8].to_vec()),
rw_to_empty(&circ, cx_gates[9..10].to_vec()),
];
let strategy = GammaStrategyCost::exhaustive_cx_with_gamma(10.);
let rewritten = strategy.apply_rewrites(rws, &circ);
let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]);
let circ_lens: HashSet<_> = rewritten.map(|r| r.circ.num_operations()).collect();
assert_eq!(circ_lens, exp_circ_lens);
}
#[test]
fn test_exhaustive_default_cx_cost() {
let strat = LexicographicCostFunction::cx_count().into_greedy_strategy();
let circ = n_cx(3);
assert_eq!(strat.circuit_cost(&circ), (3, 3).into());
let circ = build_simple_circuit(2, |circ| {
circ.append(Tk2Op::CX, [0, 1])?;
circ.append(Tk2Op::X, [0])?;
circ.append(Tk2Op::X, [1])?;
Ok(())
})
.unwrap();
assert_eq!(strat.circuit_cost(&circ), (1, 3).into());
}
#[test]
fn test_exhaustive_default_cx_threshold() {
let strat = LexicographicCostFunction::cx_count();
assert!(strat.under_threshold(&(3, 0).into(), &(3, 0).into()));
assert!(strat.under_threshold(&(3, 0).into(), &(3, 5).into()));
assert!(!strat.under_threshold(&(3, 10).into(), &(4, 0).into()));
assert!(strat.under_threshold(&(3, 0).into(), &(1, 5).into()));
}
}