use std::collections::HashMap;
use crate::graph::Graph;
use crate::types::{ulid_decode, DbError, NodeId, Value};
pub mod bfs;
pub mod centrality;
pub mod community;
pub mod components;
pub mod flow;
pub mod louvain;
pub mod shortest_path;
pub mod similarity;
pub mod triangle;
pub type Row = HashMap<String, Value>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Direction {
Out,
In,
Any,
}
impl Direction {
fn from_str(s: &str) -> Result<Self, DbError> {
match s.to_ascii_lowercase().as_str() {
"out" | "outgoing" => Ok(Direction::Out),
"in" | "incoming" => Ok(Direction::In),
"any" | "both" => Ok(Direction::Any),
other => Err(DbError::Query(format!(
"invalid direction '{other}': expected \"out\", \"in\", or \"any\""
))),
}
}
}
pub struct GraphSnapshot {
pub node_ids: Vec<NodeId>,
pub id_to_idx: HashMap<NodeId, usize>,
pub adj_out: Vec<Vec<(usize, f64)>>,
pub adj_in: Vec<Vec<(usize, f64)>>,
pub n: usize,
}
impl GraphSnapshot {
pub fn build(graph: &Graph, weight_prop: Option<&str>) -> Self {
let nodes = graph.all_nodes();
let n = nodes.len();
let node_ids: Vec<NodeId> = nodes.iter().map(|nd| nd.id).collect();
let id_to_idx: HashMap<NodeId, usize> =
node_ids.iter().copied().enumerate().map(|(i, id)| (id, i)).collect();
let mut adj_out = vec![Vec::new(); n];
let mut adj_in = vec![Vec::new(); n];
for edge in graph.all_edges() {
let (Some(&fi), Some(&ti)) = (
id_to_idx.get(&edge.from_node),
id_to_idx.get(&edge.to_node),
) else {
continue;
};
let w = weight_prop
.and_then(|k| edge.properties.get(k))
.and_then(|v| match v {
Value::Float(f) => Some(*f),
Value::Int(i) => Some(*i as f64),
_ => None,
})
.unwrap_or(1.0);
adj_out[fi].push((ti, w));
adj_in[ti].push((fi, w));
if !edge.directed {
adj_out[ti].push((fi, w));
adj_in[fi].push((ti, w));
}
}
Self { node_ids, id_to_idx, adj_out, adj_in, n }
}
pub fn neighbors(&self, i: usize, dir: Direction) -> Vec<(usize, f64)> {
match dir {
Direction::Out => self.adj_out[i].clone(),
Direction::In => self.adj_in[i].clone(),
Direction::Any => {
let mut nbrs = self.adj_out[i].clone();
nbrs.extend_from_slice(&self.adj_in[i]);
nbrs
}
}
}
pub fn unique_neighbor_indices(&self, i: usize, dir: Direction) -> Vec<usize> {
let mut seen = std::collections::HashSet::new();
self.neighbors(i, dir)
.into_iter()
.filter_map(|(j, _)| seen.insert(j).then_some(j))
.collect()
}
}
pub fn require_str<'a>(
params: &'a HashMap<String, Value>,
key: &str,
) -> Result<&'a str, DbError> {
match params.get(key) {
Some(Value::String(s)) => Ok(s.as_str()),
Some(other) => Err(DbError::Query(format!(
"parameter '{key}' must be a string, got {other:?}"
))),
None => Err(DbError::Query(format!("required parameter '{key}' is missing"))),
}
}
pub fn require_node_idx(
params: &HashMap<String, Value>,
key: &str,
snap: &GraphSnapshot,
) -> Result<usize, DbError> {
let s = require_str(params, key)?;
let raw = ulid_decode(s)
.map_err(|e| DbError::Query(format!("parameter '{key}' is not a valid ULID: {e}")))?;
snap.id_to_idx.get(&NodeId(raw)).copied().ok_or_else(|| {
DbError::Query(format!("parameter '{key}': node '{s}' not found in graph"))
})
}
pub fn opt_node_idx(
params: &HashMap<String, Value>,
key: &str,
snap: &GraphSnapshot,
) -> Result<Option<usize>, DbError> {
match params.get(key) {
None => Ok(None),
Some(Value::String(s)) => {
let raw = ulid_decode(s).map_err(|e| {
DbError::Query(format!("parameter '{key}' is not a valid ULID: {e}"))
})?;
let idx = snap.id_to_idx.get(&NodeId(raw)).copied().ok_or_else(|| {
DbError::Query(format!("parameter '{key}': node '{s}' not found in graph"))
})?;
Ok(Some(idx))
}
Some(other) => Err(DbError::Query(format!(
"parameter '{key}' must be a string (ULID), got {other:?}"
))),
}
}
pub fn opt_f64(
params: &HashMap<String, Value>,
key: &str,
default: f64,
) -> Result<f64, DbError> {
match params.get(key) {
None => Ok(default),
Some(Value::Float(f)) => {
if f.is_nan() || f.is_infinite() {
Err(DbError::Query(format!("parameter '{key}' must be a finite number")))
} else {
Ok(*f)
}
}
Some(Value::Int(i)) => Ok(*i as f64),
Some(other) => Err(DbError::Query(format!(
"parameter '{key}' must be a number, got {other:?}"
))),
}
}
pub fn opt_usize(
params: &HashMap<String, Value>,
key: &str,
default: usize,
) -> Result<usize, DbError> {
match params.get(key) {
None => Ok(default),
Some(Value::Int(i)) if *i >= 0 => Ok(*i as usize),
Some(Value::Int(i)) => Err(DbError::Query(format!(
"parameter '{key}' must be a non-negative integer, got {i}"
))),
Some(other) => Err(DbError::Query(format!(
"parameter '{key}' must be an integer, got {other:?}"
))),
}
}
pub fn opt_bool(
params: &HashMap<String, Value>,
key: &str,
default: bool,
) -> Result<bool, DbError> {
match params.get(key) {
None => Ok(default),
Some(Value::Bool(b)) => Ok(*b),
Some(other) => Err(DbError::Query(format!(
"parameter '{key}' must be a boolean, got {other:?}"
))),
}
}
pub fn opt_direction(
params: &HashMap<String, Value>,
key: &str,
default: Direction,
) -> Result<Direction, DbError> {
match params.get(key) {
None => Ok(default),
Some(Value::String(s)) => Direction::from_str(s),
Some(other) => Err(DbError::Query(format!(
"parameter '{key}' must be a string (\"out\", \"in\", or \"any\"), got {other:?}"
))),
}
}
pub fn opt_str<'a>(
params: &'a HashMap<String, Value>,
key: &str,
default: &'a str,
) -> Result<&'a str, DbError> {
match params.get(key) {
None => Ok(default),
Some(Value::String(s)) => Ok(s.as_str()),
Some(other) => Err(DbError::Query(format!(
"parameter '{key}' must be a string, got {other:?}"
))),
}
}
pub fn opt_weight_prop<'a>(
params: &'a HashMap<String, Value>,
) -> Result<Option<&'a str>, DbError> {
match params.get("weight") {
None => Ok(None),
Some(Value::String(s)) => Ok(Some(s.as_str())),
Some(other) => Err(DbError::Query(format!(
"parameter 'weight' must be a string property name, got {other:?}"
))),
}
}
pub fn dispatch_call(
graph: &Graph,
name: &str,
params: &HashMap<String, Value>,
) -> Result<Vec<Row>, DbError> {
match name {
"bfs" | "dfs" => bfs::run(graph, params),
"shortestPath" => shortest_path::run(graph, params),
"wcc" => components::run_wcc(graph, params),
"scc" => components::run_scc(graph, params),
"pageRank" | "pagerank" => centrality::run_pagerank(graph, params),
"betweennessCentrality" => centrality::run_betweenness(graph, params),
"closenessCentrality" => centrality::run_closeness(graph, params),
"degreeCentrality" => centrality::run_degree(graph, params),
"triangleCount" => triangle::run(graph, params),
"jaccardSimilarity" => similarity::run_jaccard(graph, params),
"labelPropagation" => community::run_label_propagation(graph, params),
"louvain" => louvain::run_louvain(graph, params),
"leiden" => louvain::run_leiden(graph, params),
"maxFlow" => flow::run_max_flow(graph, params),
other => Err(DbError::Query(format!(
"unknown algorithm '{other}'. Available algorithms: \
bfs, dfs, shortestPath, wcc, scc, pageRank, \
betweennessCentrality, closenessCentrality, degreeCentrality, \
triangleCount, jaccardSimilarity, labelPropagation, \
louvain, leiden, maxFlow"
))),
}
}