use petgraph::Direction;
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::visit::EdgeRef;
use std::collections::HashMap;
use crate::model::{Item, ItemId, ItemType, RelationshipType};
#[derive(Debug)]
pub struct KnowledgeGraph {
graph: DiGraph<Item, RelationshipType>,
index: HashMap<ItemId, NodeIndex>,
strict_mode: bool,
}
impl KnowledgeGraph {
pub fn new(strict_mode: bool) -> Self {
Self {
graph: DiGraph::new(),
index: HashMap::new(),
strict_mode,
}
}
pub fn is_strict_mode(&self) -> bool {
self.strict_mode
}
pub fn item_count(&self) -> usize {
self.graph.node_count()
}
pub fn relationship_count(&self) -> usize {
self.graph.edge_count()
}
pub fn add_item(&mut self, item: Item) -> NodeIndex {
let id = item.id.clone();
let idx = self.graph.add_node(item);
self.index.insert(id, idx);
idx
}
pub fn add_relationship(
&mut self,
from: &ItemId,
to: &ItemId,
rel_type: RelationshipType,
) -> Option<()> {
let from_idx = self.index.get(from)?;
let to_idx = self.index.get(to)?;
self.graph.add_edge(*from_idx, *to_idx, rel_type);
Some(())
}
pub fn get(&self, id: &ItemId) -> Option<&Item> {
let idx = self.index.get(id)?;
self.graph.node_weight(*idx)
}
pub fn get_mut(&mut self, id: &ItemId) -> Option<&mut Item> {
let idx = self.index.get(id)?;
self.graph.node_weight_mut(*idx)
}
pub fn contains(&self, id: &ItemId) -> bool {
self.index.contains_key(id)
}
pub fn items(&self) -> impl Iterator<Item = &Item> {
self.graph.node_weights()
}
pub fn item_ids(&self) -> impl Iterator<Item = &ItemId> {
self.index.keys()
}
pub fn items_by_type(&self, item_type: ItemType) -> Vec<&Item> {
self.graph
.node_weights()
.filter(|item| item.item_type == item_type)
.collect()
}
pub fn count_by_type(&self) -> HashMap<ItemType, usize> {
let mut counts = HashMap::new();
for item in self.graph.node_weights() {
*counts.entry(item.item_type).or_insert(0) += 1;
}
counts
}
pub fn parents(&self, id: &ItemId) -> Vec<&Item> {
let Some(idx) = self.index.get(id) else {
return Vec::new();
};
self.graph
.edges_directed(*idx, Direction::Outgoing)
.filter(|edge| edge.weight().is_upstream())
.filter_map(|edge| self.graph.node_weight(edge.target()))
.collect()
}
pub fn children(&self, id: &ItemId) -> Vec<&Item> {
let Some(idx) = self.index.get(id) else {
return Vec::new();
};
self.graph
.edges_directed(*idx, Direction::Incoming)
.filter(|edge| edge.weight().is_upstream())
.filter_map(|edge| self.graph.node_weight(edge.source()))
.collect()
}
pub fn orphans(&self) -> Vec<&Item> {
self.graph
.node_weights()
.filter(|item| {
if item.item_type.is_root() {
return false;
}
item.upstream.is_empty()
})
.collect()
}
pub fn inner(&self) -> &DiGraph<Item, RelationshipType> {
&self.graph
}
pub fn inner_mut(&mut self) -> &mut DiGraph<Item, RelationshipType> {
&mut self.graph
}
pub fn node_index(&self, id: &ItemId) -> Option<NodeIndex> {
self.index.get(id).copied()
}
pub fn has_cycles(&self) -> bool {
petgraph::algo::is_cyclic_directed(&self.graph)
}
pub fn relationships(&self) -> Vec<(ItemId, ItemId, RelationshipType)> {
self.graph
.edge_references()
.filter_map(|edge| {
let from = self.graph.node_weight(edge.source())?;
let to = self.graph.node_weight(edge.target())?;
Some((from.id.clone(), to.id.clone(), *edge.weight()))
})
.collect()
}
}
impl Default for KnowledgeGraph {
fn default() -> Self {
Self::new(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::{ItemBuilder, SourceLocation};
use std::path::PathBuf;
fn create_test_item(id: &str, item_type: ItemType) -> Item {
let source = SourceLocation::new(PathBuf::from("/repo"), format!("{}.md", id), 1);
let mut builder = ItemBuilder::new()
.id(ItemId::new_unchecked(id))
.item_type(item_type)
.name(format!("Test {}", id))
.source(source);
if item_type.requires_specification() {
builder = builder.specification("Test specification");
}
builder.build().unwrap()
}
#[test]
fn test_add_and_get_item() {
let mut graph = KnowledgeGraph::new(false);
let item = create_test_item("SOL-001", ItemType::Solution);
graph.add_item(item);
let id = ItemId::new_unchecked("SOL-001");
assert!(graph.contains(&id));
assert_eq!(graph.get(&id).unwrap().name, "Test SOL-001");
}
#[test]
fn test_items_by_type() {
let mut graph = KnowledgeGraph::new(false);
graph.add_item(create_test_item("SOL-001", ItemType::Solution));
graph.add_item(create_test_item("UC-001", ItemType::UseCase));
graph.add_item(create_test_item("UC-002", ItemType::UseCase));
let solutions = graph.items_by_type(ItemType::Solution);
assert_eq!(solutions.len(), 1);
let use_cases = graph.items_by_type(ItemType::UseCase);
assert_eq!(use_cases.len(), 2);
}
#[test]
fn test_item_count() {
let mut graph = KnowledgeGraph::new(false);
assert_eq!(graph.item_count(), 0);
graph.add_item(create_test_item("SOL-001", ItemType::Solution));
assert_eq!(graph.item_count(), 1);
graph.add_item(create_test_item("UC-001", ItemType::UseCase));
assert_eq!(graph.item_count(), 2);
}
}