use std::collections::{HashMap, HashSet};
use crate::index::AnnotationId;
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
#[error("annotation parent graph contains a cycle: {}", format_path(path))]
pub struct CycleError {
pub path: Vec<AnnotationId>,
}
fn format_path(path: &[AnnotationId]) -> String {
let mut s = String::new();
for (i, id) in path.iter().enumerate() {
if i > 0 {
s.push_str(" -> ");
}
s.push_str(id.as_str());
}
if path.len() > 1 {
s.push_str(" -> ");
s.push_str(path[0].as_str());
} else if path.len() == 1 {
s.push_str(" -> ");
s.push_str(path[0].as_str());
}
s
}
#[aristo::intent(
"One cycle reported per call, then return. This is intentional, not \
incomplete — extending to enumerate all cycles would multiply \
diagnostic noise without helping the fix.",
verify = "neural",
id = "detect_cycles_returns_first_cycle_only"
)]
pub fn detect_cycles(parents: &HashMap<AnnotationId, Vec<AnnotationId>>) -> Result<(), CycleError> {
let mut visited: HashSet<AnnotationId> = HashSet::new();
let mut on_stack: HashSet<AnnotationId> = HashSet::new();
let mut path: Vec<AnnotationId> = Vec::new();
for start in parents.keys() {
if visited.contains(start) {
continue;
}
if let Some(cycle) = dfs(start, parents, &mut visited, &mut on_stack, &mut path) {
return Err(CycleError { path: cycle });
}
}
Ok(())
}
fn dfs(
node: &AnnotationId,
parents: &HashMap<AnnotationId, Vec<AnnotationId>>,
visited: &mut HashSet<AnnotationId>,
on_stack: &mut HashSet<AnnotationId>,
path: &mut Vec<AnnotationId>,
) -> Option<Vec<AnnotationId>> {
visited.insert(node.clone());
on_stack.insert(node.clone());
path.push(node.clone());
if let Some(children) = parents.get(node) {
for child in children {
if on_stack.contains(child) {
let start = path
.iter()
.position(|p| p == child)
.expect("on_stack ⇒ in path");
return Some(path[start..].to_vec());
}
if !visited.contains(child) {
if let Some(cycle) = dfs(child, parents, visited, on_stack, path) {
return Some(cycle);
}
}
}
}
on_stack.remove(node);
path.pop();
None
}
#[cfg(test)]
mod tests {
use super::*;
fn id(s: &str) -> AnnotationId {
AnnotationId::parse(s).expect("test id must parse")
}
fn graph(edges: &[(&str, &[&str])]) -> HashMap<AnnotationId, Vec<AnnotationId>> {
edges
.iter()
.map(|(node, parents)| (id(node), parents.iter().map(|p| id(p)).collect::<Vec<_>>()))
.collect()
}
#[test]
fn empty_graph_is_acyclic() {
let g = HashMap::new();
detect_cycles(&g).unwrap();
}
#[test]
fn linear_chain_is_acyclic() {
let g = graph(&[("a", &["b"]), ("b", &["c"]), ("c", &["d"]), ("d", &[])]);
detect_cycles(&g).unwrap();
}
#[test]
fn self_cycle_detected() {
let g = graph(&[("a", &["a"])]);
let err = detect_cycles(&g).unwrap_err();
assert_eq!(err.path, vec![id("a")]);
assert!(err.to_string().contains("a -> a"));
}
#[test]
fn two_cycle_detected() {
let g = graph(&[("a", &["b"]), ("b", &["a"])]);
let err = detect_cycles(&g).unwrap_err();
assert_eq!(err.path.len(), 2);
assert!(err.path.contains(&id("a")));
assert!(err.path.contains(&id("b")));
}
#[test]
fn three_cycle_detected() {
let g = graph(&[("a", &["b"]), ("b", &["c"]), ("c", &["a"])]);
let err = detect_cycles(&g).unwrap_err();
assert_eq!(err.path.len(), 3);
}
#[test]
fn diamond_is_acyclic() {
let g = graph(&[("a", &["b", "c"]), ("b", &["d"]), ("c", &["d"]), ("d", &[])]);
detect_cycles(&g).unwrap();
}
#[test]
fn multiple_disconnected_components_acyclic() {
let g = graph(&[
("a", &["b"]),
("b", &[]),
("c", &["d"]),
("d", &["e"]),
("e", &[]),
]);
detect_cycles(&g).unwrap();
}
#[test]
fn cycle_in_one_component_detected_when_other_is_clean() {
let g = graph(&[("a", &["b"]), ("b", &[]), ("c", &["d"]), ("d", &["c"])]);
let err = detect_cycles(&g).unwrap_err();
assert!(err.path.contains(&id("c")));
assert!(err.path.contains(&id("d")));
}
#[test]
fn orphan_parent_reference_is_not_a_cycle_error() {
let g = graph(&[("a", &["nonexistent_parent"])]);
detect_cycles(&g).unwrap();
}
#[test]
fn multi_parent_node_with_one_cycle_branch_detected() {
let g = graph(&[("a", &["b", "c"]), ("b", &["a"]), ("c", &[])]);
let err = detect_cycles(&g).unwrap_err();
assert_eq!(err.path.len(), 2);
}
#[test]
fn long_cycle_detected() {
let edges: Vec<(String, Vec<String>)> = (0..10)
.map(|i| (format!("n{i}"), vec![format!("n{}", (i + 1) % 10)]))
.collect();
let g: HashMap<AnnotationId, Vec<AnnotationId>> = edges
.iter()
.map(|(n, ps)| (id(n), ps.iter().map(|p| id(p)).collect()))
.collect();
let err = detect_cycles(&g).unwrap_err();
assert_eq!(err.path.len(), 10);
}
#[test]
fn cycle_error_message_renders_path_with_arrows() {
let g = graph(&[("a", &["b"]), ("b", &["c"]), ("c", &["a"])]);
let err = detect_cycles(&g).unwrap_err();
let msg = err.to_string();
assert_eq!(msg.matches(" -> ").count(), 3, "got: {msg}");
}
#[test]
fn aristos_namespace_ids_work() {
let g = graph(&[
("aristos:foo", &["aristos:bar"]),
("aristos:bar", &["aristos:foo"]),
]);
let err = detect_cycles(&g).unwrap_err();
assert_eq!(err.path.len(), 2);
}
}