use std::collections::{HashMap, HashSet, VecDeque};
use ainl_contracts::{Feature, FeatureId};
use thiserror::Error;
#[derive(Debug, Error, PartialEq, Eq)]
pub enum DagError {
#[error("unknown precondition feature {0}")]
UnknownPrecondition(String),
#[error("cycle detected involving features: {0}")]
CycleDetected(String),
#[error("no features provided")]
Empty,
}
pub fn detect_cycle(features: &[Feature]) -> Result<(), DagError> {
if features.is_empty() {
return Err(DagError::Empty);
}
let ids: HashSet<&str> = features.iter().map(|f| f.feature_id.as_str()).collect();
let mut indegree: HashMap<&str, usize> = ids.iter().map(|id| (*id, 0)).collect();
let mut adj: HashMap<&str, Vec<&str>> = ids.iter().map(|id| (*id, Vec::new())).collect();
for f in features {
for pre in &f.preconditions {
let pid = pre.as_str();
if !ids.contains(pid) {
return Err(DagError::UnknownPrecondition(pid.to_string()));
}
*indegree.get_mut(f.feature_id.as_str()).unwrap() += 1;
adj.get_mut(pid).unwrap().push(f.feature_id.as_str());
}
}
let mut queue: VecDeque<&str> = indegree
.iter()
.filter(|(_, &d)| d == 0)
.map(|(id, _)| *id)
.collect();
let mut visited = 0usize;
while let Some(id) = queue.pop_front() {
visited += 1;
if let Some(children) = adj.get(id) {
for child in children {
let d = indegree.get_mut(child).unwrap();
*d -= 1;
if *d == 0 {
queue.push_back(child);
}
}
}
}
if visited == features.len() {
Ok(())
} else {
let mut stuck: Vec<&str> = indegree
.iter()
.filter(|(_, &d)| d > 0)
.map(|(id, _)| *id)
.collect();
stuck.sort_unstable();
Err(DagError::CycleDetected(stuck.join(", ")))
}
}
pub fn topological_layers(features: &[Feature]) -> Result<Vec<Vec<FeatureId>>, DagError> {
detect_cycle(features)?;
if features.is_empty() {
return Err(DagError::Empty);
}
let mut indegree: HashMap<&str, usize> = features
.iter()
.map(|f| (f.feature_id.as_str(), f.preconditions.len()))
.collect();
let mut adj: HashMap<&str, Vec<&str>> = features
.iter()
.map(|f| (f.feature_id.as_str(), Vec::new()))
.collect();
let id_set: HashSet<&str> = indegree.keys().copied().collect();
for f in features {
for pre in &f.preconditions {
let pid = pre.as_str();
if !id_set.contains(pid) {
return Err(DagError::UnknownPrecondition(pid.to_string()));
}
adj.get_mut(pid).unwrap().push(f.feature_id.as_str());
}
}
let mut layers = Vec::new();
let mut remaining = features.len();
while remaining > 0 {
let ready: Vec<&str> = indegree
.iter()
.filter(|(_, &d)| d == 0)
.map(|(id, _)| *id)
.collect();
if ready.is_empty() {
let mut stuck: Vec<&str> = indegree
.iter()
.filter(|(_, &d)| d > 0)
.map(|(id, _)| *id)
.collect();
stuck.sort_unstable();
return Err(DagError::CycleDetected(stuck.join(", ")));
}
let mut layer: Vec<FeatureId> = ready.iter().map(|id| FeatureId((*id).to_string())).collect();
layer.sort_by(|a, b| a.as_str().cmp(b.as_str()));
for id in &layer {
indegree.remove(id.as_str());
remaining -= 1;
if let Some(children) = adj.get(id.as_str()) {
for child in children {
if let Some(d) = indegree.get_mut(child) {
*d = d.saturating_sub(1);
}
}
}
}
layers.push(layer);
}
Ok(layers)
}
#[cfg(test)]
mod tests {
use super::*;
use ainl_contracts::FeatureStatus;
fn feat(id: &str, pre: &[&str]) -> Feature {
Feature {
feature_id: FeatureId(id.into()),
description: id.into(),
status: FeatureStatus::Pending,
milestone: None,
skill_name: None,
touches_files: vec![],
preconditions: pre.iter().map(|p| FeatureId((*p).into())).collect(),
expected_behavior: vec![],
verification_steps: vec![],
fulfills: vec![],
snapshot: None,
}
}
#[test]
fn linear_chain_layers() {
let features = vec![feat("a", &[]), feat("b", &["a"]), feat("c", &["b"])];
let layers = topological_layers(&features).unwrap();
assert_eq!(layers.len(), 3);
assert_eq!(layers[0][0].as_str(), "a");
assert_eq!(layers[2][0].as_str(), "c");
}
#[test]
fn diamond_dag_two_layers_middle() {
let features = vec![
feat("root", &[]),
feat("left", &["root"]),
feat("right", &["root"]),
feat("merge", &["left", "right"]),
];
let layers = topological_layers(&features).unwrap();
assert_eq!(layers.len(), 3);
assert_eq!(layers[1].len(), 2);
}
#[test]
fn cycle_reported() {
let features = vec![feat("a", &["b"]), feat("b", &["a"])];
assert!(matches!(
detect_cycle(&features),
Err(DagError::CycleDetected(_))
));
}
#[test]
fn unknown_precondition() {
let features = vec![feat("a", &["missing"])];
assert_eq!(
detect_cycle(&features),
Err(DagError::UnknownPrecondition("missing".into()))
);
}
}