1use crate::id::{EdgeId, NodeId, PortId};
4use slotmap::SlotMap;
5
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22#[cfg_attr(
23 feature = "serde",
24 serde(bound(
25 serialize = "N: Serialize, P: Serialize, E: Serialize",
26 deserialize = "N: Deserialize<'de>, P: Deserialize<'de>, E: Deserialize<'de>"
27 ))
28)]
29pub struct Graph<N, P, E> {
30 nodes: SlotMap<NodeId, NodeEntry<N>>,
31 ports: SlotMap<PortId, PortEntry<P>>,
32 edges: SlotMap<EdgeId, EdgeEntry<E>>,
33}
34
35#[derive(Debug, Clone)]
36#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
37struct NodeEntry<N> {
38 data: N,
39 ports: Vec<PortId>,
41}
42
43#[derive(Debug, Clone)]
44#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
45struct PortEntry<P> {
46 data: P,
47 node: NodeId,
48 edges: Vec<EdgeId>,
50}
51
52#[derive(Debug, Clone)]
53#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
54struct EdgeEntry<E> {
55 data: E,
56 ports: [PortId; 2],
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub struct NodeMissing;
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66pub enum ConnectError {
67 PortMissing,
69 SelfLoop,
72 Rejected,
75}
76
77impl core::fmt::Display for ConnectError {
78 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
79 match self {
80 ConnectError::PortMissing => write!(f, "one or both ports do not exist"),
81 ConnectError::SelfLoop => write!(f, "cannot connect a port to itself"),
82 ConnectError::Rejected => write!(f, "connection rejected by validator"),
83 }
84 }
85}
86
87impl std::error::Error for ConnectError {}
88
89impl<N, P, E> Default for Graph<N, P, E> {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl<N, P, E> Graph<N, P, E> {
96 pub fn new() -> Self {
100 Graph {
101 nodes: SlotMap::with_key(),
102 ports: SlotMap::with_key(),
103 edges: SlotMap::with_key(),
104 }
105 }
106
107 pub fn add_node(&mut self, data: N) -> NodeId {
111 self.nodes.insert(NodeEntry {
112 data,
113 ports: Vec::new(),
114 })
115 }
116
117 pub fn add_port(&mut self, node: NodeId, data: P) -> Result<PortId, NodeMissing> {
120 if !self.nodes.contains_key(node) {
121 return Err(NodeMissing);
122 }
123 let port = self.ports.insert(PortEntry {
124 data,
125 node,
126 edges: Vec::new(),
127 });
128 self.nodes[node].ports.push(port);
129 Ok(port)
130 }
131
132 pub fn connect(&mut self, a: PortId, b: PortId, data: E) -> Result<EdgeId, ConnectError> {
137 self.connect_with(a, b, data, |_, _| true)
138 }
139
140 pub fn connect_with<F>(
158 &mut self,
159 a: PortId,
160 b: PortId,
161 data: E,
162 check: F,
163 ) -> Result<EdgeId, ConnectError>
164 where
165 F: FnOnce(&P, &P) -> bool,
166 {
167 if a == b {
168 return Err(ConnectError::SelfLoop);
169 }
170 let (pa, pb) = match (self.ports.get(a), self.ports.get(b)) {
171 (Some(pa), Some(pb)) => (pa, pb),
172 _ => return Err(ConnectError::PortMissing),
173 };
174 if !check(&pa.data, &pb.data) {
175 return Err(ConnectError::Rejected);
176 }
177 let edge = self.edges.insert(EdgeEntry {
178 data,
179 ports: [a, b],
180 });
181 self.ports[a].edges.push(edge);
182 self.ports[b].edges.push(edge);
183 Ok(edge)
184 }
185
186 pub fn disconnect(&mut self, edge: EdgeId) -> Option<E> {
188 let entry = self.edges.remove(edge)?;
189 for p in entry.ports {
190 if let Some(port) = self.ports.get_mut(p) {
191 port.edges.retain(|&e| e != edge);
192 }
193 }
194 Some(entry.data)
195 }
196
197 pub fn remove_port(&mut self, port: PortId) -> Option<P> {
199 let entry = self.ports.remove(port)?;
200 for edge in entry.edges {
201 if let Some(e) = self.edges.remove(edge) {
204 for p in e.ports {
205 if p != port
206 && let Some(other) = self.ports.get_mut(p)
207 {
208 other.edges.retain(|&x| x != edge);
209 }
210 }
211 }
212 }
213 if let Some(node) = self.nodes.get_mut(entry.node) {
214 node.ports.retain(|&p| p != port);
215 }
216 Some(entry.data)
217 }
218
219 pub fn remove_node(&mut self, node: NodeId) -> Option<N> {
221 let entry = self.nodes.remove(node)?;
222 for port in entry.ports {
223 self.remove_port(port);
226 }
227 Some(entry.data)
228 }
229
230 pub fn clear(&mut self) {
232 self.nodes.clear();
233 self.ports.clear();
234 self.edges.clear();
235 }
236
237 pub fn node(&self, id: NodeId) -> Option<&N> {
242 self.nodes.get(id).map(|n| &n.data)
243 }
244
245 pub fn node_mut(&mut self, id: NodeId) -> Option<&mut N> {
247 self.nodes.get_mut(id).map(|n| &mut n.data)
248 }
249
250 pub fn port(&self, id: PortId) -> Option<&P> {
252 self.ports.get(id).map(|p| &p.data)
253 }
254
255 pub fn port_mut(&mut self, id: PortId) -> Option<&mut P> {
257 self.ports.get_mut(id).map(|p| &mut p.data)
258 }
259
260 pub fn edge(&self, id: EdgeId) -> Option<&E> {
262 self.edges.get(id).map(|e| &e.data)
263 }
264
265 pub fn edge_mut(&mut self, id: EdgeId) -> Option<&mut E> {
267 self.edges.get_mut(id).map(|e| &mut e.data)
268 }
269
270 pub fn port_node(&self, port: PortId) -> Option<NodeId> {
274 self.ports.get(port).map(|p| p.node)
275 }
276
277 pub fn edge_endpoints(&self, edge: EdgeId) -> Option<(PortId, PortId)> {
279 self.edges.get(edge).map(|e| (e.ports[0], e.ports[1]))
280 }
281
282 pub fn edge_nodes(&self, edge: EdgeId) -> Option<(NodeId, NodeId)> {
284 let (a, b) = self.edge_endpoints(edge)?;
285 Some((self.port_node(a)?, self.port_node(b)?))
286 }
287
288 pub fn opposite(&self, edge: EdgeId, port: PortId) -> Option<PortId> {
290 let (a, b) = self.edge_endpoints(edge)?;
291 if port == a {
292 Some(b)
293 } else if port == b {
294 Some(a)
295 } else {
296 None
297 }
298 }
299
300 pub fn ports(&self, node: NodeId) -> impl Iterator<Item = PortId> + '_ {
302 self.nodes
303 .get(node)
304 .map(|n| n.ports.as_slice())
305 .unwrap_or(&[])
306 .iter()
307 .copied()
308 }
309
310 pub fn find_port<F>(&self, node: NodeId, mut pred: F) -> Option<PortId>
326 where
327 F: FnMut(&P) -> bool,
328 {
329 self.ports(node)
330 .find(|&p| self.port(p).is_some_and(&mut pred))
331 }
332
333 pub fn find_ports<'a, F>(
335 &'a self,
336 node: NodeId,
337 mut pred: F,
338 ) -> impl Iterator<Item = PortId> + 'a
339 where
340 F: FnMut(&P) -> bool + 'a,
341 {
342 self.ports(node)
343 .filter(move |&p| self.port(p).is_some_and(&mut pred))
344 }
345
346 pub fn port_edges(&self, port: PortId) -> impl Iterator<Item = EdgeId> + '_ {
348 self.ports
349 .get(port)
350 .map(|p| p.edges.as_slice())
351 .unwrap_or(&[])
352 .iter()
353 .copied()
354 }
355
356 pub fn node_edges(&self, node: NodeId) -> impl Iterator<Item = EdgeId> + '_ {
358 self.ports(node).flat_map(|p| self.port_edges(p))
359 }
360
361 pub fn neighbors(&self, node: NodeId) -> impl Iterator<Item = NodeId> + '_ {
364 self.ports(node).flat_map(move |p| {
365 self.port_edges(p)
366 .filter_map(move |e| self.opposite(e, p).and_then(|q| self.port_node(q)))
367 })
368 }
369
370 pub fn edges_between(&self, a: PortId, b: PortId) -> impl Iterator<Item = EdgeId> + '_ {
372 self.port_edges(a)
373 .filter(move |&e| self.opposite(e, a) == Some(b))
374 }
375
376 pub fn nodes(&self) -> impl Iterator<Item = (NodeId, &N)> + '_ {
381 self.nodes.iter().map(|(id, n)| (id, &n.data))
382 }
383
384 pub fn all_ports(&self) -> impl Iterator<Item = (PortId, &P)> + '_ {
386 self.ports.iter().map(|(id, p)| (id, &p.data))
387 }
388
389 pub fn all_edges(&self) -> impl Iterator<Item = (EdgeId, &E)> + '_ {
391 self.edges.iter().map(|(id, e)| (id, &e.data))
392 }
393
394 pub fn node_count(&self) -> usize {
396 self.nodes.len()
397 }
398
399 pub fn port_count(&self) -> usize {
401 self.ports.len()
402 }
403
404 pub fn edge_count(&self) -> usize {
406 self.edges.len()
407 }
408
409 pub fn contains_node(&self, id: NodeId) -> bool {
411 self.nodes.contains_key(id)
412 }
413
414 pub fn contains_port(&self, id: PortId) -> bool {
416 self.ports.contains_key(id)
417 }
418
419 pub fn contains_edge(&self, id: EdgeId) -> bool {
421 self.edges.contains_key(id)
422 }
423}
424
425impl<N, P, E> core::ops::Index<NodeId> for Graph<N, P, E> {
428 type Output = N;
429 fn index(&self, id: NodeId) -> &N {
430 &self.nodes[id].data
431 }
432}
433impl<N, P, E> core::ops::IndexMut<NodeId> for Graph<N, P, E> {
434 fn index_mut(&mut self, id: NodeId) -> &mut N {
435 &mut self.nodes[id].data
436 }
437}
438impl<N, P, E> core::ops::Index<PortId> for Graph<N, P, E> {
439 type Output = P;
440 fn index(&self, id: PortId) -> &P {
441 &self.ports[id].data
442 }
443}
444impl<N, P, E> core::ops::IndexMut<PortId> for Graph<N, P, E> {
445 fn index_mut(&mut self, id: PortId) -> &mut P {
446 &mut self.ports[id].data
447 }
448}
449impl<N, P, E> core::ops::Index<EdgeId> for Graph<N, P, E> {
450 type Output = E;
451 fn index(&self, id: EdgeId) -> &E {
452 &self.edges[id].data
453 }
454}
455impl<N, P, E> core::ops::IndexMut<EdgeId> for Graph<N, P, E> {
456 fn index_mut(&mut self, id: EdgeId) -> &mut E {
457 &mut self.edges[id].data
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 fn rc_pair() -> (
466 Graph<&'static str, &'static str, &'static str>,
467 NodeId,
468 PortId,
469 PortId,
470 EdgeId,
471 ) {
472 let mut g = Graph::new();
473 let r = g.add_node("R1");
474 let ra = g.add_port(r, "a").unwrap();
475 let rb = g.add_port(r, "b").unwrap();
476 let c = g.add_node("C1");
477 let cp = g.add_port(c, "+").unwrap();
478 let w = g.connect(rb, cp, "w1").unwrap();
479 (g, r, ra, rb, w)
480 }
481
482 #[test]
483 fn cascade_node_removal() {
484 let (mut g, r, ra, rb, w) = rc_pair();
485 g.remove_node(r);
486 assert!(!g.contains_node(r));
487 assert!(!g.contains_port(ra));
488 assert!(!g.contains_port(rb));
489 assert!(!g.contains_edge(w));
490 assert_eq!(g.node_count(), 1); assert_eq!(g.edge_count(), 0);
492 }
493
494 #[test]
495 fn cascade_port_removal_cleans_other_endpoint() {
496 let (mut g, _r, _ra, rb, w) = rc_pair();
497 g.remove_port(rb);
498 assert!(!g.contains_edge(w));
499 for (p, _) in g.all_ports() {
501 assert!(g.port_edges(p).all(|e| g.contains_edge(e)));
502 }
503 }
504
505 #[test]
506 fn stale_ids_miss_quietly() {
507 let (mut g, r, ra, _rb, w) = rc_pair();
508 g.remove_node(r);
509 assert_eq!(g.node(r), None);
510 assert_eq!(g.port(ra), None);
511 assert_eq!(g.edge(w), None);
512 assert_eq!(g.ports(r).count(), 0);
513 }
514
515 #[test]
516 fn self_loop_rejected() {
517 let (mut g, _r, ra, _rb, _w) = rc_pair();
518 assert_eq!(g.connect(ra, ra, "x"), Err(ConnectError::SelfLoop));
519 }
520}