1use alloc::collections::BTreeMap;
2use alloc::vec::Vec;
3
4use crate::edge::Edge;
5use crate::error::GraphError;
6use crate::id::{EdgeId, NodeId, PortId};
7use crate::node::Node;
8use crate::port::{Port, PortDirection};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct NodeGraph<N, E> {
13 pub(crate) nodes: BTreeMap<NodeId, Node<N>>,
14 pub(crate) edges: BTreeMap<EdgeId, Edge<E>>,
15 pub(crate) next_node_id: u64,
16 pub(crate) next_edge_id: u64,
17}
18
19impl<N, E> Default for NodeGraph<N, E> {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl<N, E> NodeGraph<N, E> {
26 pub fn new() -> Self {
28 Self {
29 nodes: BTreeMap::new(),
30 edges: BTreeMap::new(),
31 next_node_id: 0,
32 next_edge_id: 0,
33 }
34 }
35
36 pub fn add_node(&mut self, payload: N) -> NodeId {
38 self.add_node_with_ports(payload, Vec::new())
39 }
40
41 pub fn add_node_with_ports(&mut self, payload: N, ports: Vec<Port>) -> NodeId {
43 let id = NodeId::new(self.next_node_id);
44 self.next_node_id += 1;
45 let node = Node::new(id, payload, ports);
46 self.nodes.insert(id, node);
47 id
48 }
49
50 pub fn remove_node(&mut self, id: NodeId) -> Result<Node<N>, GraphError> {
52 let connected = self.connected_edge_ids_for_node(id)?;
53 let removed = self.nodes.remove(&id).ok_or(GraphError::NodeNotFound(id))?;
54 for edge_id in connected {
55 let _ = self.edges.remove(&edge_id);
56 }
57 Ok(removed)
58 }
59
60 pub fn add_edge(
62 &mut self,
63 from: (NodeId, PortId),
64 to: (NodeId, PortId),
65 payload: E,
66 ) -> Result<EdgeId, GraphError> {
67 if from.0 == to.0 {
68 return Err(GraphError::SelfLoop);
69 }
70
71 let from_port = self.resolve_port(&from)?;
72 let to_port = self.resolve_port(&to)?;
73
74 if from_port.direction() != PortDirection::Output
75 || to_port.direction() != PortDirection::Input
76 {
77 return Err(GraphError::PortDirectionMismatch);
78 }
79
80 if from_port.type_tag() != to_port.type_tag() {
81 return Err(GraphError::TypeTagMismatch {
82 expected: from_port.type_tag().into(),
83 actual: to_port.type_tag().into(),
84 });
85 }
86
87 if self
88 .edges
89 .values()
90 .any(|edge| edge.from() == &from && edge.to() == &to)
91 {
92 return Err(GraphError::DuplicateEdge);
93 }
94
95 let id = EdgeId::new(self.next_edge_id);
96 self.next_edge_id += 1;
97 self.edges.insert(id, Edge::new(id, from, to, payload));
98 Ok(id)
99 }
100
101 pub fn remove_edge(&mut self, id: EdgeId) -> Result<Edge<E>, GraphError> {
103 self.edges.remove(&id).ok_or(GraphError::EdgeNotFound(id))
104 }
105
106 pub fn node(&self, id: NodeId) -> Option<&Node<N>> {
108 self.nodes.get(&id)
109 }
110
111 pub fn node_mut(&mut self, id: NodeId) -> Option<&mut Node<N>> {
113 self.nodes.get_mut(&id)
114 }
115
116 pub fn edge(&self, id: EdgeId) -> Option<&Edge<E>> {
118 self.edges.get(&id)
119 }
120
121 pub fn nodes(&self) -> impl Iterator<Item = (&NodeId, &Node<N>)> {
123 self.nodes.iter()
124 }
125
126 pub fn edges(&self) -> impl Iterator<Item = (&EdgeId, &Edge<E>)> {
128 self.edges.iter()
129 }
130
131 pub fn connected(&self, port: (NodeId, PortId)) -> Vec<EdgeId> {
133 self.edges
134 .iter()
135 .filter_map(|(edge_id, edge)| {
136 ((edge.from() == &port) || (edge.to() == &port)).then_some(*edge_id)
137 })
138 .collect()
139 }
140
141 pub fn incoming(&self, node: NodeId) -> Vec<EdgeId> {
143 self.edges
144 .iter()
145 .filter_map(|(edge_id, edge)| (edge.to().0 == node).then_some(*edge_id))
146 .collect()
147 }
148
149 pub fn outgoing(&self, node: NodeId) -> Vec<EdgeId> {
151 self.edges
152 .iter()
153 .filter_map(|(edge_id, edge)| (edge.from().0 == node).then_some(*edge_id))
154 .collect()
155 }
156
157 fn resolve_port(&self, endpoint: &(NodeId, PortId)) -> Result<&Port, GraphError> {
158 let node = self
159 .nodes
160 .get(&endpoint.0)
161 .ok_or(GraphError::NodeNotFound(endpoint.0))?;
162 node.port(&endpoint.1)
163 .ok_or_else(|| GraphError::PortNotFound {
164 node: endpoint.0,
165 port: endpoint.1.clone(),
166 })
167 }
168
169 fn connected_edge_ids_for_node(&self, node: NodeId) -> Result<Vec<EdgeId>, GraphError> {
170 if !self.nodes.contains_key(&node) {
171 return Err(GraphError::NodeNotFound(node));
172 }
173
174 Ok(self
175 .edges
176 .iter()
177 .filter_map(|(edge_id, edge)| {
178 ((edge.from().0 == node) || (edge.to().0 == node)).then_some(*edge_id)
179 })
180 .collect())
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use alloc::vec;
187
188 use super::NodeGraph;
189 use crate::error::GraphError;
190 use crate::id::PortId;
191 use crate::port::{Port, PortDirection};
192
193 fn port(id: &str, direction: PortDirection, type_tag: &str) -> Port {
194 Port::new(PortId::new(id).expect("port id"), direction, type_tag).expect("port")
195 }
196
197 #[test]
198 fn node_ids_increment_from_zero() {
199 let mut graph = NodeGraph::<(), ()>::new();
200 let first = graph.add_node(());
201 let second = graph.add_node(());
202
203 assert_eq!(first.as_u64(), 0);
204 assert_eq!(second.as_u64(), 1);
205 }
206
207 #[test]
208 fn remove_missing_node_returns_error() {
209 let mut graph = NodeGraph::<(), ()>::new();
210
211 let error = graph
212 .remove_node(crate::NodeId::new(9))
213 .expect_err("missing node");
214
215 assert_eq!(error, GraphError::NodeNotFound(crate::NodeId::new(9)));
216 }
217
218 #[test]
219 fn remove_missing_edge_returns_error() {
220 let mut graph = NodeGraph::<(), ()>::new();
221
222 let error = graph
223 .remove_edge(crate::EdgeId::new(4))
224 .expect_err("missing edge");
225
226 assert_eq!(error, GraphError::EdgeNotFound(crate::EdgeId::new(4)));
227 }
228
229 #[test]
230 fn add_edge_requires_output_to_input() {
231 let mut graph = NodeGraph::<(), ()>::new();
232 let source = graph.add_node_with_ports((), vec![port("in", PortDirection::Input, "Fact")]);
233 let target =
234 graph.add_node_with_ports((), vec![port("out", PortDirection::Output, "Fact")]);
235
236 let error = graph
237 .add_edge(
238 (source, PortId::new("in").unwrap()),
239 (target, PortId::new("out").unwrap()),
240 (),
241 )
242 .expect_err("direction mismatch");
243
244 assert_eq!(error, GraphError::PortDirectionMismatch);
245 }
246}