use std::collections::{HashMap, HashSet, VecDeque};
use crate::graph::Graph;
use crate::types::{ulid_encode, DbError, Value};
use super::{opt_direction, opt_str, opt_usize, require_node_idx, Direction, GraphSnapshot, Row};
pub fn run(graph: &Graph, params: &HashMap<String, Value>) -> Result<Vec<Row>, DbError> {
let algo = opt_str(params, "algorithm", "bfs")?;
if !matches!(algo, "bfs" | "dfs") {
return Err(DbError::Query(format!(
"unknown 'algorithm' value '{algo}': expected \"bfs\" or \"dfs\""
)));
}
let direction = opt_direction(params, "direction", Direction::Out)?;
let max_depth = opt_usize(params, "maxDepth", usize::MAX)?;
let _ = super::require_str(params, "start")?;
let snap = GraphSnapshot::build(graph, None);
if snap.n == 0 {
return Ok(vec![]);
}
let start = require_node_idx(params, "start", &snap)?;
match algo {
"bfs" => run_bfs(&snap, start, max_depth, direction),
_ => run_dfs(&snap, start, max_depth, direction),
}
}
fn run_bfs(
snap: &GraphSnapshot,
start: usize,
max_depth: usize,
dir: Direction,
) -> Result<Vec<Row>, DbError> {
let mut visited: HashSet<usize> = HashSet::new();
let mut queue: VecDeque<(usize, usize, Option<usize>)> = VecDeque::new();
let mut rows = Vec::new();
visited.insert(start);
queue.push_back((start, 0, None));
while let Some((curr, depth, pred)) = queue.pop_front() {
rows.push(make_row(snap, curr, depth, pred));
if depth < max_depth {
for (nbr, _) in snap.neighbors(curr, dir) {
if visited.insert(nbr) {
queue.push_back((nbr, depth + 1, Some(curr)));
}
}
}
}
Ok(rows)
}
fn run_dfs(
snap: &GraphSnapshot,
start: usize,
max_depth: usize,
dir: Direction,
) -> Result<Vec<Row>, DbError> {
let mut visited: HashSet<usize> = HashSet::new();
let mut stack: Vec<(usize, usize, Option<usize>)> = vec![(start, 0, None)];
let mut rows = Vec::new();
while let Some((curr, depth, pred)) = stack.pop() {
if !visited.insert(curr) {
continue;
}
rows.push(make_row(snap, curr, depth, pred));
if depth < max_depth {
let nbrs = snap.neighbors(curr, dir);
for (nbr, _) in nbrs.into_iter().rev() {
if !visited.contains(&nbr) {
stack.push((nbr, depth + 1, Some(curr)));
}
}
}
}
Ok(rows)
}
fn make_row(snap: &GraphSnapshot, node: usize, depth: usize, pred: Option<usize>) -> Row {
let mut row = HashMap::new();
row.insert(
"node".to_string(),
Value::String(ulid_encode(snap.node_ids[node].0)),
);
row.insert("depth".to_string(), Value::Int(depth as i64));
row.insert(
"predecessor".to_string(),
pred.map(|p| Value::String(ulid_encode(snap.node_ids[p].0)))
.unwrap_or(Value::Null),
);
row
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::Graph;
use crate::types::{ulid_encode, Value};
fn insert_node(g: &mut Graph, name: &str) -> crate::types::NodeId {
let mut next = 0u64;
let id = g.alloc_node_id();
let node = crate::types::Node {
id,
labels: vec!["N".into()],
properties: [("name".to_string(), Value::String(name.to_string()))]
.into_iter()
.collect(),
};
g.apply_insert_node(node);
let _ = next;
id
}
fn insert_edge(g: &mut Graph, from: crate::types::NodeId, to: crate::types::NodeId) {
let id = g.alloc_edge_id();
let edge = crate::types::Edge {
id,
from_node: from,
to_node: to,
label: "E".into(),
properties: Default::default(),
directed: true,
};
g.apply_insert_edge(edge);
}
fn linear_graph() -> (Graph, Vec<crate::types::NodeId>) {
let mut g = Graph::new();
let ids: Vec<_> = ["a", "b", "c", "d"].iter().map(|n| insert_node(&mut g, n)).collect();
insert_edge(&mut g, ids[0], ids[1]);
insert_edge(&mut g, ids[1], ids[2]);
insert_edge(&mut g, ids[2], ids[3]);
(g, ids)
}
#[test]
fn bfs_linear_full_depth() {
let (g, ids) = linear_graph();
let params: HashMap<String, Value> = [(
"start".to_string(),
Value::String(ulid_encode(ids[0].0)),
)]
.into_iter()
.collect();
let rows = run(&g, ¶ms).unwrap();
assert_eq!(rows.len(), 4); let depths: Vec<i64> = rows.iter().map(|r| {
if let Value::Int(d) = r["depth"] { d } else { panic!() }
}).collect();
assert_eq!(depths, vec![0, 1, 2, 3]);
}
#[test]
fn bfs_max_depth_one() {
let (g, ids) = linear_graph();
let params: HashMap<String, Value> = [
("start".to_string(), Value::String(ulid_encode(ids[0].0))),
("maxDepth".to_string(), Value::Int(1)),
]
.into_iter()
.collect();
let rows = run(&g, ¶ms).unwrap();
assert_eq!(rows.len(), 2); }
#[test]
fn bfs_depth_zero_returns_start_only() {
let (g, ids) = linear_graph();
let params: HashMap<String, Value> = [
("start".to_string(), Value::String(ulid_encode(ids[0].0))),
("maxDepth".to_string(), Value::Int(0)),
]
.into_iter()
.collect();
let rows = run(&g, ¶ms).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0]["depth"], Value::Int(0));
assert_eq!(rows[0]["predecessor"], Value::Null);
}
#[test]
fn dfs_visits_all_nodes() {
let (g, ids) = linear_graph();
let params: HashMap<String, Value> = [
("start".to_string(), Value::String(ulid_encode(ids[0].0))),
("algorithm".to_string(), Value::String("dfs".into())),
]
.into_iter()
.collect();
let rows = run(&g, ¶ms).unwrap();
assert_eq!(rows.len(), 4);
}
#[test]
fn bfs_direction_any_undirected() {
let mut g = Graph::new();
let a = insert_node(&mut g, "a");
let b = insert_node(&mut g, "b");
let id = g.alloc_edge_id();
g.apply_insert_edge(crate::types::Edge {
id,
from_node: a,
to_node: b,
label: "E".into(),
properties: Default::default(),
directed: false,
});
let params: HashMap<String, Value> = [
("start".to_string(), Value::String(ulid_encode(b.0))),
("direction".to_string(), Value::String("any".into())),
]
.into_iter()
.collect();
let rows = run(&g, ¶ms).unwrap();
assert_eq!(rows.len(), 2);
}
#[test]
fn bfs_cycle_does_not_loop() {
let mut g = Graph::new();
let a = insert_node(&mut g, "a");
let b = insert_node(&mut g, "b");
let c = insert_node(&mut g, "c");
insert_edge(&mut g, a, b);
insert_edge(&mut g, b, c);
insert_edge(&mut g, c, a);
let params: HashMap<String, Value> = [(
"start".to_string(),
Value::String(ulid_encode(a.0)),
)]
.into_iter()
.collect();
let rows = run(&g, ¶ms).unwrap();
assert_eq!(rows.len(), 3); }
#[test]
fn missing_start_param_errors() {
let g = Graph::new();
let params: HashMap<String, Value> = HashMap::new();
assert!(run(&g, ¶ms).is_err());
}
#[test]
fn invalid_direction_errors() {
let (g, ids) = linear_graph();
let params: HashMap<String, Value> = [
("start".to_string(), Value::String(ulid_encode(ids[0].0))),
("direction".to_string(), Value::String("sideways".into())),
]
.into_iter()
.collect();
assert!(run(&g, ¶ms).is_err());
}
#[test]
fn empty_graph_returns_empty() {
let g = Graph::new();
let params: HashMap<String, Value> = HashMap::new();
let rows = run(&g, ¶ms).unwrap_or_default();
assert!(rows.is_empty());
}
}