use std::sync::{Arc, RwLock};
use std::thread;
use std::time::Instant;
use crossbeam_channel::{Receiver, RecvError, SendError, Sender, select};
use crate::circuit::cost::CircuitCost;
use crate::optimiser::pqueue::{Entry, StatePQueue};
pub type Work<S, P> = Entry<S, P, u64>;
#[derive(Debug, Clone)]
pub struct StatePQWorker<S, P: Ord> {
push: Receiver<Vec<Work<S, P>>>,
pop: Sender<Work<S, P>>,
log: Sender<LogMessage<S, P>>,
last_progress_log: Instant,
pq: StatePQueue<S, P>,
min_cost: Option<P>,
state_cnt: usize,
max_cost: Arc<RwLock<Option<P>>>,
local_max_cost: Option<P>,
}
#[derive(Debug, Clone)]
pub enum LogMessage<S, P> {
NewBestState(S, P),
StateCount {
processed_count: usize,
seen_count: usize,
queue_length: usize,
},
}
#[derive(Clone)]
pub struct StatePQueueChannels<S, P> {
push: Sender<Vec<Work<S, P>>>,
pop: Receiver<Work<S, P>>,
max_cost: Arc<RwLock<Option<P>>>,
}
impl<S, P: CircuitCost> StatePQueueChannels<S, P> {
pub fn close(&self) -> Result<(), SendError<Vec<Work<S, P>>>> {
self.push.send(Vec::new())
}
pub fn send(&self, work: Vec<Work<S, P>>) -> Result<(), SendError<Vec<Work<S, P>>>> {
if work.is_empty() {
return Ok(());
}
self.push.send(work)
}
pub fn recv(&self) -> Result<Work<S, P>, RecvError> {
self.pop.recv()
}
pub fn max_cost(&self) -> Option<P> {
self.max_cost.read().as_deref().ok().cloned().flatten()
}
}
impl<S, P> StatePQWorker<S, P>
where
P: CircuitCost + Send + Sync + 'static,
S: Clone + Send + Sync + 'static,
{
pub fn init(queue_capacity: usize) -> (StatePQueueChannels<S, P>, Receiver<LogMessage<S, P>>) {
let max_cost = Arc::new(RwLock::new(None));
let (tx_push, rx_push) = crossbeam_channel::unbounded();
let (tx_pop, rx_pop) = crossbeam_channel::bounded(0);
let (tx_log, rx_log) = crossbeam_channel::unbounded();
let pq = Self::new(rx_push, tx_pop, tx_log, max_cost.clone(), queue_capacity);
pq.run();
(
StatePQueueChannels {
push: tx_push,
pop: rx_pop,
max_cost,
},
rx_log,
)
}
fn new(
push: Receiver<Vec<Work<S, P>>>,
pop: Sender<Work<S, P>>,
log: Sender<LogMessage<S, P>>,
max_cost: Arc<RwLock<Option<P>>>,
queue_capacity: usize,
) -> Self {
let pq = StatePQueue::new(queue_capacity, None);
let min_cost = None;
let state_cnt = 0;
StatePQWorker {
push,
pop,
log,
last_progress_log: Instant::now() - std::time::Duration::from_secs(60),
pq,
min_cost,
state_cnt,
max_cost,
local_max_cost: None,
}
}
fn run(mut self) {
let builder = thread::Builder::new().name("priority queueing".into());
let _ = builder
.name("priority-channel".into())
.spawn(move || {
'main: loop {
while self.pq.is_empty() {
let Ok(new_circs) = self.push.recv() else {
break 'main;
};
if new_circs.is_empty() {
break 'main;
}
self.enqueue_circs(new_circs);
}
select! {
recv(self.push) -> result => {
let Ok(new_circs) = result else {
break 'main;
};
if new_circs.is_empty() {
break 'main;
}
self.enqueue_circs(new_circs);
}
send(self.pop, self.pq.pop().unwrap()) -> result => {
if result.is_err() {
break 'main;
}
self.update_max_cost();
}
}
}
self.log
.send(LogMessage::StateCount {
processed_count: self.state_cnt,
seen_count: self.pq.num_seen_hashes(),
queue_length: self.pq.len(),
})
.unwrap();
})
.unwrap();
}
#[tracing::instrument(target = "badger::metrics", skip(self, states))]
fn enqueue_circs(&mut self, states: Vec<Work<S, P>>) {
for Work { cost, hash, state } in states {
if self.min_cost.is_none() || Some(&cost) < self.min_cost.as_ref() {
self.min_cost = Some(cost.clone());
self.log
.send(LogMessage::NewBestState(state.clone(), cost.clone()))
.unwrap();
}
self.pq.push_unchecked(state, hash, cost);
}
self.update_max_cost();
self.state_cnt += 1;
if Instant::now() - self.last_progress_log > std::time::Duration::from_millis(100) {
self.log
.send(LogMessage::StateCount {
processed_count: self.state_cnt,
seen_count: self.pq.num_seen_hashes(),
queue_length: self.pq.len(),
})
.unwrap();
}
}
#[inline]
fn update_max_cost(&mut self) {
if !self.pq.is_full() || self.pq.is_empty() {
return;
}
let queue_max = self.pq.max_cost().unwrap().clone();
let local_max = self.local_max_cost.clone();
if local_max.is_some() && queue_max < local_max.unwrap() {
self.local_max_cost = Some(queue_max.clone());
*self.max_cost.write().unwrap() = Some(queue_max);
}
}
}