use std::collections::{BinaryHeap, HashMap};
use std::cmp::Reverse;
use crate::graph::Graph;
use crate::types::{ulid_encode, DbError, Value};
use super::{
opt_direction, opt_f64, opt_node_idx, opt_weight_prop, require_node_idx, Direction,
GraphSnapshot, Row,
};
pub fn run(graph: &Graph, params: &HashMap<String, Value>) -> Result<Vec<Row>, DbError> {
let weight_prop = opt_weight_prop(params)?;
let snap = GraphSnapshot::build(graph, weight_prop);
if snap.n == 0 {
return Ok(vec![]);
}
let source = require_node_idx(params, "source", &snap)?;
let target = opt_node_idx(params, "target", &snap)?;
let direction = opt_direction(params, "direction", Direction::Out)?;
let max_cost = opt_f64(params, "maxCost", f64::INFINITY)?;
if max_cost < 0.0 {
return Err(DbError::Query("parameter 'maxCost' must be non-negative".into()));
}
dijkstra(&snap, source, target, direction, max_cost)
}
fn dijkstra(
snap: &GraphSnapshot,
source: usize,
target: Option<usize>,
dir: Direction,
max_cost: f64,
) -> Result<Vec<Row>, DbError> {
let n = snap.n;
let mut dist = vec![f64::INFINITY; n];
let mut prev: Vec<Option<usize>> = vec![None; n];
dist[source] = 0.0;
let mut heap: BinaryHeap<Reverse<(u64, usize)>> = BinaryHeap::new();
heap.push(Reverse((0u64, source)));
while let Some(Reverse((d_bits, u))) = heap.pop() {
let d = f64::from_bits(d_bits);
if d > dist[u] {
continue;
}
if target == Some(u) {
break;
}
for (v, w) in snap.neighbors(u, dir) {
if w < 0.0 {
return Err(DbError::Query(format!(
"shortestPath: negative edge weight {w} encountered. \
Dijkstra requires non-negative weights. \
Consider using absolute weights or a Bellman-Ford variant."
)));
}
let new_dist = dist[u] + w;
if new_dist < dist[v] && new_dist <= max_cost {
dist[v] = new_dist;
prev[v] = Some(u);
heap.push(Reverse((new_dist.to_bits(), v)));
}
}
}
let source_str = Value::String(ulid_encode(snap.node_ids[source].0));
let targets: Vec<usize> = match target {
Some(t) => {
if dist[t].is_finite() { vec![t] } else { vec![] }
}
None => (0..n).filter(|&i| i != source && dist[i].is_finite()).collect(),
};
let mut rows = Vec::with_capacity(targets.len());
for t in targets {
let path = reconstruct_path(&prev, source, t, snap);
let mut row = HashMap::new();
row.insert("source".to_string(), source_str.clone());
row.insert("target".to_string(), Value::String(ulid_encode(snap.node_ids[t].0)));
row.insert("cost".to_string(), Value::Float(dist[t]));
row.insert("path".to_string(), Value::List(path));
rows.push(row);
}
Ok(rows)
}
fn reconstruct_path(
prev: &[Option<usize>],
source: usize,
target: usize,
snap: &GraphSnapshot,
) -> Vec<Value> {
let mut path = Vec::new();
let mut curr = target;
loop {
path.push(Value::String(ulid_encode(snap.node_ids[curr].0)));
if curr == source {
break;
}
match prev[curr] {
Some(p) => curr = p,
None => break, }
}
path.reverse();
path
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::Graph;
use crate::types::{ulid_encode, Edge, EdgeId, Node, NodeId, Value};
fn make_node(g: &mut Graph, name: &str) -> NodeId {
let id = g.alloc_node_id();
g.apply_insert_node(Node {
id,
labels: vec!["N".into()],
properties: [("name".to_string(), Value::String(name.to_string()))]
.into_iter()
.collect(),
});
id
}
fn make_edge(g: &mut Graph, from: NodeId, to: NodeId, weight: f64) {
let id = g.alloc_edge_id();
g.apply_insert_edge(Edge {
id,
from_node: from,
to_node: to,
label: "E".into(),
properties: [("w".to_string(), Value::Float(weight))].into_iter().collect(),
directed: true,
});
}
#[test]
fn unweighted_direct_path() {
let mut g = Graph::new();
let a = make_node(&mut g, "a");
let b = make_node(&mut g, "b");
make_edge(&mut g, a, b, 1.0);
let params: HashMap<String, Value> = [
("source".to_string(), Value::String(ulid_encode(a.0))),
("target".to_string(), Value::String(ulid_encode(b.0))),
]
.into_iter()
.collect();
let rows = run(&g, ¶ms).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0]["cost"], Value::Float(1.0));
}
#[test]
fn weighted_chooses_cheapest_path() {
let mut g = Graph::new();
let a = make_node(&mut g, "a");
let b = make_node(&mut g, "b");
let c = make_node(&mut g, "c");
make_edge(&mut g, a, b, 10.0);
make_edge(&mut g, b, c, 1.0);
make_edge(&mut g, a, c, 3.0);
let params: HashMap<String, Value> = [
("source".to_string(), Value::String(ulid_encode(a.0))),
("target".to_string(), Value::String(ulid_encode(c.0))),
("weight".to_string(), Value::String("w".into())),
]
.into_iter()
.collect();
let rows = run(&g, ¶ms).unwrap();
assert_eq!(rows[0]["cost"], Value::Float(3.0));
if let Value::List(p) = &rows[0]["path"] {
assert_eq!(p.len(), 2);
} else {
panic!("path not a list");
}
}
#[test]
fn no_path_returns_empty() {
let mut g = Graph::new();
let a = make_node(&mut g, "a");
let b = make_node(&mut g, "b");
let params: HashMap<String, Value> = [
("source".to_string(), Value::String(ulid_encode(a.0))),
("target".to_string(), Value::String(ulid_encode(b.0))),
]
.into_iter()
.collect();
let rows = run(&g, ¶ms).unwrap();
assert!(rows.is_empty());
}
#[test]
fn negative_weight_errors() {
let mut g = Graph::new();
let a = make_node(&mut g, "a");
let b = make_node(&mut g, "b");
make_edge(&mut g, a, b, -5.0);
let params: HashMap<String, Value> = [
("source".to_string(), Value::String(ulid_encode(a.0))),
("target".to_string(), Value::String(ulid_encode(b.0))),
("weight".to_string(), Value::String("w".into())),
]
.into_iter()
.collect();
assert!(run(&g, ¶ms).is_err());
}
#[test]
fn all_targets_when_target_omitted() {
let mut g = Graph::new();
let a = make_node(&mut g, "a");
let b = make_node(&mut g, "b");
let c = make_node(&mut g, "c");
make_edge(&mut g, a, b, 1.0);
make_edge(&mut g, a, c, 2.0);
let params: HashMap<String, Value> = [(
"source".to_string(),
Value::String(ulid_encode(a.0)),
)]
.into_iter()
.collect();
let rows = run(&g, ¶ms).unwrap();
assert_eq!(rows.len(), 2); }
#[test]
fn max_cost_prunes_long_paths() {
let mut g = Graph::new();
let a = make_node(&mut g, "a");
let b = make_node(&mut g, "b");
let c = make_node(&mut g, "c");
make_edge(&mut g, a, b, 1.0);
make_edge(&mut g, b, c, 5.0); let params: HashMap<String, Value> = [
("source".to_string(), Value::String(ulid_encode(a.0))),
("weight".to_string(), Value::String("w".into())),
("maxCost".to_string(), Value::Float(3.0)),
]
.into_iter()
.collect();
let rows = run(&g, ¶ms).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0]["cost"], Value::Float(1.0));
}
#[test]
fn empty_graph_returns_empty() {
let g = Graph::new();
let params: HashMap<String, Value> = HashMap::new();
assert!(run(&g, ¶ms).is_ok()); }
}