Skip to main content

neco_nodegraph/
graph.rs

1use alloc::collections::BTreeMap;
2use alloc::vec::Vec;
3
4use crate::edge::Edge;
5use crate::error::GraphError;
6use crate::id::{EdgeId, NodeId, PortId};
7use crate::node::Node;
8use crate::port::{Port, PortDirection};
9
10/// Pure node graph data model with typed ports.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct NodeGraph<N, E> {
13    pub(crate) nodes: BTreeMap<NodeId, Node<N>>,
14    pub(crate) edges: BTreeMap<EdgeId, Edge<E>>,
15    pub(crate) next_node_id: u64,
16    pub(crate) next_edge_id: u64,
17}
18
19impl<N, E> Default for NodeGraph<N, E> {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl<N, E> NodeGraph<N, E> {
26    /// Creates an empty graph.
27    pub fn new() -> Self {
28        Self {
29            nodes: BTreeMap::new(),
30            edges: BTreeMap::new(),
31            next_node_id: 0,
32            next_edge_id: 0,
33        }
34    }
35
36    /// Adds a node without ports and returns its identifier.
37    pub fn add_node(&mut self, payload: N) -> NodeId {
38        self.add_node_with_ports(payload, Vec::new())
39    }
40
41    /// Adds a node with declared ports and returns its identifier.
42    pub fn add_node_with_ports(&mut self, payload: N, ports: Vec<Port>) -> NodeId {
43        let id = NodeId::new(self.next_node_id);
44        self.next_node_id += 1;
45        let node = Node::new(id, payload, ports);
46        self.nodes.insert(id, node);
47        id
48    }
49
50    /// Removes a node and all connected edges.
51    pub fn remove_node(&mut self, id: NodeId) -> Result<Node<N>, GraphError> {
52        let connected = self.connected_edge_ids_for_node(id)?;
53        let removed = self.nodes.remove(&id).ok_or(GraphError::NodeNotFound(id))?;
54        for edge_id in connected {
55            let _ = self.edges.remove(&edge_id);
56        }
57        Ok(removed)
58    }
59
60    /// Adds a validated directed edge between two ports.
61    pub fn add_edge(
62        &mut self,
63        from: (NodeId, PortId),
64        to: (NodeId, PortId),
65        payload: E,
66    ) -> Result<EdgeId, GraphError> {
67        if from.0 == to.0 {
68            return Err(GraphError::SelfLoop);
69        }
70
71        let from_port = self.resolve_port(&from)?;
72        let to_port = self.resolve_port(&to)?;
73
74        if from_port.direction() != PortDirection::Output
75            || to_port.direction() != PortDirection::Input
76        {
77            return Err(GraphError::PortDirectionMismatch);
78        }
79
80        if from_port.type_tag() != to_port.type_tag() {
81            return Err(GraphError::TypeTagMismatch {
82                expected: from_port.type_tag().into(),
83                actual: to_port.type_tag().into(),
84            });
85        }
86
87        if self
88            .edges
89            .values()
90            .any(|edge| edge.from() == &from && edge.to() == &to)
91        {
92            return Err(GraphError::DuplicateEdge);
93        }
94
95        let id = EdgeId::new(self.next_edge_id);
96        self.next_edge_id += 1;
97        self.edges.insert(id, Edge::new(id, from, to, payload));
98        Ok(id)
99    }
100
101    /// Removes an edge by identifier.
102    pub fn remove_edge(&mut self, id: EdgeId) -> Result<Edge<E>, GraphError> {
103        self.edges.remove(&id).ok_or(GraphError::EdgeNotFound(id))
104    }
105
106    /// Returns an immutable node reference.
107    pub fn node(&self, id: NodeId) -> Option<&Node<N>> {
108        self.nodes.get(&id)
109    }
110
111    /// Returns a mutable node reference.
112    pub fn node_mut(&mut self, id: NodeId) -> Option<&mut Node<N>> {
113        self.nodes.get_mut(&id)
114    }
115
116    /// Returns an immutable edge reference.
117    pub fn edge(&self, id: EdgeId) -> Option<&Edge<E>> {
118        self.edges.get(&id)
119    }
120
121    /// Iterates over nodes in identifier order.
122    pub fn nodes(&self) -> impl Iterator<Item = (&NodeId, &Node<N>)> {
123        self.nodes.iter()
124    }
125
126    /// Iterates over edges in identifier order.
127    pub fn edges(&self) -> impl Iterator<Item = (&EdgeId, &Edge<E>)> {
128        self.edges.iter()
129    }
130
131    /// Returns all edges connected to the given port.
132    pub fn connected(&self, port: (NodeId, PortId)) -> Vec<EdgeId> {
133        self.edges
134            .iter()
135            .filter_map(|(edge_id, edge)| {
136                ((edge.from() == &port) || (edge.to() == &port)).then_some(*edge_id)
137            })
138            .collect()
139    }
140
141    /// Returns all incoming edge identifiers for a node.
142    pub fn incoming(&self, node: NodeId) -> Vec<EdgeId> {
143        self.edges
144            .iter()
145            .filter_map(|(edge_id, edge)| (edge.to().0 == node).then_some(*edge_id))
146            .collect()
147    }
148
149    /// Returns all outgoing edge identifiers for a node.
150    pub fn outgoing(&self, node: NodeId) -> Vec<EdgeId> {
151        self.edges
152            .iter()
153            .filter_map(|(edge_id, edge)| (edge.from().0 == node).then_some(*edge_id))
154            .collect()
155    }
156
157    fn resolve_port(&self, endpoint: &(NodeId, PortId)) -> Result<&Port, GraphError> {
158        let node = self
159            .nodes
160            .get(&endpoint.0)
161            .ok_or(GraphError::NodeNotFound(endpoint.0))?;
162        node.port(&endpoint.1)
163            .ok_or_else(|| GraphError::PortNotFound {
164                node: endpoint.0,
165                port: endpoint.1.clone(),
166            })
167    }
168
169    fn connected_edge_ids_for_node(&self, node: NodeId) -> Result<Vec<EdgeId>, GraphError> {
170        if !self.nodes.contains_key(&node) {
171            return Err(GraphError::NodeNotFound(node));
172        }
173
174        Ok(self
175            .edges
176            .iter()
177            .filter_map(|(edge_id, edge)| {
178                ((edge.from().0 == node) || (edge.to().0 == node)).then_some(*edge_id)
179            })
180            .collect())
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use alloc::vec;
187
188    use super::NodeGraph;
189    use crate::error::GraphError;
190    use crate::id::PortId;
191    use crate::port::{Port, PortDirection};
192
193    fn port(id: &str, direction: PortDirection, type_tag: &str) -> Port {
194        Port::new(PortId::new(id).expect("port id"), direction, type_tag).expect("port")
195    }
196
197    #[test]
198    fn node_ids_increment_from_zero() {
199        let mut graph = NodeGraph::<(), ()>::new();
200        let first = graph.add_node(());
201        let second = graph.add_node(());
202
203        assert_eq!(first.as_u64(), 0);
204        assert_eq!(second.as_u64(), 1);
205    }
206
207    #[test]
208    fn remove_missing_node_returns_error() {
209        let mut graph = NodeGraph::<(), ()>::new();
210
211        let error = graph
212            .remove_node(crate::NodeId::new(9))
213            .expect_err("missing node");
214
215        assert_eq!(error, GraphError::NodeNotFound(crate::NodeId::new(9)));
216    }
217
218    #[test]
219    fn remove_missing_edge_returns_error() {
220        let mut graph = NodeGraph::<(), ()>::new();
221
222        let error = graph
223            .remove_edge(crate::EdgeId::new(4))
224            .expect_err("missing edge");
225
226        assert_eq!(error, GraphError::EdgeNotFound(crate::EdgeId::new(4)));
227    }
228
229    #[test]
230    fn add_edge_requires_output_to_input() {
231        let mut graph = NodeGraph::<(), ()>::new();
232        let source = graph.add_node_with_ports((), vec![port("in", PortDirection::Input, "Fact")]);
233        let target =
234            graph.add_node_with_ports((), vec![port("out", PortDirection::Output, "Fact")]);
235
236        let error = graph
237            .add_edge(
238                (source, PortId::new("in").unwrap()),
239                (target, PortId::new("out").unwrap()),
240                (),
241            )
242            .expect_err("direction mismatch");
243
244        assert_eq!(error, GraphError::PortDirectionMismatch);
245    }
246}