use std::collections::{HashSet, VecDeque};
use petgraph::Direction;
use petgraph::graph::NodeIndex;
use petgraph::visit::EdgeRef;
use crate::graph::KnowledgeGraph;
use crate::model::{Item, ItemId, ItemType, RelationshipType};
#[derive(Debug, Clone)]
pub struct TraversalResult {
pub origin: ItemId,
pub items: Vec<TraversalNode>,
pub max_depth: usize,
}
#[derive(Debug, Clone)]
pub struct TraversalNode {
pub item_id: ItemId,
pub depth: usize,
pub relationship: Option<RelationshipType>,
pub parent: Option<ItemId>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TraversalDirection {
Upstream,
Downstream,
}
#[derive(Debug, Clone, Default)]
pub struct TraversalOptions {
pub max_depth: Option<usize>,
pub type_filter: Vec<ItemType>,
}
impl TraversalOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_depth(mut self, depth: usize) -> Self {
self.max_depth = Some(depth);
self
}
pub fn with_types(mut self, types: Vec<ItemType>) -> Self {
self.type_filter = types;
self
}
}
pub fn traverse_upstream(
graph: &KnowledgeGraph,
start: &ItemId,
options: &TraversalOptions,
) -> Option<TraversalResult> {
traverse_graph(graph, start, TraversalDirection::Upstream, options)
}
pub fn traverse_downstream(
graph: &KnowledgeGraph,
start: &ItemId,
options: &TraversalOptions,
) -> Option<TraversalResult> {
traverse_graph(graph, start, TraversalDirection::Downstream, options)
}
fn traverse_graph(
graph: &KnowledgeGraph,
start: &ItemId,
direction: TraversalDirection,
options: &TraversalOptions,
) -> Option<TraversalResult> {
let start_idx = graph.node_index(start)?;
let inner = graph.inner();
let mut visited: HashSet<NodeIndex> = HashSet::new();
let mut queue: VecDeque<(NodeIndex, usize, Option<RelationshipType>, Option<ItemId>)> =
VecDeque::new();
let mut result_items: Vec<TraversalNode> = Vec::new();
let mut max_depth = 0;
queue.push_back((start_idx, 0, None, None));
visited.insert(start_idx);
while let Some((node_idx, depth, relationship, display_parent)) = queue.pop_front() {
if let Some(max) = options.max_depth
&& depth > max
{
continue;
}
if let Some(item) = inner.node_weight(node_idx) {
let matches_filter =
options.type_filter.is_empty() || options.type_filter.contains(&item.item_type);
let next_display_parent = if matches_filter {
Some(item.id.clone())
} else {
display_parent.clone()
};
if matches_filter {
result_items.push(TraversalNode {
item_id: item.id.clone(),
depth,
relationship,
parent: display_parent,
});
max_depth = max_depth.max(depth);
}
let edges = match direction {
TraversalDirection::Upstream => {
inner
.edges_directed(node_idx, Direction::Outgoing)
.filter(|e| e.weight().is_upstream())
.map(|e| (e.target(), *e.weight()))
.collect::<Vec<_>>()
}
TraversalDirection::Downstream => {
let mut edges = Vec::new();
for edge in inner.edges_directed(node_idx, Direction::Incoming) {
if edge.weight().is_upstream() {
edges.push((edge.source(), edge.weight().inverse()));
}
}
for edge in inner.edges_directed(node_idx, Direction::Outgoing) {
if edge.weight().is_downstream() {
edges.push((edge.target(), *edge.weight()));
}
}
edges
}
};
let next_depth = depth + 1;
if options.max_depth.is_none_or(|max| next_depth <= max) {
for (target_idx, rel_type) in edges {
if !visited.contains(&target_idx) {
visited.insert(target_idx);
queue.push_back((
target_idx,
next_depth,
Some(rel_type),
next_display_parent.clone(),
));
}
}
}
}
}
Some(TraversalResult {
origin: start.clone(),
items: result_items,
max_depth,
})
}
pub fn get_upstream_parents<'a>(graph: &'a KnowledgeGraph, id: &ItemId) -> Vec<&'a Item> {
graph.parents(id)
}
pub fn get_downstream_children<'a>(graph: &'a KnowledgeGraph, id: &ItemId) -> Vec<&'a Item> {
graph.children(id)
}
#[derive(Debug, Clone)]
pub struct TraversalTree {
pub root: ItemId,
pub children: Vec<TraversalTreeNode>,
}
#[derive(Debug, Clone)]
pub struct TraversalTreeNode {
pub item_id: ItemId,
pub relationship: RelationshipType,
pub children: Vec<TraversalTreeNode>,
}
impl TraversalResult {
pub fn to_tree(&self, _graph: &KnowledgeGraph) -> Option<TraversalTree> {
if self.items.is_empty() {
return None;
}
let mut children_map: std::collections::HashMap<Option<ItemId>, Vec<&TraversalNode>> =
std::collections::HashMap::new();
for node in &self.items {
children_map
.entry(node.parent.clone())
.or_default()
.push(node);
}
fn build_children(
parent_id: &ItemId,
children_map: &std::collections::HashMap<Option<ItemId>, Vec<&TraversalNode>>,
) -> Vec<TraversalTreeNode> {
let Some(children) = children_map.get(&Some(parent_id.clone())) else {
return Vec::new();
};
children
.iter()
.map(|node| TraversalTreeNode {
item_id: node.item_id.clone(),
relationship: node.relationship.unwrap_or(RelationshipType::Refines),
children: build_children(&node.item_id, children_map),
})
.collect()
}
let root_children = children_map
.get(&None)
.map(|roots| {
if roots.is_empty() {
Vec::new()
} else {
build_children(&roots[0].item_id, &children_map)
}
})
.unwrap_or_default();
Some(TraversalTree {
root: self.origin.clone(),
children: root_children,
})
}
pub fn filter_by_type(
&self,
types: &[ItemType],
graph: &KnowledgeGraph,
) -> Vec<&TraversalNode> {
self.items
.iter()
.filter(|node| {
graph
.get(&node.item_id)
.is_some_and(|item| types.contains(&item.item_type))
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::GraphBuilder;
use crate::model::{ItemBuilder, SourceLocation, UpstreamRefs};
use std::path::PathBuf;
fn create_test_item(id: &str, item_type: ItemType) -> Item {
let source = SourceLocation::new(PathBuf::from("/repo"), format!("{}.md", id), 1);
let mut builder = ItemBuilder::new()
.id(ItemId::new_unchecked(id))
.item_type(item_type)
.name(format!("Test {}", id))
.source(source);
if item_type.requires_specification() {
builder = builder.specification("Test specification");
}
builder.build().unwrap()
}
fn create_test_item_with_upstream(
id: &str,
item_type: ItemType,
upstream: UpstreamRefs,
) -> Item {
let source = SourceLocation::new(PathBuf::from("/repo"), format!("{}.md", id), 1);
let mut builder = ItemBuilder::new()
.id(ItemId::new_unchecked(id))
.item_type(item_type)
.name(format!("Test {}", id))
.source(source)
.upstream(upstream);
if item_type.requires_specification() {
builder = builder.specification("Test specification");
}
builder.build().unwrap()
}
#[test]
fn test_upstream_traversal() {
let sol = create_test_item("SOL-001", ItemType::Solution);
let uc = create_test_item_with_upstream(
"UC-001",
ItemType::UseCase,
UpstreamRefs {
refines: vec![ItemId::new_unchecked("SOL-001")],
..Default::default()
},
);
let scen = create_test_item_with_upstream(
"SCEN-001",
ItemType::Scenario,
UpstreamRefs {
refines: vec![ItemId::new_unchecked("UC-001")],
..Default::default()
},
);
let graph = GraphBuilder::new()
.add_item(sol)
.add_item(uc)
.add_item(scen)
.build()
.unwrap();
let result = traverse_upstream(
&graph,
&ItemId::new_unchecked("SCEN-001"),
&TraversalOptions::new(),
);
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.items.len(), 3); assert_eq!(result.max_depth, 2);
}
#[test]
fn test_downstream_traversal() {
let sol = create_test_item("SOL-001", ItemType::Solution);
let uc = create_test_item_with_upstream(
"UC-001",
ItemType::UseCase,
UpstreamRefs {
refines: vec![ItemId::new_unchecked("SOL-001")],
..Default::default()
},
);
let graph = GraphBuilder::new()
.add_item(sol)
.add_item(uc)
.build()
.unwrap();
let result = traverse_downstream(
&graph,
&ItemId::new_unchecked("SOL-001"),
&TraversalOptions::new(),
);
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.items.len(), 2); }
#[test]
fn test_depth_limited_traversal() {
let sol = create_test_item("SOL-001", ItemType::Solution);
let uc = create_test_item_with_upstream(
"UC-001",
ItemType::UseCase,
UpstreamRefs {
refines: vec![ItemId::new_unchecked("SOL-001")],
..Default::default()
},
);
let scen = create_test_item_with_upstream(
"SCEN-001",
ItemType::Scenario,
UpstreamRefs {
refines: vec![ItemId::new_unchecked("UC-001")],
..Default::default()
},
);
let graph = GraphBuilder::new()
.add_item(sol)
.add_item(uc)
.add_item(scen)
.build()
.unwrap();
let result = traverse_upstream(
&graph,
&ItemId::new_unchecked("SCEN-001"),
&TraversalOptions::new().with_max_depth(1),
);
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.max_depth, 1);
assert!(result.items.len() <= 2);
}
#[test]
fn test_type_filtered_traversal() {
let sol = create_test_item("SOL-001", ItemType::Solution);
let uc = create_test_item_with_upstream(
"UC-001",
ItemType::UseCase,
UpstreamRefs {
refines: vec![ItemId::new_unchecked("SOL-001")],
..Default::default()
},
);
let graph = GraphBuilder::new()
.add_item(sol)
.add_item(uc)
.build()
.unwrap();
let result = traverse_upstream(
&graph,
&ItemId::new_unchecked("UC-001"),
&TraversalOptions::new().with_types(vec![ItemType::Solution]),
);
assert!(result.is_some());
let result = result.unwrap();
let filtered = result.filter_by_type(&[ItemType::Solution], &graph);
assert_eq!(filtered.len(), 1);
}
}