use crate::entity::Entity;
use crate::relation::{Direction, Relation};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraversalQuery {
pub start: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub target: Option<String>,
#[serde(default = "default_depth")]
pub max_depth: u32,
#[serde(default)]
pub direction: Direction,
#[serde(default)]
pub entity_type_filter: Vec<String>,
#[serde(default)]
pub relation_type_filter: Vec<String>,
#[serde(default)]
pub use_weights: bool,
#[serde(default)]
pub all_paths: bool,
#[serde(default = "default_max_paths")]
pub max_paths: usize,
}
fn default_depth() -> u32 {
10
}
fn default_max_paths() -> usize {
5
}
impl Default for TraversalQuery {
fn default() -> Self {
Self {
start: String::new(),
target: None,
max_depth: default_depth(),
direction: Direction::Both,
entity_type_filter: Vec::new(),
relation_type_filter: Vec::new(),
use_weights: false,
all_paths: false,
max_paths: default_max_paths(),
}
}
}
impl TraversalQuery {
pub fn new(start: impl Into<String>) -> Self {
Self {
start: start.into(),
..Default::default()
}
}
pub fn find_path_to(mut self, target: impl Into<String>) -> Self {
self.target = Some(target.into());
self
}
pub fn with_depth(mut self, depth: u32) -> Self {
self.max_depth = depth;
self
}
pub fn with_direction(mut self, direction: Direction) -> Self {
self.direction = direction;
self
}
pub fn filter_entity_types(mut self, types: Vec<String>) -> Self {
self.entity_type_filter = types;
self
}
pub fn filter_relation_types(mut self, types: Vec<String>) -> Self {
self.relation_type_filter = types;
self
}
pub fn weighted(mut self) -> Self {
self.use_weights = true;
self
}
pub fn all_paths(mut self, max: usize) -> Self {
self.all_paths = true;
self.max_paths = max;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphPath {
pub nodes: Vec<String>,
pub edges: Vec<PathEdge>,
pub total_weight: f64,
pub length: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PathEdge {
pub from: String,
pub to: String,
pub relation_type: String,
pub weight: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TraversalResult {
pub start: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub target: Option<String>,
pub paths: Vec<GraphPath>,
pub visited_entities: Vec<String>,
pub entities: Vec<Entity>,
pub relations: Vec<Relation>,
pub stats: TraversalStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TraversalStats {
pub nodes_visited: usize,
pub edges_traversed: usize,
pub max_depth_reached: u32,
pub path_found: bool,
}
#[derive(Clone, PartialEq)]
struct DijkstraState {
cost: f64,
node: String,
}
impl Eq for DijkstraState {}
impl Ord for DijkstraState {
fn cmp(&self, other: &Self) -> Ordering {
other
.cost
.partial_cmp(&self.cost)
.unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for DijkstraState {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub struct TraversalEngine;
impl TraversalEngine {
pub fn execute(
query: &TraversalQuery,
entities: &HashMap<String, Entity>,
relations: &[Relation],
) -> TraversalResult {
tracing::debug!(
"Executing traversal: start={}, target={:?}, depth={}, direction={:?}",
query.start,
query.target,
query.max_depth,
query.direction
);
if query.target.is_some() {
if query.use_weights {
Self::dijkstra_path(query, entities, relations)
} else {
Self::bfs_path(query, entities, relations)
}
} else {
Self::filtered_bfs(query, entities, relations)
}
}
fn bfs_path(
query: &TraversalQuery,
entities: &HashMap<String, Entity>,
relations: &[Relation],
) -> TraversalResult {
let target = query.target.as_ref().unwrap();
let mut visited: HashSet<String> = HashSet::new();
let mut parent: HashMap<String, (String, PathEdge)> = HashMap::new();
let mut queue: VecDeque<(String, u32)> = VecDeque::new();
let mut stats = TraversalStats::default();
queue.push_back((query.start.clone(), 0));
visited.insert(query.start.clone());
while let Some((current, depth)) = queue.pop_front() {
stats.nodes_visited += 1;
stats.max_depth_reached = stats.max_depth_reached.max(depth);
if ¤t == target {
stats.path_found = true;
tracing::debug!("BFS found path at depth {}", depth);
break;
}
if depth >= query.max_depth {
continue;
}
for rel in Self::get_neighbors(¤t, &query.direction, relations) {
stats.edges_traversed += 1;
if !query.relation_type_filter.is_empty()
&& !query.relation_type_filter.contains(&rel.relation_type)
{
continue;
}
let next = if rel.from_name == current {
&rel.to_name
} else {
&rel.from_name
};
if let Some(entity) = entities.get(next) {
if !query.entity_type_filter.is_empty()
&& !query.entity_type_filter.contains(&entity.entity_type.0)
{
continue;
}
}
if !visited.contains(next) {
visited.insert(next.clone());
parent.insert(
next.clone(),
(
current.clone(),
PathEdge {
from: rel.from_name.clone(),
to: rel.to_name.clone(),
relation_type: rel.relation_type.clone(),
weight: rel.weight,
},
),
);
queue.push_back((next.clone(), depth + 1));
}
}
}
let paths = if stats.path_found {
vec![Self::reconstruct_path(&query.start, target, &parent)]
} else {
vec![]
};
Self::build_result(query, paths, &visited, entities, relations, stats)
}
fn dijkstra_path(
query: &TraversalQuery,
entities: &HashMap<String, Entity>,
relations: &[Relation],
) -> TraversalResult {
let target = query.target.as_ref().unwrap();
let mut dist: HashMap<String, f64> = HashMap::new();
let mut parent: HashMap<String, (String, PathEdge)> = HashMap::new();
let mut heap = BinaryHeap::new();
let mut stats = TraversalStats::default();
dist.insert(query.start.clone(), 0.0);
heap.push(DijkstraState {
cost: 0.0,
node: query.start.clone(),
});
while let Some(DijkstraState { cost, node }) = heap.pop() {
stats.nodes_visited += 1;
if &node == target {
stats.path_found = true;
tracing::debug!("Dijkstra found path with cost {}", cost);
break;
}
if cost > *dist.get(&node).unwrap_or(&f64::INFINITY) {
continue;
}
for rel in Self::get_neighbors(&node, &query.direction, relations) {
stats.edges_traversed += 1;
if !query.relation_type_filter.is_empty()
&& !query.relation_type_filter.contains(&rel.relation_type)
{
continue;
}
let next = if rel.from_name == node {
&rel.to_name
} else {
&rel.from_name
};
if let Some(entity) = entities.get(next) {
if !query.entity_type_filter.is_empty()
&& !query.entity_type_filter.contains(&entity.entity_type.0)
{
continue;
}
}
let edge_weight = rel.weight.unwrap_or(1.0);
let new_cost = cost + edge_weight;
if new_cost < *dist.get(next).unwrap_or(&f64::INFINITY) {
dist.insert(next.clone(), new_cost);
parent.insert(
next.clone(),
(
node.clone(),
PathEdge {
from: rel.from_name.clone(),
to: rel.to_name.clone(),
relation_type: rel.relation_type.clone(),
weight: rel.weight,
},
),
);
heap.push(DijkstraState {
cost: new_cost,
node: next.clone(),
});
}
}
}
let paths = if stats.path_found {
vec![Self::reconstruct_path(&query.start, target, &parent)]
} else {
vec![]
};
let visited: HashSet<String> = dist.keys().cloned().collect();
Self::build_result(query, paths, &visited, entities, relations, stats)
}
fn filtered_bfs(
query: &TraversalQuery,
entities: &HashMap<String, Entity>,
relations: &[Relation],
) -> TraversalResult {
let mut visited: HashSet<String> = HashSet::new();
let mut queue: VecDeque<(String, u32)> = VecDeque::new();
let mut stats = TraversalStats::default();
queue.push_back((query.start.clone(), 0));
visited.insert(query.start.clone());
while let Some((current, depth)) = queue.pop_front() {
stats.nodes_visited += 1;
stats.max_depth_reached = stats.max_depth_reached.max(depth);
if depth >= query.max_depth {
continue;
}
for rel in Self::get_neighbors(¤t, &query.direction, relations) {
stats.edges_traversed += 1;
if !query.relation_type_filter.is_empty()
&& !query.relation_type_filter.contains(&rel.relation_type)
{
continue;
}
let next = if rel.from_name == current {
&rel.to_name
} else {
&rel.from_name
};
if let Some(entity) = entities.get(next) {
if !query.entity_type_filter.is_empty()
&& !query.entity_type_filter.contains(&entity.entity_type.0)
{
continue;
}
}
if !visited.contains(next) {
visited.insert(next.clone());
queue.push_back((next.clone(), depth + 1));
}
}
}
tracing::debug!(
"Filtered BFS visited {} nodes, traversed {} edges",
stats.nodes_visited,
stats.edges_traversed
);
Self::build_result(query, vec![], &visited, entities, relations, stats)
}
fn get_neighbors<'a>(
node: &str,
direction: &Direction,
relations: &'a [Relation],
) -> Vec<&'a Relation> {
relations
.iter()
.filter(|rel| match direction {
Direction::Outgoing => rel.from_name == node,
Direction::Incoming => rel.to_name == node,
Direction::Both => rel.from_name == node || rel.to_name == node,
})
.collect()
}
fn reconstruct_path(
start: &str,
end: &str,
parent: &HashMap<String, (String, PathEdge)>,
) -> GraphPath {
let mut nodes = vec![end.to_string()];
let mut edges = Vec::new();
let mut current = end.to_string();
let mut total_weight = 0.0;
while ¤t != start {
if let Some((prev, edge)) = parent.get(¤t) {
total_weight += edge.weight.unwrap_or(1.0);
edges.push(edge.clone());
nodes.push(prev.clone());
current = prev.clone();
} else {
break;
}
}
nodes.reverse();
edges.reverse();
GraphPath {
length: edges.len(),
nodes,
edges,
total_weight,
}
}
fn build_result(
query: &TraversalQuery,
paths: Vec<GraphPath>,
visited: &HashSet<String>,
entities: &HashMap<String, Entity>,
relations: &[Relation],
stats: TraversalStats,
) -> TraversalResult {
let visited_entities: Vec<String> = visited.iter().cloned().collect();
let result_entities: Vec<Entity> = visited_entities
.iter()
.filter_map(|name| entities.get(name).cloned())
.collect();
let result_relations: Vec<Relation> = relations
.iter()
.filter(|r| visited.contains(&r.from_name) && visited.contains(&r.to_name))
.cloned()
.collect();
TraversalResult {
start: query.start.clone(),
target: query.target.clone(),
paths,
visited_entities,
entities: result_entities,
relations: result_relations,
stats,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::project::ProjectId;
fn create_test_graph() -> (HashMap<String, Entity>, Vec<Relation>) {
let project_id = ProjectId::new();
let mut entities = HashMap::new();
for name in ["A", "B", "C", "D", "E", "F"] {
let entity = Entity::new(project_id.clone(), name, "node");
entities.insert(name.to_string(), entity);
}
let relations = vec![
Relation::from_names(project_id.clone(), "A", "B", "connects").with_weight(1.0),
Relation::from_names(project_id.clone(), "B", "C", "connects").with_weight(2.0),
Relation::from_names(project_id.clone(), "C", "D", "connects").with_weight(1.0),
Relation::from_names(project_id.clone(), "B", "E", "connects").with_weight(1.0),
Relation::from_names(project_id.clone(), "C", "F", "connects").with_weight(1.0),
Relation::from_names(project_id.clone(), "E", "F", "connects").with_weight(3.0),
];
(entities, relations)
}
#[test]
fn test_bfs_shortest_path() {
let (entities, relations) = create_test_graph();
let query = TraversalQuery::new("A").find_path_to("D");
let result = TraversalEngine::execute(&query, &entities, &relations);
assert!(result.stats.path_found);
assert_eq!(result.paths.len(), 1);
assert_eq!(result.paths[0].nodes, vec!["A", "B", "C", "D"]);
assert_eq!(result.paths[0].length, 3);
}
#[test]
fn test_dijkstra_weighted_path() {
let (entities, relations) = create_test_graph();
let query = TraversalQuery::new("A").find_path_to("F").weighted();
let result = TraversalEngine::execute(&query, &entities, &relations);
assert!(result.stats.path_found);
assert_eq!(result.paths[0].nodes, vec!["A", "B", "C", "F"]);
assert!((result.paths[0].total_weight - 4.0).abs() < 0.001);
}
#[test]
fn test_filtered_traversal() {
let (entities, relations) = create_test_graph();
let query = TraversalQuery::new("A").with_depth(2);
let result = TraversalEngine::execute(&query, &entities, &relations);
assert!(result.visited_entities.contains(&"A".to_string()));
assert!(result.visited_entities.contains(&"B".to_string()));
assert!(result.visited_entities.contains(&"C".to_string()));
assert!(result.visited_entities.contains(&"E".to_string()));
}
#[test]
fn test_no_path_found() {
let project_id = ProjectId::new();
let mut entities = HashMap::new();
entities.insert("A".to_string(), Entity::new(project_id.clone(), "A", "node"));
entities.insert("B".to_string(), Entity::new(project_id.clone(), "B", "node"));
let query = TraversalQuery::new("A").find_path_to("B");
let result = TraversalEngine::execute(&query, &entities, &[]);
assert!(!result.stats.path_found);
assert!(result.paths.is_empty());
}
#[test]
fn test_direction_filtering() {
let (entities, relations) = create_test_graph();
let outgoing = TraversalQuery::new("B")
.with_direction(Direction::Outgoing)
.with_depth(1);
let result = TraversalEngine::execute(&outgoing, &entities, &relations);
assert!(result.visited_entities.contains(&"C".to_string()));
assert!(result.visited_entities.contains(&"E".to_string()));
assert!(!result.visited_entities.contains(&"A".to_string()));
let incoming = TraversalQuery::new("B")
.with_direction(Direction::Incoming)
.with_depth(1);
let result = TraversalEngine::execute(&incoming, &entities, &relations);
assert!(result.visited_entities.contains(&"A".to_string()));
assert!(!result.visited_entities.contains(&"C".to_string()));
}
#[test]
fn test_relation_type_filter() {
let project_id = ProjectId::new();
let mut entities = HashMap::new();
for name in ["A", "B", "C"] {
entities.insert(
name.to_string(),
Entity::new(project_id.clone(), name, "node"),
);
}
let relations = vec![
Relation::from_names(project_id.clone(), "A", "B", "works_at"),
Relation::from_names(project_id.clone(), "B", "C", "knows"),
];
let query = TraversalQuery::new("A")
.with_depth(2)
.filter_relation_types(vec!["works_at".to_string()]);
let result = TraversalEngine::execute(&query, &entities, &relations);
assert!(result.visited_entities.contains(&"B".to_string()));
assert!(!result.visited_entities.contains(&"C".to_string()));
}
}