use std::thread::{self, JoinHandle};
use crate::circuit::cost::CircuitCost;
use crate::circuit::CircuitHash;
use crate::rewrite::strategy::RewriteStrategy;
use crate::rewrite::Rewriter;
use super::hugr_pchannel::{PriorityChannelCommunication, Work};
pub struct BadgerWorker<R, S, P: Ord> {
#[allow(unused)]
id: usize,
priority_channel: PriorityChannelCommunication<P>,
rewriter: R,
strategy: S,
}
impl<R, S, P> BadgerWorker<R, S, P>
where
R: Rewriter + Send + 'static,
S: RewriteStrategy<Cost = P> + Send + 'static,
P: CircuitCost + Send + Sync + 'static,
{
#[allow(clippy::too_many_arguments)]
pub fn spawn(
id: usize,
priority_channel: PriorityChannelCommunication<P>,
rewriter: R,
strategy: S,
) -> JoinHandle<()> {
let name = format!("BadgerWorker-{id}");
thread::Builder::new()
.name(name)
.spawn(move || {
let mut worker = Self {
id,
priority_channel,
rewriter,
strategy,
};
worker.run_loop()
})
.unwrap()
}
#[tracing::instrument(target = "badger::metrics", skip(self))]
fn run_loop(&mut self) {
loop {
let Ok(Work { circ, cost, .. }) = self.priority_channel.recv() else {
break;
};
let rewrites = self.rewriter.get_rewrites(&circ);
let max_cost = self.priority_channel.max_cost();
let new_circs = self
.strategy
.apply_rewrites(rewrites, &circ)
.filter_map(|r| {
let new_cost = cost.add_delta(&r.cost_delta);
if max_cost.is_some() && &new_cost >= max_cost.as_ref().unwrap() {
return None;
}
let Ok(hash) = r.circ.circuit_hash(r.circ.parent()) else {
return None;
};
Some(Work {
cost: new_cost,
hash,
circ: r.circ,
})
})
.collect();
let send = tracing::trace_span!(target: "badger::metrics", "BadgerWorker::send_result")
.in_scope(|| self.priority_channel.send(new_circs));
if send.is_err() {
break;
}
}
}
}