use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::collections::VecDeque;
use ordered_float::OrderedFloat;
use rustc_hash::FxHashSet;
use crate::semiring::Semiring;
use crate::wfst::StateId;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum QueueType {
#[default]
Auto,
Fifo,
Topological,
ShortestFirst,
}
pub trait ShortestDistanceQueue<W: Semiring> {
fn with_capacity(capacity: usize) -> Self;
fn new() -> Self
where
Self: Sized,
{
Self::with_capacity(0)
}
fn insert(&mut self, state: StateId, distance: &W);
fn pop(&mut self) -> Option<StateId>;
fn update(&mut self, state: StateId, distance: &W);
fn is_empty(&self) -> bool;
fn len(&self) -> usize;
fn contains(&self, state: StateId) -> bool;
fn clear(&mut self);
}
#[derive(Clone, Debug)]
pub struct FifoQueue {
queue: VecDeque<StateId>,
in_queue: FxHashSet<StateId>,
}
impl FifoQueue {
pub fn new() -> Self {
Self::with_capacity(0)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
queue: VecDeque::with_capacity(capacity),
in_queue: FxHashSet::with_capacity_and_hasher(capacity, Default::default()),
}
}
pub fn insert_state(&mut self, state: StateId) {
if self.in_queue.insert(state) {
self.queue.push_back(state);
}
}
pub fn pop(&mut self) -> Option<StateId> {
let state = self.queue.pop_front()?;
self.in_queue.remove(&state);
Some(state)
}
pub fn update_state(&mut self, state: StateId) {
self.insert_state(state);
}
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
}
pub fn len(&self) -> usize {
self.queue.len()
}
pub fn contains(&self, state: StateId) -> bool {
self.in_queue.contains(&state)
}
pub fn clear(&mut self) {
self.queue.clear();
self.in_queue.clear();
}
}
impl Default for FifoQueue {
fn default() -> Self {
Self::new()
}
}
impl<W: Semiring> ShortestDistanceQueue<W> for FifoQueue {
fn with_capacity(capacity: usize) -> Self {
FifoQueue::with_capacity(capacity)
}
fn insert(&mut self, state: StateId, _distance: &W) {
self.insert_state(state);
}
fn pop(&mut self) -> Option<StateId> {
FifoQueue::pop(self)
}
fn update(&mut self, state: StateId, _distance: &W) {
self.update_state(state);
}
fn is_empty(&self) -> bool {
FifoQueue::is_empty(self)
}
fn len(&self) -> usize {
FifoQueue::len(self)
}
fn contains(&self, state: StateId) -> bool {
FifoQueue::contains(self, state)
}
fn clear(&mut self) {
FifoQueue::clear(self)
}
}
#[derive(Clone, Debug)]
pub struct TopologicalQueue {
order: Vec<StateId>,
current_pos: usize,
state_to_pos: Vec<usize>,
enqueued: Vec<bool>,
count: usize,
}
impl TopologicalQueue {
pub fn new() -> Self {
Self::with_capacity(0)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
order: Vec::with_capacity(capacity),
current_pos: 0,
state_to_pos: Vec::with_capacity(capacity),
enqueued: Vec::with_capacity(capacity),
count: 0,
}
}
pub fn from_order(order: Vec<StateId>) -> Self {
let n = order.iter().map(|&s| s as usize + 1).max().unwrap_or(0);
let mut state_to_pos = vec![usize::MAX; n];
for (pos, &state) in order.iter().enumerate() {
state_to_pos[state as usize] = pos;
}
Self {
enqueued: vec![false; order.len()],
order,
current_pos: 0,
state_to_pos,
count: 0,
}
}
pub fn insert_state(&mut self, state: StateId) {
let idx = state as usize;
if idx < self.state_to_pos.len() {
let pos = self.state_to_pos[idx];
if pos < self.enqueued.len() && !self.enqueued[pos] {
self.enqueued[pos] = true;
self.count += 1;
}
}
}
pub fn pop(&mut self) -> Option<StateId> {
while self.current_pos < self.order.len() {
if self.enqueued[self.current_pos] {
self.enqueued[self.current_pos] = false;
self.count -= 1;
let state = self.order[self.current_pos];
self.current_pos += 1;
return Some(state);
}
self.current_pos += 1;
}
None
}
pub fn update_state(&mut self, state: StateId) {
self.insert_state(state);
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn len(&self) -> usize {
self.count
}
pub fn contains(&self, state: StateId) -> bool {
let idx = state as usize;
if idx < self.state_to_pos.len() {
let pos = self.state_to_pos[idx];
pos < self.enqueued.len() && self.enqueued[pos]
} else {
false
}
}
pub fn clear(&mut self) {
for e in &mut self.enqueued {
*e = false;
}
self.current_pos = 0;
self.count = 0;
}
}
impl Default for TopologicalQueue {
fn default() -> Self {
Self::new()
}
}
impl<W: Semiring> ShortestDistanceQueue<W> for TopologicalQueue {
fn with_capacity(capacity: usize) -> Self {
TopologicalQueue::with_capacity(capacity)
}
fn insert(&mut self, state: StateId, _distance: &W) {
self.insert_state(state);
}
fn pop(&mut self) -> Option<StateId> {
TopologicalQueue::pop(self)
}
fn update(&mut self, state: StateId, _distance: &W) {
self.update_state(state);
}
fn is_empty(&self) -> bool {
TopologicalQueue::is_empty(self)
}
fn len(&self) -> usize {
TopologicalQueue::len(self)
}
fn contains(&self, state: StateId) -> bool {
TopologicalQueue::contains(self, state)
}
fn clear(&mut self) {
TopologicalQueue::clear(self)
}
}
#[derive(Clone, Debug)]
struct ShortestFirstEntry {
state: StateId,
neg_distance: OrderedFloat<f64>,
}
impl PartialEq for ShortestFirstEntry {
fn eq(&self, other: &Self) -> bool {
self.neg_distance == other.neg_distance
}
}
impl Eq for ShortestFirstEntry {}
impl PartialOrd for ShortestFirstEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ShortestFirstEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.neg_distance.cmp(&other.neg_distance)
}
}
#[derive(Clone, Debug)]
pub struct ShortestFirstQueue {
heap: BinaryHeap<ShortestFirstEntry>,
in_queue: FxHashSet<StateId>,
distances: Vec<OrderedFloat<f64>>,
}
impl ShortestFirstQueue {
pub fn new() -> Self {
Self::with_capacity(0)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
heap: BinaryHeap::with_capacity(capacity),
in_queue: FxHashSet::with_capacity_and_hasher(capacity, Default::default()),
distances: Vec::with_capacity(capacity),
}
}
pub fn init_distances(&mut self, num_states: usize) {
self.distances
.resize(num_states, OrderedFloat(f64::INFINITY));
}
pub fn insert_with_distance(&mut self, state: StateId, dist: f64) {
let idx = state as usize;
if idx >= self.distances.len() {
self.distances.resize(idx + 1, OrderedFloat(f64::INFINITY));
}
let ord_dist = OrderedFloat(dist);
if ord_dist < self.distances[idx] {
self.distances[idx] = ord_dist;
self.heap.push(ShortestFirstEntry {
state,
neg_distance: OrderedFloat(-dist),
});
self.in_queue.insert(state);
}
}
pub fn pop(&mut self) -> Option<StateId> {
while let Some(entry) = self.heap.pop() {
let idx = entry.state as usize;
let expected_dist = OrderedFloat(-entry.neg_distance.0);
if idx < self.distances.len() && expected_dist == self.distances[idx] {
self.in_queue.remove(&entry.state);
return Some(entry.state);
}
}
None
}
pub fn is_empty(&self) -> bool {
self.in_queue.is_empty()
}
pub fn len(&self) -> usize {
self.in_queue.len()
}
pub fn contains(&self, state: StateId) -> bool {
self.in_queue.contains(&state)
}
pub fn clear(&mut self) {
self.heap.clear();
self.in_queue.clear();
for d in &mut self.distances {
*d = OrderedFloat(f64::INFINITY);
}
}
}
impl Default for ShortestFirstQueue {
fn default() -> Self {
Self::new()
}
}
impl<W: Semiring> ShortestDistanceQueue<W> for ShortestFirstQueue {
fn with_capacity(capacity: usize) -> Self {
ShortestFirstQueue::with_capacity(capacity)
}
fn insert(&mut self, state: StateId, distance: &W) {
let bytes = distance.to_bytes();
let dist = if bytes.len() >= 8 {
f64::from_le_bytes(bytes[..8].try_into().unwrap_or([0; 8]))
} else {
0.0
};
self.insert_with_distance(state, dist);
}
fn pop(&mut self) -> Option<StateId> {
ShortestFirstQueue::pop(self)
}
fn update(&mut self, state: StateId, distance: &W) {
self.insert(state, distance);
}
fn is_empty(&self) -> bool {
ShortestFirstQueue::is_empty(self)
}
fn len(&self) -> usize {
ShortestFirstQueue::len(self)
}
fn contains(&self, state: StateId) -> bool {
ShortestFirstQueue::contains(self, state)
}
fn clear(&mut self) {
ShortestFirstQueue::clear(self)
}
}
#[derive(Clone, Debug)]
pub enum AutoQueue {
Fifo(FifoQueue),
Topological(TopologicalQueue),
ShortestFirst(ShortestFirstQueue),
}
impl Default for AutoQueue {
fn default() -> Self {
AutoQueue::Fifo(FifoQueue::default())
}
}
impl AutoQueue {
pub fn with_topological_order(order: Option<Vec<StateId>>) -> Self {
match order {
Some(order) => AutoQueue::Topological(TopologicalQueue::from_order(order)),
None => AutoQueue::Fifo(FifoQueue::default()),
}
}
pub fn shortest_first(num_states: usize) -> Self {
let mut queue = ShortestFirstQueue::with_capacity(num_states);
queue.init_distances(num_states);
AutoQueue::ShortestFirst(queue)
}
pub fn pop(&mut self) -> Option<StateId> {
match self {
AutoQueue::Fifo(q) => q.pop(),
AutoQueue::Topological(q) => q.pop(),
AutoQueue::ShortestFirst(q) => q.pop(),
}
}
pub fn is_empty(&self) -> bool {
match self {
AutoQueue::Fifo(q) => q.is_empty(),
AutoQueue::Topological(q) => q.is_empty(),
AutoQueue::ShortestFirst(q) => q.is_empty(),
}
}
pub fn len(&self) -> usize {
match self {
AutoQueue::Fifo(q) => q.len(),
AutoQueue::Topological(q) => q.len(),
AutoQueue::ShortestFirst(q) => q.len(),
}
}
pub fn contains(&self, state: StateId) -> bool {
match self {
AutoQueue::Fifo(q) => q.contains(state),
AutoQueue::Topological(q) => q.contains(state),
AutoQueue::ShortestFirst(q) => q.contains(state),
}
}
pub fn clear(&mut self) {
match self {
AutoQueue::Fifo(q) => q.clear(),
AutoQueue::Topological(q) => q.clear(),
AutoQueue::ShortestFirst(q) => q.clear(),
}
}
}
impl<W: Semiring> ShortestDistanceQueue<W> for AutoQueue {
fn with_capacity(capacity: usize) -> Self {
AutoQueue::Fifo(FifoQueue::with_capacity(capacity))
}
fn insert(&mut self, state: StateId, distance: &W) {
match self {
AutoQueue::Fifo(q) => q.insert(state, distance),
AutoQueue::Topological(q) => q.insert(state, distance),
AutoQueue::ShortestFirst(q) => q.insert(state, distance),
}
}
fn pop(&mut self) -> Option<StateId> {
AutoQueue::pop(self)
}
fn update(&mut self, state: StateId, distance: &W) {
match self {
AutoQueue::Fifo(q) => q.update(state, distance),
AutoQueue::Topological(q) => q.update(state, distance),
AutoQueue::ShortestFirst(q) => q.update(state, distance),
}
}
fn is_empty(&self) -> bool {
AutoQueue::is_empty(self)
}
fn len(&self) -> usize {
AutoQueue::len(self)
}
fn contains(&self, state: StateId) -> bool {
AutoQueue::contains(self, state)
}
fn clear(&mut self) {
AutoQueue::clear(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::semiring::TropicalWeight;
#[test]
fn test_fifo_queue_basic() {
let mut queue = FifoQueue::new();
assert!(queue.is_empty());
assert_eq!(queue.len(), 0);
queue.insert_state(0);
queue.insert_state(1);
queue.insert_state(2);
assert!(!queue.is_empty());
assert_eq!(queue.len(), 3);
assert!(queue.contains(0));
assert!(queue.contains(1));
assert!(queue.contains(2));
assert_eq!(queue.pop(), Some(0));
assert_eq!(queue.pop(), Some(1));
assert_eq!(queue.pop(), Some(2));
assert_eq!(queue.pop(), None);
assert!(queue.is_empty());
}
#[test]
fn test_fifo_queue_no_duplicates() {
let mut queue = FifoQueue::new();
queue.insert_state(0);
queue.insert_state(0); queue.insert_state(1);
assert_eq!(queue.len(), 2); assert_eq!(queue.pop(), Some(0));
assert_eq!(queue.pop(), Some(1));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_topological_queue_basic() {
let mut queue = TopologicalQueue::from_order(vec![0, 1, 2, 3]);
queue.insert_state(2);
queue.insert_state(0);
queue.insert_state(1);
assert_eq!(queue.pop(), Some(0));
assert_eq!(queue.pop(), Some(1));
assert_eq!(queue.pop(), Some(2));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_shortest_first_queue_basic() {
let mut queue = ShortestFirstQueue::new();
queue.init_distances(10);
queue.insert_with_distance(0, 5.0);
queue.insert_with_distance(1, 1.0); queue.insert_with_distance(2, 3.0);
assert_eq!(queue.pop(), Some(1)); assert_eq!(queue.pop(), Some(2)); assert_eq!(queue.pop(), Some(0)); assert_eq!(queue.pop(), None);
}
#[test]
fn test_shortest_first_queue_update() {
let mut queue = ShortestFirstQueue::new();
queue.init_distances(10);
queue.insert_with_distance(0, 5.0);
queue.insert_with_distance(1, 10.0);
queue.insert_with_distance(1, 2.0);
assert_eq!(queue.pop(), Some(1)); assert_eq!(queue.pop(), Some(0)); }
#[test]
fn test_auto_queue_fifo_fallback() {
let mut queue: AutoQueue = AutoQueue::with_topological_order(None);
<AutoQueue as ShortestDistanceQueue<TropicalWeight>>::insert(
&mut queue,
0,
&TropicalWeight::new(1.0),
);
<AutoQueue as ShortestDistanceQueue<TropicalWeight>>::insert(
&mut queue,
1,
&TropicalWeight::new(2.0),
);
assert_eq!(queue.pop(), Some(0));
assert_eq!(queue.pop(), Some(1));
}
#[test]
fn test_auto_queue_topological() {
let mut queue: AutoQueue = AutoQueue::with_topological_order(Some(vec![2, 0, 1]));
<AutoQueue as ShortestDistanceQueue<TropicalWeight>>::insert(
&mut queue,
0,
&TropicalWeight::new(1.0),
);
<AutoQueue as ShortestDistanceQueue<TropicalWeight>>::insert(
&mut queue,
1,
&TropicalWeight::new(2.0),
);
<AutoQueue as ShortestDistanceQueue<TropicalWeight>>::insert(
&mut queue,
2,
&TropicalWeight::new(3.0),
);
assert_eq!(queue.pop(), Some(2));
assert_eq!(queue.pop(), Some(0));
assert_eq!(queue.pop(), Some(1));
}
#[test]
fn test_queue_clear() {
let mut queue = FifoQueue::new();
queue.insert_state(0);
queue.insert_state(1);
assert!(!queue.is_empty());
queue.clear();
assert!(queue.is_empty());
assert!(!queue.contains(0));
}
}