use std::sync::{Arc, RwLock};
use std::thread;
use std::time::Instant;
use crossbeam_channel::{select, Receiver, RecvError, SendError, Sender};
use fxhash::FxHashSet;
use crate::circuit::cost::CircuitCost;
use crate::Circuit;
use super::hugr_pqueue::{Entry, HugrPQ};
pub type Work<P> = Entry<Circuit, P, u64>;
#[derive(Debug, Clone)]
pub struct HugrPriorityChannel<C, P: Ord> {
push: Receiver<Vec<Work<P>>>,
pop: Sender<Work<P>>,
log: Sender<PriorityChannelLog<P>>,
last_progress_log: Instant,
pq: HugrPQ<P, C>,
seen_hashes: FxHashSet<u64>,
min_cost: Option<P>,
circ_cnt: usize,
max_cost: Arc<RwLock<Option<P>>>,
local_max_cost: Option<P>,
}
#[derive(Debug, Clone)]
pub enum PriorityChannelLog<P> {
NewBestCircuit(Circuit, P),
CircuitCount {
processed_count: usize,
seen_count: usize,
queue_length: usize,
},
}
#[derive(Clone)]
pub struct PriorityChannelCommunication<P> {
push: Sender<Vec<Work<P>>>,
pop: Receiver<Work<P>>,
max_cost: Arc<RwLock<Option<P>>>,
}
impl<P: CircuitCost> PriorityChannelCommunication<P> {
pub fn close(&self) -> Result<(), SendError<Vec<Work<P>>>> {
self.push.send(Vec::new())
}
pub fn send(&self, work: Vec<Work<P>>) -> Result<(), SendError<Vec<Work<P>>>> {
if work.is_empty() {
return Ok(());
}
self.push.send(work)
}
pub fn recv(&self) -> Result<Work<P>, RecvError> {
self.pop.recv()
}
pub fn max_cost(&self) -> Option<P> {
self.max_cost.read().as_deref().ok().cloned().flatten()
}
}
impl<C, P> HugrPriorityChannel<C, P>
where
C: Fn(&Circuit) -> P + Send + Sync + 'static,
P: CircuitCost + Send + Sync + 'static,
{
pub fn init(
cost_fn: C,
queue_capacity: usize,
) -> (
PriorityChannelCommunication<P>,
Receiver<PriorityChannelLog<P>>,
) {
let max_cost = Arc::new(RwLock::new(None));
let (tx_push, rx_push) = crossbeam_channel::unbounded();
let (tx_pop, rx_pop) = crossbeam_channel::bounded(0);
let (tx_log, rx_log) = crossbeam_channel::unbounded();
let pq = HugrPriorityChannel::new(
rx_push,
tx_pop,
tx_log,
max_cost.clone(),
cost_fn,
queue_capacity,
);
pq.run();
(
PriorityChannelCommunication {
push: tx_push,
pop: rx_pop,
max_cost,
},
rx_log,
)
}
fn new(
push: Receiver<Vec<Work<P>>>,
pop: Sender<Work<P>>,
log: Sender<PriorityChannelLog<P>>,
max_cost: Arc<RwLock<Option<P>>>,
cost_fn: C,
queue_capacity: usize,
) -> Self {
let pq = HugrPQ::new(cost_fn, queue_capacity);
let seen_hashes = FxHashSet::default();
let min_cost = None;
let circ_cnt = 0;
HugrPriorityChannel {
push,
pop,
log,
last_progress_log: Instant::now() - std::time::Duration::from_secs(60),
pq,
seen_hashes,
min_cost,
circ_cnt,
max_cost,
local_max_cost: None,
}
}
fn run(mut self) {
let builder = thread::Builder::new().name("priority queueing".into());
let _ = builder
.name("priority-channel".into())
.spawn(move || {
'main: loop {
while self.pq.is_empty() {
let Ok(new_circs) = self.push.recv() else {
break 'main;
};
if new_circs.is_empty() {
break 'main;
}
self.enqueue_circs(new_circs);
}
select! {
recv(self.push) -> result => {
let Ok(new_circs) = result else {
break 'main;
};
if new_circs.is_empty() {
break 'main;
}
self.enqueue_circs(new_circs);
}
send(self.pop, self.pq.pop().unwrap()) -> result => {
if result.is_err() {
break 'main;
}
self.update_max_cost();
}
}
}
self.log
.send(PriorityChannelLog::CircuitCount {
processed_count: self.circ_cnt,
seen_count: self.seen_hashes.len(),
queue_length: self.pq.len(),
})
.unwrap();
})
.unwrap();
}
#[tracing::instrument(target = "badger::metrics", skip(self, circs))]
fn enqueue_circs(&mut self, circs: Vec<Work<P>>) {
for Work { cost, hash, circ } in circs {
if !self.seen_hashes.insert(hash) {
continue;
}
if self.min_cost.is_none() || Some(&cost) < self.min_cost.as_ref() {
self.min_cost = Some(cost.clone());
self.log
.send(PriorityChannelLog::NewBestCircuit(
circ.clone(),
cost.clone(),
))
.unwrap();
}
self.pq.push_unchecked(circ, hash, cost);
}
self.update_max_cost();
self.circ_cnt += 1;
if Instant::now() - self.last_progress_log > std::time::Duration::from_millis(100) {
self.log
.send(PriorityChannelLog::CircuitCount {
processed_count: self.circ_cnt,
seen_count: self.seen_hashes.len(),
queue_length: self.pq.len(),
})
.unwrap();
}
}
#[inline]
fn update_max_cost(&mut self) {
if !self.pq.is_full() || self.pq.is_empty() {
return;
}
let queue_max = self.pq.max_cost().unwrap().clone();
let local_max = self.local_max_cost.clone();
if local_max.is_some() && queue_max < local_max.unwrap() {
self.local_max_cost = Some(queue_max.clone());
*self.max_cost.write().unwrap() = Some(queue_max);
}
}
}