use std::collections::{HashMap, HashSet};
use std::mem;
use crate::db::GraphDB;
use crate::error::{GraphError, Result};
use crate::model::{Edge, EdgeDirection, EdgeId, Node, NodeId, PropertyValue};
type NodeFilter = Box<dyn Fn(&Node) -> bool + 'static>;
type EdgeFilter = Box<dyn Fn(&Edge) -> bool + 'static>;
#[derive(Debug, Clone)]
struct TraversalSpec {
edge_types: Vec<String>,
direction: EdgeDirection,
depth: usize,
}
#[derive(Debug, Clone)]
enum StartSpec {
Explicit(Vec<NodeId>),
Label(String),
Property {
label: String,
key: String,
value: PropertyValue,
},
}
enum QueryOp {
Start(StartSpec),
FilterNodes(NodeFilter),
FilterEdges(EdgeFilter),
Traverse(TraversalSpec),
}
#[derive(Debug, Clone)]
pub struct QueryResult {
pub start_nodes: Vec<NodeId>,
pub node_ids: Vec<NodeId>,
pub nodes: Vec<Node>,
pub edges: Vec<Edge>,
pub limited: bool,
}
pub struct QueryBuilder<'db> {
db: &'db mut GraphDB,
ops: Vec<QueryOp>,
limit: Option<usize>,
}
impl<'db> QueryBuilder<'db> {
pub fn new(db: &'db mut GraphDB) -> Self {
Self {
db,
ops: Vec::new(),
limit: None,
}
}
pub fn start_from(mut self, node_ids: Vec<NodeId>) -> Self {
self.remove_existing_start();
self.ops.push(QueryOp::Start(StartSpec::Explicit(node_ids)));
self
}
pub fn start_from_label(mut self, label: &str) -> Self {
self.remove_existing_start();
self.ops
.push(QueryOp::Start(StartSpec::Label(label.to_string())));
self
}
pub fn start_from_property(mut self, label: &str, key: &str, value: PropertyValue) -> Self {
self.remove_existing_start();
self.ops.push(QueryOp::Start(StartSpec::Property {
label: label.to_string(),
key: key.to_string(),
value,
}));
self
}
pub fn filter_nodes<F>(mut self, filter: F) -> Self
where
F: Fn(&Node) -> bool + 'static,
{
self.ops.push(QueryOp::FilterNodes(Box::new(filter)));
self
}
pub fn filter_edges<F>(mut self, filter: F) -> Self
where
F: Fn(&Edge) -> bool + 'static,
{
self.ops.push(QueryOp::FilterEdges(Box::new(filter)));
self
}
pub fn traverse(mut self, edge_types: &[&str], direction: EdgeDirection, depth: usize) -> Self {
self.ops.push(QueryOp::Traverse(TraversalSpec {
edge_types: edge_types.iter().map(|s| (*s).to_string()).collect(),
direction,
depth,
}));
self
}
pub fn limit(mut self, n: usize) -> Self {
self.limit = Some(match self.limit {
Some(existing) => existing.min(n),
None => n,
});
self
}
pub fn get_ids(mut self) -> Result<QueryResult> {
if !self.ops.iter().any(|op| matches!(op, QueryOp::Start(_))) {
return Err(GraphError::InvalidArgument(
"QueryBuilder requires a starting point".into(),
));
}
let mut current_nodes: Vec<NodeId> = Vec::new();
let mut edge_filters: Vec<EdgeFilter> = Vec::new();
let mut captured_start: Option<Vec<NodeId>> = None;
let mut collected_edges: HashMap<EdgeId, Edge> = HashMap::new();
let ops = mem::take(&mut self.ops);
for op in ops {
match op {
QueryOp::Start(spec) => {
current_nodes = self.resolve_start(spec)?;
}
QueryOp::FilterNodes(filter) => {
current_nodes = self.apply_node_filter(current_nodes, filter)?;
}
QueryOp::FilterEdges(filter) => {
edge_filters.push(filter);
}
QueryOp::Traverse(spec) => {
if captured_start.is_none() {
captured_start = Some(current_nodes.clone());
}
let (nodes, edges) =
self.execute_traversal(¤t_nodes, &spec, &edge_filters)?;
current_nodes = nodes;
for edge in edges {
collected_edges.entry(edge.id).or_insert(edge);
}
}
}
}
let mut node_ids = if current_nodes.is_empty() {
Vec::new()
} else {
let mut seen = HashSet::new();
let mut ordered = Vec::new();
for node_id in current_nodes {
if seen.insert(node_id) {
ordered.push(node_id);
}
}
ordered
};
if captured_start.is_none() {
captured_start = Some(node_ids.clone());
}
let mut limited = false;
if let Some(limit) = self.limit {
if node_ids.len() > limit {
node_ids.truncate(limit);
limited = true;
}
}
let node_id_set: HashSet<NodeId> = node_ids.iter().copied().collect();
let mut edges: Vec<Edge> = collected_edges
.into_values()
.filter(|edge| {
node_id_set.contains(&edge.source_node_id)
&& node_id_set.contains(&edge.target_node_id)
})
.collect();
edges.sort_by_key(|edge| edge.id);
let mut nodes = Vec::with_capacity(node_ids.len());
for node_id in &node_ids {
if let Some(node) = self.db.get_node(*node_id)? {
nodes.push(node);
}
}
Ok(QueryResult {
start_nodes: captured_start.unwrap_or_default(),
node_ids,
nodes,
edges,
limited,
})
}
pub fn get_nodes(self) -> Result<Vec<Node>> {
let result = self.get_ids()?;
Ok(result.nodes)
}
fn resolve_start(&mut self, spec: StartSpec) -> Result<Vec<NodeId>> {
match spec {
StartSpec::Explicit(nodes) => Ok(nodes),
StartSpec::Label(label) => self.db.get_nodes_by_label(&label),
StartSpec::Property { label, key, value } => {
self.db.find_nodes_by_property(&label, &key, &value)
}
}
}
fn apply_node_filter(&mut self, nodes: Vec<NodeId>, filter: NodeFilter) -> Result<Vec<NodeId>> {
let mut result = Vec::new();
let predicate = &filter;
for node_id in nodes {
if let Some(node) = self.db.get_node(node_id)? {
if predicate(&node) {
result.push(node_id);
}
}
}
Ok(result)
}
fn execute_traversal(
&mut self,
start_nodes: &[NodeId],
spec: &TraversalSpec,
edge_filters: &[EdgeFilter],
) -> Result<(Vec<NodeId>, Vec<Edge>)> {
if spec.depth == 0 {
return Ok((start_nodes.to_vec(), Vec::new()));
}
let mut visited: HashSet<NodeId> = HashSet::new();
let mut ordered: Vec<NodeId> = Vec::new();
let mut frontier: Vec<NodeId> = Vec::new();
let mut edges = Vec::new();
for &node_id in start_nodes {
if visited.insert(node_id) {
ordered.push(node_id);
frontier.push(node_id);
}
}
let edge_type_refs: Vec<&str> = spec.edge_types.iter().map(|ty| ty.as_str()).collect();
let mut depth_remaining = spec.depth;
while depth_remaining > 0 && !frontier.is_empty() {
depth_remaining -= 1;
let mut next_frontier = Vec::new();
for node_id in frontier {
for (neighbor, edge) in self.db.get_neighbors_with_edges_by_type(
node_id,
&edge_type_refs,
spec.direction,
)? {
if !self.edge_passes_filters(&edge, edge_filters) {
continue;
}
if visited.insert(neighbor) {
ordered.push(neighbor);
next_frontier.push(neighbor);
}
edges.push(edge);
}
}
frontier = next_frontier;
}
let mut seen_edge_ids = HashSet::new();
edges.retain(|edge| seen_edge_ids.insert(edge.id));
Ok((ordered, edges))
}
fn edge_passes_filters(&self, edge: &Edge, filters: &[EdgeFilter]) -> bool {
filters.iter().all(|filter| filter(edge))
}
fn remove_existing_start(&mut self) {
self.ops.retain(|op| !matches!(op, QueryOp::Start(_)));
}
}
impl GraphDB {
pub fn query(&mut self) -> QueryBuilder<'_> {
QueryBuilder::new(self)
}
}