use petgraph::stable_graph::{EdgeIndex, NodeIndex, StableGraph};
use petgraph::Directed;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct DagEdge {
pub label: Option<String>,
}
#[derive(Debug, Clone)]
pub struct StableDag<N> {
inner: StableGraph<N, DagEdge, Directed>,
}
impl<N: Serialize> Serialize for StableDag<N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.inner.serialize(serializer)
}
}
impl<'de, N: Deserialize<'de>> Deserialize<'de> for StableDag<N> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let inner = StableGraph::<N, DagEdge, Directed>::deserialize(deserializer)?;
Ok(Self { inner })
}
}
impl<N> StableDag<N> {
pub fn new() -> Self {
Self {
inner: StableGraph::new(),
}
}
pub fn with_capacity(nodes: usize, edges: usize) -> Self {
Self {
inner: StableGraph::with_capacity(nodes, edges),
}
}
#[inline]
pub fn node_count(&self) -> usize {
self.inner.node_count()
}
#[inline]
pub fn edge_count(&self) -> usize {
self.inner.edge_count()
}
pub fn add_node(&mut self, weight: N) -> NodeIndex {
self.inner.add_node(weight)
}
pub fn node_weight(&self, idx: NodeIndex) -> Option<&N> {
self.inner.node_weight(idx)
}
pub fn node_weight_mut(&mut self, idx: NodeIndex) -> Option<&mut N> {
self.inner.node_weight_mut(idx)
}
pub fn remove_node(&mut self, idx: NodeIndex) -> Option<N> {
self.inner.remove_node(idx)
}
pub fn contains_node(&self, idx: NodeIndex) -> bool {
self.inner.contains_node(idx)
}
pub fn add_edge(&mut self, a: NodeIndex, b: NodeIndex) -> Option<EdgeIndex> {
if !self.contains_node(a) || !self.contains_node(b) {
return None;
}
Some(self.inner.add_edge(a, b, DagEdge::default()))
}
pub fn add_edge_with_label(
&mut self,
a: NodeIndex,
b: NodeIndex,
label: impl Into<String>,
) -> Option<EdgeIndex> {
if !self.contains_node(a) || !self.contains_node(b) {
return None;
}
Some(self.inner.add_edge(
a,
b,
DagEdge {
label: Some(label.into()),
},
))
}
pub fn has_edge(&self, a: NodeIndex, b: NodeIndex) -> bool {
self.inner.find_edge(a, b).is_some()
}
pub fn remove_edge(&mut self, a: NodeIndex, b: NodeIndex) -> bool {
if let Some(edge) = self.inner.find_edge(a, b) {
self.inner.remove_edge(edge);
true
} else {
false
}
}
pub fn neighbors_directed(
&self,
idx: NodeIndex,
direction: petgraph::Direction,
) -> impl Iterator<Item = NodeIndex> + '_ {
self.inner.neighbors_directed(idx, direction)
}
pub fn incoming_neighbors(&self, idx: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
self.neighbors_directed(idx, petgraph::Direction::Incoming)
}
pub fn outgoing_neighbors(&self, idx: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
self.neighbors_directed(idx, petgraph::Direction::Outgoing)
}
}
impl<N> Default for StableDag<N> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stable_graph_new_creates_empty_graph() {
let graph: StableDag<String> = StableDag::new();
assert_eq!(graph.node_count(), 0);
assert_eq!(graph.edge_count(), 0);
}
#[test]
fn test_stable_graph_with_capacity() {
let graph: StableDag<String> = StableDag::with_capacity(10, 20);
assert_eq!(graph.node_count(), 0);
}
#[test]
fn test_stable_graph_default_is_empty() {
let graph: StableDag<String> = StableDag::default();
assert_eq!(graph.node_count(), 0);
assert_eq!(graph.edge_count(), 0);
}
#[test]
fn test_add_node_returns_stable_index() {
let mut graph: StableDag<String> = StableDag::new();
let idx1 = graph.add_node("msg-001".to_string());
let idx2 = graph.add_node("msg-002".to_string());
assert_eq!(idx1.index(), 0);
assert_eq!(idx2.index(), 1);
assert_eq!(graph.node_count(), 2);
}
#[test]
fn test_add_node_increments_count() {
let mut graph: StableDag<String> = StableDag::new();
assert_eq!(graph.node_count(), 0);
graph.add_node("a".to_string());
assert_eq!(graph.node_count(), 1);
graph.add_node("b".to_string());
assert_eq!(graph.node_count(), 2);
}
#[test]
fn test_node_weight_retrieves_value() {
let mut graph: StableDag<String> = StableDag::new();
let idx = graph.add_node("test-value".to_string());
assert_eq!(graph.node_weight(idx), Some(&"test-value".to_string()));
}
#[test]
fn test_node_weight_mut_allows_modification() {
let mut graph: StableDag<String> = StableDag::new();
let idx = graph.add_node("original".to_string());
if let Some(weight) = graph.node_weight_mut(idx) {
*weight = "modified".to_string();
}
assert_eq!(graph.node_weight(idx), Some(&"modified".to_string()));
}
#[test]
fn test_remove_node_preserves_other_indices() {
let mut graph: StableDag<String> = StableDag::new();
let idx0 = graph.add_node("msg-001".to_string());
let idx1 = graph.add_node("msg-002".to_string());
let idx2 = graph.add_node("msg-003".to_string());
graph.remove_node(idx1);
assert_eq!(graph.node_weight(idx0), Some(&"msg-001".to_string()));
assert_eq!(graph.node_weight(idx1), None); assert_eq!(graph.node_weight(idx2), Some(&"msg-003".to_string()));
assert_eq!(idx0.index(), 0);
assert_eq!(idx2.index(), 2); }
#[test]
fn test_remove_node_decrements_count() {
let mut graph: StableDag<String> = StableDag::new();
let idx = graph.add_node("test".to_string());
assert_eq!(graph.node_count(), 1);
graph.remove_node(idx);
assert_eq!(graph.node_count(), 0);
}
#[test]
fn test_remove_node_returns_weight() {
let mut graph: StableDag<String> = StableDag::new();
let idx = graph.add_node("value".to_string());
let removed = graph.remove_node(idx);
assert_eq!(removed, Some("value".to_string()));
}
#[test]
fn test_remove_node_missing_returns_none() {
let mut graph: StableDag<String> = StableDag::new();
let idx = graph.add_node("value".to_string());
graph.remove_node(idx);
let removed = graph.remove_node(idx);
assert_eq!(removed, None);
}
#[test]
fn test_contains_node_after_removal() {
let mut graph: StableDag<String> = StableDag::new();
let idx = graph.add_node("test".to_string());
assert!(graph.contains_node(idx));
graph.remove_node(idx);
assert!(!graph.contains_node(idx));
}
#[test]
fn test_add_edge_creates_directed_edge() {
let mut graph: StableDag<String> = StableDag::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
let edge = graph.add_edge(a, b);
assert!(edge.is_some());
assert_eq!(graph.edge_count(), 1);
}
#[test]
fn test_add_edge_invalid_node_returns_none() {
use petgraph::stable_graph::NodeIndex;
let mut graph: StableDag<String> = StableDag::new();
let a = graph.add_node("a".to_string());
let invalid = NodeIndex::new(999);
let edge = graph.add_edge(a, invalid);
assert!(edge.is_none());
assert_eq!(graph.edge_count(), 0);
}
#[test]
fn test_has_edge_checks_existence() {
let mut graph: StableDag<String> = StableDag::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
let c = graph.add_node("c".to_string());
graph.add_edge(a, b);
assert!(graph.has_edge(a, b));
assert!(!graph.has_edge(b, a)); assert!(!graph.has_edge(a, c)); }
#[test]
fn test_remove_edge_works() {
let mut graph: StableDag<String> = StableDag::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
graph.add_edge(a, b);
assert!(graph.has_edge(a, b));
assert_eq!(graph.edge_count(), 1);
let removed = graph.remove_edge(a, b);
assert!(removed);
assert!(!graph.has_edge(a, b));
assert_eq!(graph.edge_count(), 0);
}
#[test]
fn test_remove_edge_nonexistent_returns_false() {
let mut graph: StableDag<String> = StableDag::new();
let a = graph.add_node("a".to_string());
let b = graph.add_node("b".to_string());
let removed = graph.remove_edge(a, b);
assert!(!removed);
}
}