use std::collections::{HashSet, VecDeque};
use petgraph::Direction;
use petgraph::graph::NodeIndex;
use petgraph::visit::EdgeRef;
use crate::graph::KnowledgeGraph;
use crate::model::{ItemId, ItemType, RelationshipType};
#[derive(Debug, Clone)]
pub struct TraversalResult {
pub origin: ItemId,
pub items: Vec<TraversalNode>,
pub max_depth: usize,
}
#[derive(Debug, Clone)]
pub struct TraversalNode {
pub item_id: ItemId,
pub depth: usize,
pub relationship: Option<RelationshipType>,
pub parent: Option<ItemId>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TraversalDirection {
Upstream,
Downstream,
}
#[derive(Debug, Clone, Default)]
pub struct TraversalOptions {
pub max_depth: Option<usize>,
pub type_filter: Vec<ItemType>,
}
impl TraversalOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_depth(mut self, depth: usize) -> Self {
self.max_depth = Some(depth);
self
}
pub fn with_types(mut self, types: Vec<ItemType>) -> Self {
self.type_filter = types;
self
}
}
pub fn traverse_upstream(
graph: &KnowledgeGraph,
start: &ItemId,
options: &TraversalOptions,
) -> Option<TraversalResult> {
traverse_graph(graph, start, TraversalDirection::Upstream, options)
}
pub fn traverse_downstream(
graph: &KnowledgeGraph,
start: &ItemId,
options: &TraversalOptions,
) -> Option<TraversalResult> {
traverse_graph(graph, start, TraversalDirection::Downstream, options)
}
fn traverse_graph(
graph: &KnowledgeGraph,
start: &ItemId,
direction: TraversalDirection,
options: &TraversalOptions,
) -> Option<TraversalResult> {
let start_idx = graph.node_index(start)?;
let inner = graph.inner();
let mut visited: HashSet<NodeIndex> = HashSet::new();
let mut queue: VecDeque<(NodeIndex, usize, Option<RelationshipType>, Option<ItemId>)> =
VecDeque::new();
let mut result_items: Vec<TraversalNode> = Vec::new();
let mut max_depth = 0;
queue.push_back((start_idx, 0, None, None));
visited.insert(start_idx);
while let Some((node_idx, depth, relationship, display_parent)) = queue.pop_front() {
if let Some(max) = options.max_depth
&& depth > max
{
continue;
}
if let Some(item) = inner.node_weight(node_idx) {
let matches_filter =
options.type_filter.is_empty() || options.type_filter.contains(&item.item_type);
let next_display_parent = if matches_filter {
Some(item.id.clone())
} else {
display_parent.clone()
};
if matches_filter {
result_items.push(TraversalNode {
item_id: item.id.clone(),
depth,
relationship,
parent: display_parent,
});
max_depth = max_depth.max(depth);
}
let edges = match direction {
TraversalDirection::Upstream => {
inner
.edges_directed(node_idx, Direction::Outgoing)
.filter(|e| e.weight().is_upstream())
.map(|e| (e.target(), *e.weight()))
.collect::<Vec<_>>()
}
TraversalDirection::Downstream => {
let mut edges = Vec::new();
for edge in inner.edges_directed(node_idx, Direction::Incoming) {
if edge.weight().is_upstream() {
edges.push((edge.source(), edge.weight().inverse()));
}
}
for edge in inner.edges_directed(node_idx, Direction::Outgoing) {
if edge.weight().is_downstream() {
edges.push((edge.target(), *edge.weight()));
}
}
edges
}
};
let next_depth = depth + 1;
if options.max_depth.is_none_or(|max| next_depth <= max) {
for (target_idx, rel_type) in edges {
if !visited.contains(&target_idx) {
visited.insert(target_idx);
queue.push_back((
target_idx,
next_depth,
Some(rel_type),
next_display_parent.clone(),
));
}
}
}
}
}
Some(TraversalResult {
origin: start.clone(),
items: result_items,
max_depth,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::KnowledgeGraphBuilder;
use crate::model::{Relationship, RelationshipType};
use crate::test_utils::{create_test_item, create_test_item_with_relationships};
#[test]
fn test_upstream_traversal() {
let sol = create_test_item("SOL-001", ItemType::Solution);
let uc = create_test_item_with_relationships(
"UC-001",
ItemType::UseCase,
vec![Relationship::new(
ItemId::new_unchecked("SOL-001"),
RelationshipType::Refines,
)],
);
let scen = create_test_item_with_relationships(
"SCEN-001",
ItemType::Scenario,
vec![Relationship::new(
ItemId::new_unchecked("UC-001"),
RelationshipType::Refines,
)],
);
let graph = KnowledgeGraphBuilder::new()
.add_item(sol)
.add_item(uc)
.add_item(scen)
.build()
.unwrap();
let result = traverse_upstream(
&graph,
&ItemId::new_unchecked("SCEN-001"),
&TraversalOptions::new(),
);
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.items.len(), 3); assert_eq!(result.max_depth, 2);
}
#[test]
fn test_downstream_traversal() {
let sol = create_test_item("SOL-001", ItemType::Solution);
let uc = create_test_item_with_relationships(
"UC-001",
ItemType::UseCase,
vec![Relationship::new(
ItemId::new_unchecked("SOL-001"),
RelationshipType::Refines,
)],
);
let graph = KnowledgeGraphBuilder::new()
.add_item(sol)
.add_item(uc)
.build()
.unwrap();
let result = traverse_downstream(
&graph,
&ItemId::new_unchecked("SOL-001"),
&TraversalOptions::new(),
);
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.items.len(), 2); }
#[test]
fn test_depth_limited_traversal() {
let sol = create_test_item("SOL-001", ItemType::Solution);
let uc = create_test_item_with_relationships(
"UC-001",
ItemType::UseCase,
vec![Relationship::new(
ItemId::new_unchecked("SOL-001"),
RelationshipType::Refines,
)],
);
let scen = create_test_item_with_relationships(
"SCEN-001",
ItemType::Scenario,
vec![Relationship::new(
ItemId::new_unchecked("UC-001"),
RelationshipType::Refines,
)],
);
let graph = KnowledgeGraphBuilder::new()
.add_item(sol)
.add_item(uc)
.add_item(scen)
.build()
.unwrap();
let result = traverse_upstream(
&graph,
&ItemId::new_unchecked("SCEN-001"),
&TraversalOptions::new().with_max_depth(1),
);
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.max_depth, 1);
assert!(result.items.len() <= 2);
}
#[test]
fn test_type_filtered_traversal() {
let sol = create_test_item("SOL-001", ItemType::Solution);
let uc = create_test_item_with_relationships(
"UC-001",
ItemType::UseCase,
vec![Relationship::new(
ItemId::new_unchecked("SOL-001"),
RelationshipType::Refines,
)],
);
let graph = KnowledgeGraphBuilder::new()
.add_item(sol)
.add_item(uc)
.build()
.unwrap();
let result = traverse_upstream(
&graph,
&ItemId::new_unchecked("UC-001"),
&TraversalOptions::new().with_types(vec![ItemType::Solution]),
);
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.items.len(), 1);
assert_eq!(result.items[0].item_id, ItemId::new_unchecked("SOL-001"));
}
}