Skip to main content

serdes_ai_graph/
edge.rs

1//! Graph edge types.
2
3use std::fmt;
4use std::sync::Arc;
5
6/// An edge between two nodes with an optional condition.
7pub struct Edge<S> {
8    /// Source node name.
9    pub from: String,
10    /// Target node name.
11    pub to: String,
12    /// Condition function.
13    pub condition: Arc<dyn Fn(&S) -> bool + Send + Sync>,
14    /// Edge label/description.
15    pub label: Option<String>,
16}
17
18impl<S> Edge<S> {
19    /// Create a new conditional edge.
20    pub fn new<F>(from: impl Into<String>, to: impl Into<String>, condition: F) -> Self
21    where
22        F: Fn(&S) -> bool + Send + Sync + 'static,
23    {
24        Self {
25            from: from.into(),
26            to: to.into(),
27            condition: Arc::new(condition),
28            label: None,
29        }
30    }
31
32    /// Create an unconditional edge (always true).
33    pub fn unconditional(from: impl Into<String>, to: impl Into<String>) -> Self {
34        Self {
35            from: from.into(),
36            to: to.into(),
37            condition: Arc::new(|_| true),
38            label: None,
39        }
40    }
41
42    /// Add a label to the edge.
43    pub fn with_label(mut self, label: impl Into<String>) -> Self {
44        self.label = Some(label.into());
45        self
46    }
47
48    /// Check if the condition is satisfied.
49    pub fn matches(&self, state: &S) -> bool {
50        (self.condition)(state)
51    }
52}
53
54impl<S> fmt::Debug for Edge<S> {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        f.debug_struct("Edge")
57            .field("from", &self.from)
58            .field("to", &self.to)
59            .field("label", &self.label)
60            .finish()
61    }
62}
63
64impl<S> Clone for Edge<S> {
65    fn clone(&self) -> Self {
66        Self {
67            from: self.from.clone(),
68            to: self.to.clone(),
69            condition: Arc::clone(&self.condition),
70            label: self.label.clone(),
71        }
72    }
73}
74
75/// Builder for creating edges.
76#[derive(Debug)]
77pub struct EdgeBuilder<S> {
78    from: String,
79    to: Option<String>,
80    label: Option<String>,
81    _phantom: std::marker::PhantomData<S>,
82}
83
84impl<S> EdgeBuilder<S> {
85    /// Start building an edge from a node.
86    pub fn from(name: impl Into<String>) -> Self {
87        Self {
88            from: name.into(),
89            to: None,
90            label: None,
91            _phantom: std::marker::PhantomData,
92        }
93    }
94
95    /// Set the target node.
96    pub fn to(mut self, name: impl Into<String>) -> Self {
97        self.to = Some(name.into());
98        self
99    }
100
101    /// Set the edge label.
102    pub fn label(mut self, label: impl Into<String>) -> Self {
103        self.label = Some(label.into());
104        self
105    }
106
107    /// Build with a condition.
108    pub fn when<F>(self, condition: F) -> Edge<S>
109    where
110        F: Fn(&S) -> bool + Send + Sync + 'static,
111    {
112        let to = self.to.expect("Target node required");
113        let mut edge = Edge::new(self.from, to, condition);
114        edge.label = self.label;
115        edge
116    }
117
118    /// Build as unconditional.
119    pub fn always(self) -> Edge<S> {
120        let to = self.to.expect("Target node required");
121        let mut edge = Edge::unconditional(self.from, to);
122        edge.label = self.label;
123        edge
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[derive(Debug, Clone)]
132    struct TestState {
133        value: i32,
134    }
135
136    #[test]
137    fn test_conditional_edge() {
138        let edge = Edge::new("a", "b", |s: &TestState| s.value > 0);
139
140        assert!(edge.matches(&TestState { value: 1 }));
141        assert!(!edge.matches(&TestState { value: -1 }));
142    }
143
144    #[test]
145    fn test_unconditional_edge() {
146        let edge: Edge<TestState> = Edge::unconditional("a", "b");
147
148        assert!(edge.matches(&TestState { value: 0 }));
149        assert!(edge.matches(&TestState { value: -100 }));
150    }
151
152    #[test]
153    fn test_edge_builder() {
154        let edge = EdgeBuilder::<TestState>::from("start")
155            .to("end")
156            .label("always")
157            .always();
158
159        assert_eq!(edge.from, "start");
160        assert_eq!(edge.to, "end");
161        assert_eq!(edge.label, Some("always".to_string()));
162    }
163
164    #[test]
165    fn test_edge_builder_conditional() {
166        let edge = EdgeBuilder::<TestState>::from("a")
167            .to("b")
168            .when(|s| s.value == 42);
169
170        assert!(edge.matches(&TestState { value: 42 }));
171        assert!(!edge.matches(&TestState { value: 0 }));
172    }
173
174    #[test]
175    fn test_edge_clone() {
176        let edge = Edge::new("a", "b", |s: &TestState| s.value > 0);
177        let cloned = edge.clone();
178
179        assert_eq!(cloned.from, "a");
180        assert_eq!(cloned.to, "b");
181    }
182}