use std::cmp::Ordering;
use std::collections::BinaryHeap;
use ndarray::Array2;
use crate::cost::CostField;
use crate::error::{Error, Result};
pub struct SolveResult {
pub(crate) distance: Array2<f64>,
pub(crate) predecessors: Vec<Option<usize>>,
pub(crate) finalized: Vec<bool>,
pub(crate) width: usize,
}
impl SolveResult {
pub fn distance(&self) -> &Array2<f64> {
&self.distance
}
pub fn path_to(&self, target_row: usize, target_col: usize) -> Result<Path> {
let (h, _) = self.distance.dim();
let w = self.width;
if target_row >= h || target_col >= w {
return Err(Error::OutOfBounds {
row: target_row,
col: target_col,
height: h,
width: w,
});
}
let idx = target_row * w + target_col;
let cost = self.distance[[target_row, target_col]];
if cost.is_infinite() || !self.finalized[idx] {
return Err(Error::NoPathFound);
}
let mut cells = Vec::new();
let mut idx = idx;
loop {
let r = idx / w;
let c = idx % w;
cells.push((r, c));
if let Some(pred) = self.predecessors[idx] {
idx = pred;
} else {
break;
}
}
cells.reverse();
Ok(Path { cells, cost })
}
}
#[derive(Clone, Debug)]
pub struct Path {
pub cells: Vec<(usize, usize)>,
pub cost: f64,
}
pub fn solve(cost: &CostField, source: (usize, usize)) -> Result<SolveResult> {
solve_multi(cost, &[source])
}
pub fn solve_multi(cost: &CostField, sources: &[(usize, usize)]) -> Result<SolveResult> {
solve_inner(cost, sources, None)
}
pub fn solve_to(
cost: &CostField,
source: (usize, usize),
target: (usize, usize),
) -> Result<SolveResult> {
solve_inner(cost, &[source], Some(target))
}
fn solve_inner(
cost: &CostField,
sources: &[(usize, usize)],
target: Option<(usize, usize)>,
) -> Result<SolveResult> {
let (h, w) = cost.dim();
if h == 0 || w == 0 {
return Err(Error::InvalidParameter("cost field must be non-empty"));
}
if sources.is_empty() {
return Err(Error::InvalidParameter("at least one source is required"));
}
for &(sr, sc) in sources {
if sr >= h || sc >= w {
return Err(Error::OutOfBounds {
row: sr,
col: sc,
height: h,
width: w,
});
}
if cost.at(sr, sc) <= 0.0 {
return Err(Error::InvalidParameter("source cell must be traversable"));
}
}
if let Some((tr, tc)) = target {
if tr >= h || tc >= w {
return Err(Error::OutOfBounds {
row: tr,
col: tc,
height: h,
width: w,
});
}
}
let n = grid_len(h, w)?;
let mut dist = vec![f64::INFINITY; n];
let mut pred: Vec<Option<usize>> = vec![None; n];
let mut visited = vec![false; n];
let mut heap = BinaryHeap::with_capacity(n / 4);
for &(sr, sc) in sources {
let idx = sr * w + sc;
dist[idx] = 0.0;
heap.push(Node { cost: 0.0, idx });
}
while let Some(node) = heap.pop() {
if visited[node.idx] {
continue;
}
visited[node.idx] = true;
if let Some((tr, tc)) = target {
if node.idx == tr * w + tc {
break;
}
}
let row = node.idx / w;
let col = node.idx % w;
for &(dr, dc) in &NEIGHBORS {
let nr = row as isize + dr;
let nc = col as isize + dc;
if nr < 0 || nr >= h as isize || nc < 0 || nc >= w as isize {
continue;
}
let nr = nr as usize;
let nc = nc as usize;
let n_idx = nr * w + nc;
if visited[n_idx] {
continue;
}
let Some(edge_cost) = edge_cost(
cost,
row,
col,
nr,
nc,
dr.unsigned_abs() + dc.unsigned_abs() == 2,
)?
else {
continue;
};
let new_dist = dist[node.idx] + edge_cost;
if !new_dist.is_finite() {
return Err(Error::InvalidParameter(
"graph path costs must remain finite",
));
}
if new_dist < dist[n_idx] {
dist[n_idx] = new_dist;
pred[n_idx] = Some(node.idx);
heap.push(Node {
cost: new_dist,
idx: n_idx,
});
}
}
}
result_from_parts(h, w, dist, pred, visited)
}
fn result_from_parts(
h: usize,
w: usize,
mut dist: Vec<f64>,
mut pred: Vec<Option<usize>>,
finalized: Vec<bool>,
) -> Result<SolveResult> {
for (idx, is_finalized) in finalized.iter().copied().enumerate() {
if !is_finalized {
dist[idx] = f64::INFINITY;
pred[idx] = None;
}
}
let distance = Array2::from_shape_vec((h, w), dist).unwrap();
Ok(SolveResult {
distance,
predecessors: pred,
finalized,
width: w,
})
}
fn edge_cost(
cost: &CostField,
row: usize,
col: usize,
next_row: usize,
next_col: usize,
diagonal: bool,
) -> Result<Option<f64>> {
let from_cost = cost.at(row, col);
let to_cost = cost.at(next_row, next_col);
if from_cost <= 0.0 || to_cost <= 0.0 {
return Ok(None);
}
let scalar_cost = (from_cost + to_cost) * 0.5;
if !scalar_cost.is_finite() {
return Err(Error::InvalidParameter(
"graph edge costs must remain finite",
));
}
let geom_dist = if diagonal {
std::f64::consts::SQRT_2
} else {
1.0
};
let edge_cost = geom_dist * scalar_cost;
if !edge_cost.is_finite() {
return Err(Error::InvalidParameter(
"graph edge costs must remain finite",
));
}
Ok(Some(edge_cost))
}
fn grid_len(height: usize, width: usize) -> Result<usize> {
height
.checked_mul(width)
.ok_or(Error::InvalidParameter("grid dimensions are too large"))
}
const NEIGHBORS: [(isize, isize); 8] = [
(-1, 0),
(1, 0),
(0, -1),
(0, 1),
(-1, -1),
(-1, 1),
(1, -1),
(1, 1),
];
#[derive(Clone, PartialEq)]
struct Node {
cost: f64,
idx: usize,
}
impl Eq for Node {}
impl PartialOrd for Node {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Node {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn solve_single_source_flat() {
let cost = CostField::uniform(10, 10);
let result = solve(&cost, (0, 0)).unwrap();
assert_eq!(result.distance()[[0, 0]], 0.0);
assert!(result.distance()[[0, 1]] > 0.0);
}
#[test]
fn solve_source_out_of_bounds() {
let cost = CostField::uniform(10, 10);
assert!(solve(&cost, (10, 0)).is_err());
}
#[test]
fn empty_sources_rejected() {
let cost = CostField::uniform(10, 10);
assert!(solve_multi(&cost, &[]).is_err());
}
#[test]
fn impassable_source_rejected() {
let mut data = Array2::ones((5, 5));
data[[2, 2]] = 0.0;
let cost = CostField::from_array(data).unwrap();
assert!(solve(&cost, (2, 2)).is_err());
}
#[test]
fn solve_path_to_adjacent() {
let cost = CostField::uniform(5, 5);
let result = solve(&cost, (2, 2)).unwrap();
let path = result.path_to(2, 3).unwrap();
assert_eq!(path.cells.first(), Some(&(2, 2)));
assert_eq!(path.cells.last(), Some(&(2, 3)));
assert!((path.cost - 1.0).abs() < 1e-10);
}
#[test]
fn solve_path_to_diagonal() {
let cost = CostField::uniform(5, 5);
let result = solve(&cost, (0, 0)).unwrap();
let path = result.path_to(1, 1).unwrap();
assert!((path.cost - std::f64::consts::SQRT_2).abs() < 1e-10);
}
#[test]
fn scalar_edge_costs_are_symmetric() {
let data = Array2::from_shape_vec((1, 2), vec![1.0, 3.0]).unwrap();
let cost = CostField::from_array(data).unwrap();
let forward = solve(&cost, (0, 0)).unwrap();
let reverse = solve(&cost, (0, 1)).unwrap();
assert!((forward.distance()[[0, 1]] - 2.0).abs() < 1e-10);
assert!((forward.distance()[[0, 1]] - reverse.distance()[[0, 0]]).abs() < 1e-10);
}
#[test]
fn overflowing_scalar_edge_cost_rejected() {
let data = Array2::from_shape_vec((1, 2), vec![f64::MAX, f64::MAX]).unwrap();
let cost = CostField::from_array(data).unwrap();
assert!(solve(&cost, (0, 0)).is_err());
}
#[test]
fn solve_unreachable_target() {
let mut data = Array2::ones((5, 5));
for r in 0..5 {
data[[r, 2]] = 0.0;
}
let cost = CostField::from_array(data).unwrap();
let result = solve(&cost, (2, 0)).unwrap();
assert!(result.path_to(2, 4).is_err());
}
#[test]
fn solve_routes_around_obstacle() {
let mut data = Array2::ones((5, 5));
for r in 0..5 {
data[[r, 2]] = 0.0;
}
data[[2, 2]] = 1.0;
let cost = CostField::from_array(data).unwrap();
let result = solve(&cost, (0, 0)).unwrap();
let path = result.path_to(0, 4).unwrap();
assert!(path.cells.iter().any(|&(r, c)| r == 2 && c == 2));
}
#[test]
fn solve_multi_sources() {
let cost = CostField::uniform(10, 10);
let result = solve_multi(&cost, &[(0, 0), (9, 9)]).unwrap();
let d_center = result.distance()[[5, 5]];
let d_corner = result.distance()[[0, 0]];
assert_eq!(d_corner, 0.0);
assert!(d_center > 0.0);
let d_far = result.distance()[[9, 9]];
assert_eq!(d_far, 0.0);
}
#[test]
fn solve_to_early_termination() {
let cost = CostField::uniform(100, 100);
let result = solve_to(&cost, (0, 0), (5, 5)).unwrap();
assert!(result.distance()[[5, 5]].is_finite());
let path = result.path_to(5, 5).unwrap();
assert_eq!(path.cells.first(), Some(&(0, 0)));
assert_eq!(path.cells.last(), Some(&(5, 5)));
assert!(result.distance()[[0, 8]].is_infinite());
assert!(matches!(result.path_to(0, 8), Err(Error::NoPathFound)));
}
#[test]
fn solve_high_cost_region_avoided() {
let mut data = Array2::ones((5, 5));
for c in 0..5 {
data[[2, c]] = 100.0;
}
data[[2, 2]] = 1.0;
let cost = CostField::from_array(data).unwrap();
let result = solve(&cost, (0, 2)).unwrap();
let path = result.path_to(4, 2).unwrap();
assert!(path.cells.iter().any(|&(r, c)| r == 2 && c == 2));
}
#[test]
fn distance_field_symmetry() {
let cost = CostField::uniform(11, 11);
let result = solve(&cost, (5, 5)).unwrap();
let d = result.distance();
assert!((d[[4, 5]] - d[[6, 5]]).abs() < 1e-10);
assert!((d[[5, 4]] - d[[5, 6]]).abs() < 1e-10);
assert!((d[[4, 4]] - d[[6, 6]]).abs() < 1e-10);
}
}