use std::thread::{self, JoinHandle};
use crate::rewrite::Rewriter;
use crate::rewrite::strategy::RewriteStrategy;
use crate::{Circuit, circuit::cost::CircuitCost};
use super::pqueue_worker::{StatePQueueChannels, Work};
pub struct BadgerWorker<R, S, P: Ord> {
#[expect(dead_code)]
id: usize,
priority_channel: StatePQueueChannels<Circuit, P>,
rewriter: R,
strategy: S,
}
impl<R, S, P> BadgerWorker<R, S, P>
where
R: Rewriter + Send + 'static,
S: RewriteStrategy<Cost = P> + Send + 'static,
P: CircuitCost + Send + Sync + 'static,
{
pub fn spawn(
id: usize,
priority_channel: StatePQueueChannels<Circuit, P>,
rewriter: R,
strategy: S,
) -> JoinHandle<()> {
let name = format!("BadgerWorker-{id}");
thread::Builder::new()
.name(name)
.spawn(move || {
let mut worker = Self {
id,
priority_channel,
rewriter,
strategy,
};
worker.run_loop()
})
.unwrap()
}
#[tracing::instrument(target = "badger::metrics", skip(self))]
fn run_loop(&mut self) {
loop {
let Ok(Work {
state: circ, cost, ..
}) = self.priority_channel.recv()
else {
break;
};
let rewrites = self.rewriter.get_rewrites(&circ);
let max_cost = self.priority_channel.max_cost();
let new_circs = self
.strategy
.apply_rewrites(rewrites, &circ)
.filter_map(|r| {
let new_cost = cost.add_delta(&r.cost_delta);
if max_cost.is_some() && new_cost >= *max_cost.as_ref().unwrap() {
return None;
}
let Ok(hash) = r.circ.circuit_hash(r.circ.parent()) else {
return None;
};
Some(Work {
cost: new_cost,
hash,
state: r.circ,
})
})
.collect();
let send = tracing::trace_span!(target: "badger::metrics", "BadgerWorker::send_result")
.in_scope(|| self.priority_channel.send(new_circs));
if send.is_err() {
break;
}
}
}
}