use std::cmp::Ordering;
use std::collections::BinaryHeap;
use serde::{Deserialize, Serialize};
use crate::grid::{GridPos, NavGrid};
use crate::incremental::{IncrementalGridPath, IncrementalStatus};
use crate::path::PathResult;
#[cfg(feature = "logging")]
use tracing::instrument;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PathRequestId(pub u32);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct RequestPriority(pub u32);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueuedRequest {
pub id: PathRequestId,
pub start: GridPos,
pub goal: GridPos,
pub priority: RequestPriority,
}
struct PendingEntry {
id: PathRequestId,
priority: RequestPriority,
}
impl PartialEq for PendingEntry {
fn eq(&self, other: &Self) -> bool {
self.priority == other.priority
}
}
impl Eq for PendingEntry {}
impl PartialOrd for PendingEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PendingEntry {
fn cmp(&self, other: &Self) -> Ordering {
other.priority.cmp(&self.priority)
}
}
#[derive(Debug, Clone)]
pub struct BatchedResult {
pub id: PathRequestId,
pub result: PathResult,
}
#[derive(Debug)]
pub struct PathBatcher {
pending: BinaryHeap<PendingEntry>,
active: Vec<(PathRequestId, IncrementalGridPath)>,
queued: Vec<QueuedRequest>,
max_active: usize,
next_id: u32,
}
impl std::fmt::Debug for PendingEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PendingEntry")
.field("id", &self.id)
.field("priority", &self.priority)
.finish()
}
}
impl PathBatcher {
#[cfg_attr(feature = "logging", instrument)]
#[must_use]
pub fn new() -> Self {
Self {
pending: BinaryHeap::new(),
active: Vec::new(),
queued: Vec::new(),
max_active: 4,
next_id: 0,
}
}
#[cfg_attr(feature = "logging", instrument)]
#[must_use]
pub fn with_max_active(max_active: usize) -> Self {
Self {
pending: BinaryHeap::new(),
active: Vec::new(),
queued: Vec::new(),
max_active: max_active.max(1),
next_id: 0,
}
}
#[cfg_attr(feature = "logging", instrument(skip(self, grid)))]
pub fn enqueue(
&mut self,
start: GridPos,
goal: GridPos,
priority: RequestPriority,
grid: &NavGrid,
) -> Option<PathRequestId> {
if !grid.is_walkable(start.x, start.y) || !grid.is_walkable(goal.x, goal.y) {
return None;
}
let id = PathRequestId(self.next_id);
self.next_id += 1;
self.queued.push(QueuedRequest {
id,
start,
goal,
priority,
});
self.pending.push(PendingEntry { id, priority });
Some(id)
}
pub fn cancel(&mut self, id: PathRequestId) -> bool {
if let Some(pos) = self.queued.iter().position(|r| r.id == id) {
self.queued.swap_remove(pos);
return true;
}
if let Some(pos) = self.active.iter().position(|(rid, _)| *rid == id) {
self.active.swap_remove(pos);
return true;
}
false
}
#[must_use]
pub fn request_count(&self) -> usize {
self.queued.len() + self.active.len()
}
#[must_use]
pub fn active_count(&self) -> usize {
self.active.len()
}
#[must_use]
pub fn queued_count(&self) -> usize {
self.queued.len()
}
#[cfg_attr(feature = "logging", instrument(skip(self, grid)))]
#[must_use]
pub fn process(&mut self, grid: &NavGrid, max_iterations: u32) -> Vec<BatchedResult> {
self.promote_queued(grid);
if self.active.is_empty() {
return Vec::new();
}
let mut completed = Vec::new();
let mut remaining_budget = max_iterations;
let per_query = (remaining_budget / self.active.len() as u32).max(1);
let mut i = 0;
while i < self.active.len() {
if remaining_budget == 0 {
break;
}
let budget = per_query.min(remaining_budget);
let (id, ref mut query) = self.active[i];
let status = query.step(grid, budget);
remaining_budget = remaining_budget.saturating_sub(budget);
match status {
IncrementalStatus::Found | IncrementalStatus::NotFound => {
let result = query.to_path_result(grid);
completed.push(BatchedResult { id, result });
self.active.swap_remove(i);
}
IncrementalStatus::InProgress => {
i += 1;
}
}
}
self.promote_queued(grid);
completed
}
fn promote_queued(&mut self, grid: &NavGrid) {
while self.active.len() < self.max_active {
let entry = match self.pending.pop() {
Some(e) => e,
None => break,
};
let pos = match self.queued.iter().position(|r| r.id == entry.id) {
Some(p) => p,
None => continue, };
let request = self.queued.swap_remove(pos);
if let Some(query) = IncrementalGridPath::new(grid, request.start, request.goal) {
self.active.push((request.id, query));
}
}
}
pub fn clear(&mut self) {
self.pending.clear();
self.active.clear();
self.queued.clear();
}
}
impl Default for PathBatcher {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::path::PathStatus;
#[test]
fn batcher_enqueue_process() {
let grid = NavGrid::new(20, 20, 1.0);
let mut batcher = PathBatcher::new();
let id = batcher.enqueue(
GridPos::new(0, 0),
GridPos::new(19, 19),
RequestPriority(0),
&grid,
);
assert!(id.is_some());
assert_eq!(batcher.request_count(), 1);
let completed = batcher.process(&grid, 10000);
assert_eq!(completed.len(), 1);
assert_eq!(completed[0].result.status, PathStatus::Found);
assert_eq!(batcher.request_count(), 0);
}
#[test]
fn batcher_priority_order() {
let grid = NavGrid::new(10, 10, 1.0);
let mut batcher = PathBatcher::with_max_active(1);
let _low = batcher
.enqueue(
GridPos::new(0, 0),
GridPos::new(9, 0),
RequestPriority(10),
&grid,
)
.unwrap();
let high = batcher
.enqueue(
GridPos::new(0, 0),
GridPos::new(0, 9),
RequestPriority(0),
&grid,
)
.unwrap();
let completed = batcher.process(&grid, 10000);
assert!(!completed.is_empty());
assert_eq!(completed[0].id, high);
}
#[test]
fn batcher_multi_frame() {
let grid = NavGrid::new(50, 50, 1.0);
let mut batcher = PathBatcher::new();
batcher.enqueue(
GridPos::new(0, 0),
GridPos::new(49, 49),
RequestPriority(0),
&grid,
);
let mut total_completed = Vec::new();
for _ in 0..100 {
let completed = batcher.process(&grid, 10);
total_completed.extend(completed);
if !total_completed.is_empty() {
break;
}
}
assert_eq!(total_completed.len(), 1);
assert_eq!(total_completed[0].result.status, PathStatus::Found);
}
#[test]
fn batcher_cancel() {
let grid = NavGrid::new(10, 10, 1.0);
let mut batcher = PathBatcher::new();
let id = batcher
.enqueue(
GridPos::new(0, 0),
GridPos::new(9, 9),
RequestPriority(0),
&grid,
)
.unwrap();
assert!(batcher.cancel(id));
assert_eq!(batcher.request_count(), 0);
}
#[test]
fn batcher_unwalkable_rejected() {
let mut grid = NavGrid::new(10, 10, 1.0);
grid.set_walkable(0, 0, false);
let mut batcher = PathBatcher::new();
let id = batcher.enqueue(
GridPos::new(0, 0),
GridPos::new(9, 9),
RequestPriority(0),
&grid,
);
assert!(id.is_none());
}
#[test]
fn batcher_no_path() {
let mut grid = NavGrid::new(10, 10, 1.0);
for y in 0..10 {
grid.set_walkable(5, y, false);
}
let mut batcher = PathBatcher::new();
batcher.enqueue(
GridPos::new(0, 0),
GridPos::new(9, 9),
RequestPriority(0),
&grid,
);
let completed = batcher.process(&grid, 100_000);
assert_eq!(completed.len(), 1);
assert_eq!(completed[0].result.status, PathStatus::NotFound);
}
#[test]
fn batcher_multiple_requests() {
let grid = NavGrid::new(10, 10, 1.0);
let mut batcher = PathBatcher::new();
for i in 0..5 {
batcher.enqueue(
GridPos::new(0, 0),
GridPos::new(9, i),
RequestPriority(i as u32),
&grid,
);
}
assert_eq!(batcher.request_count(), 5);
let mut total = Vec::new();
for _ in 0..100 {
total.extend(batcher.process(&grid, 1000));
if batcher.request_count() == 0 {
break;
}
}
assert_eq!(total.len(), 5);
for result in &total {
assert_eq!(result.result.status, PathStatus::Found);
}
}
#[test]
fn batcher_clear() {
let grid = NavGrid::new(10, 10, 1.0);
let mut batcher = PathBatcher::new();
batcher.enqueue(
GridPos::new(0, 0),
GridPos::new(9, 9),
RequestPriority(0),
&grid,
);
batcher.enqueue(
GridPos::new(1, 1),
GridPos::new(8, 8),
RequestPriority(1),
&grid,
);
batcher.clear();
assert_eq!(batcher.request_count(), 0);
}
#[test]
fn batcher_max_active_limit() {
let grid = NavGrid::new(50, 50, 1.0);
let mut batcher = PathBatcher::with_max_active(2);
for i in 0..5 {
batcher.enqueue(
GridPos::new(0, 0),
GridPos::new(49, i),
RequestPriority(i as u32),
&grid,
);
}
let _ = batcher.process(&grid, 1);
assert!(batcher.active_count() <= 2);
}
#[test]
fn request_priority_serde_roundtrip() {
let p = RequestPriority(5);
let json = serde_json::to_string(&p).unwrap();
let deserialized: RequestPriority = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.0, 5);
}
#[test]
fn path_request_id_serde_roundtrip() {
let id = PathRequestId(42);
let json = serde_json::to_string(&id).unwrap();
let deserialized: PathRequestId = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.0, 42);
}
}