use delegate::delegate;
use fxhash::{FxHashMap, FxHashSet};
use priority_queue::DoublePriorityQueue;
use super::State;
#[derive(Debug, Clone, Default)]
pub(super) struct StatePQueue<S, P: Ord> {
queue: DoublePriorityQueue<u64, P>,
hash_lookup: FxHashMap<u64, S>,
seen_hashes: FxHashSet<u64>,
max_size: usize,
all_time_best: Option<DoublePriorityQueue<u64, P>>,
}
pub(super) struct Entry<S, P, H> {
pub state: S,
pub cost: P,
pub hash: H,
}
impl<S: Clone, P: Clone + Ord> StatePQueue<S, P> {
pub fn new(max_size: usize, n_best: Option<usize>) -> Self {
Self {
queue: DoublePriorityQueue::with_capacity(max_size),
hash_lookup: Default::default(),
seen_hashes: Default::default(),
max_size,
all_time_best: n_best.map(|n| DoublePriorityQueue::with_capacity(n)),
}
}
#[expect(unused)]
pub fn peek(&self) -> Option<Entry<&S, &P, u64>> {
let (&hash, cost) = self.queue.peek_min()?;
let state = self.hash_lookup.get(&hash)?;
Some(Entry { state, cost, hash })
}
pub fn push<Ctx>(&mut self, state: S, context: &Ctx) -> Option<Entry<&S, &P, u64>>
where
S: State<Ctx, Cost = P>,
{
let hash = state.hash(context)?;
let cost = state.cost(context)?;
self.push_unchecked(state, hash, cost)
}
pub fn push_unchecked(&mut self, state: S, hash: u64, cost: P) -> Option<Entry<&S, &P, u64>> {
if !self.check_accepted(&cost) || !self.seen_hashes.insert(hash) {
return None;
}
if self.len() >= self.max_size {
self.pop_max();
}
self.queue.push(hash, cost.clone());
self.push_all_time_best(hash, cost);
let (_, cost_ref) = self.queue.get(&hash).expect("just inserted");
let state_ref = self.hash_lookup.entry(hash).or_insert(state);
Some(Entry {
state: state_ref,
cost: cost_ref,
hash,
})
}
#[inline]
fn push_all_time_best(&mut self, hash: u64, cost: P) {
let Some(all_time_best) = self.all_time_best.as_ref() else {
return;
};
if all_time_best.len() == all_time_best.capacity() {
if Some(&cost) < all_time_best.peek_max().map(|(_, cost)| cost) {
self.pop_max_all_time_best();
} else {
return;
}
}
self.all_time_best.as_mut().unwrap().push(hash, cost);
}
pub fn pop(&mut self) -> Option<Entry<S, P, u64>> {
let (hash, cost) = self.queue.pop_min()?;
let state = if self
.all_time_best
.as_ref()
.map(|atb| !atb.contains(&hash))
.unwrap_or_default()
{
self.hash_lookup.remove(&hash)?
} else {
self.hash_lookup.get(&hash)?.clone()
};
Some(Entry { state, cost, hash })
}
fn pop_max(&mut self) -> Option<Entry<S, P, u64>>
where
S: Clone,
{
let (hash, cost) = self.queue.pop_max()?;
let state = if self
.all_time_best
.as_ref()
.map(|atb| !atb.contains(&hash))
.unwrap_or_default()
{
self.hash_lookup.remove(&hash)?
} else {
self.hash_lookup.get(&hash)?.clone()
};
Some(Entry { state, cost, hash })
}
fn pop_max_all_time_best(&mut self) -> bool {
let Some((hash, _)) = self.all_time_best.as_mut().and_then(|atb| atb.pop_max()) else {
return false;
};
if !self.queue.contains(&hash) {
self.hash_lookup.remove(&hash);
}
true
}
#[expect(unused)]
pub fn truncate(&mut self, max_size: usize) {
while self.queue.len() > max_size {
let (hash, _) = self.queue.pop_max().unwrap();
self.hash_lookup.remove(&hash);
}
}
pub fn max_cost(&self) -> Option<&P> {
self.queue.peek_max().map(|(_, cost)| cost)
}
pub fn num_seen_hashes(&self) -> usize {
self.seen_hashes.len()
}
pub fn check_accepted(&self, cost: &P) -> bool {
if self.max_size == 0 {
return false;
}
if self.len() < self.max_size {
return true;
}
cost < self.max_cost().unwrap()
}
pub fn is_full(&self) -> bool {
self.queue.len() >= self.max_size
}
pub fn into_all_time_best(mut self) -> Option<Vec<S>> {
let all_time_best = self.all_time_best?;
Some(
all_time_best
.into_sorted_iter()
.map(|(hash, _)| self.hash_lookup.remove(&hash).unwrap())
.collect(),
)
}
delegate! {
to self.queue {
pub fn len(&self) -> usize;
pub fn is_empty(&self) -> bool;
}
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use super::*;
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
struct DummyState(usize);
impl DummyState {
fn hash(&self) -> u64 {
self.0 as u64
}
fn cost(&self) -> usize {
self.0
}
}
impl State<()> for DummyState {
type Cost = usize;
fn hash(&self, (): &()) -> Option<u64> {
Some(DummyState::hash(self))
}
fn cost(&self, (): &()) -> Option<Self::Cost> {
Some(DummyState::cost(self))
}
fn next_states(&self, (): &mut ()) -> Vec<Self> {
unimplemented!()
}
}
impl StatePQueue<DummyState, usize> {
fn get_enqueued(&self) -> Vec<DummyState> {
let mut enqueued = self
.queue
.iter()
.map(|(hash, _)| *self.hash_lookup.get(hash).unwrap())
.collect_vec();
enqueued.sort_unstable();
enqueued
}
fn get_all_time_best(&self) -> Vec<DummyState> {
let mut all_time_best = self
.all_time_best
.as_ref()
.unwrap()
.iter()
.map(|(hash, _)| *self.hash_lookup.get(hash).unwrap())
.collect_vec();
all_time_best.sort_unstable();
all_time_best
}
}
#[test]
fn test_queue_truncation() {
let max_size = 5;
let n_best = 3;
let mut pq = StatePQueue::new(max_size, Some(n_best));
for i in (0..10).rev() {
pq.push(DummyState(i), &());
}
assert_eq!(pq.len(), max_size);
assert_eq!(
pq.get_enqueued(),
(0..max_size).map(DummyState).collect_vec()
);
assert_eq!(
pq.get_all_time_best(),
(0..n_best).map(DummyState).collect_vec()
);
for i in [4, 3, 2] {
assert_eq!(pq.pop_max().map(|e| e.state), Some(DummyState(i)));
}
assert_eq!(pq.get_enqueued(), (0..2).map(DummyState).collect_vec());
assert_eq!(
pq.get_all_time_best(),
(0..n_best).map(DummyState).collect_vec()
);
assert_eq!(pq.pop().map(|e| e.state), Some(DummyState(0)));
assert_eq!(pq.get_enqueued(), [DummyState(1)]);
assert_eq!(
pq.get_all_time_best(),
(0..n_best).map(DummyState).collect_vec()
);
for _ in 0..n_best {
assert!(pq.pop_max_all_time_best());
}
assert_eq!(pq.get_enqueued(), [DummyState(1)]);
assert_eq!(pq.hash_lookup.len(), 1);
assert!(pq.all_time_best.as_ref().unwrap().is_empty());
assert!(pq.pop().is_some());
assert!(pq.hash_lookup.is_empty());
}
}