1#![allow(dead_code)]
3
4use std::collections::{HashMap, HashSet, VecDeque};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum PipelineNodeType {
9 Source,
11 Transform,
13 Merge,
15 Split,
17 Sink,
19}
20
21impl PipelineNodeType {
22 #[must_use]
24 pub fn is_terminal(&self) -> bool {
25 matches!(self, Self::Sink)
26 }
27
28 #[must_use]
30 pub fn is_source(&self) -> bool {
31 matches!(self, Self::Source)
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37pub struct PipelineNodeId(pub u32);
38
39#[derive(Debug, Clone)]
41struct NodeMeta {
42 node_type: PipelineNodeType,
43 label: String,
44}
45
46pub struct PipelineGraph {
48 nodes: HashMap<PipelineNodeId, NodeMeta>,
49 edges: HashMap<PipelineNodeId, HashSet<PipelineNodeId>>,
51 next_id: u32,
52}
53
54impl PipelineGraph {
55 #[must_use]
57 pub fn new() -> Self {
58 Self {
59 nodes: HashMap::new(),
60 edges: HashMap::new(),
61 next_id: 0,
62 }
63 }
64
65 pub fn add_node(
67 &mut self,
68 node_type: PipelineNodeType,
69 label: impl Into<String>,
70 ) -> PipelineNodeId {
71 let id = PipelineNodeId(self.next_id);
72 self.next_id += 1;
73 self.nodes.insert(
74 id,
75 NodeMeta {
76 node_type,
77 label: label.into(),
78 },
79 );
80 self.edges.entry(id).or_default();
81 id
82 }
83
84 pub fn connect(&mut self, from: PipelineNodeId, to: PipelineNodeId) -> Result<(), String> {
88 if !self.nodes.contains_key(&from) {
89 return Err(format!("Source node {:?} not found", from));
90 }
91 if !self.nodes.contains_key(&to) {
92 return Err(format!("Destination node {:?} not found", to));
93 }
94 self.edges.entry(from).or_default().insert(to);
95 self.edges.entry(to).or_default();
97 Ok(())
98 }
99
100 #[must_use]
102 pub fn is_valid_dag(&self) -> bool {
103 let mut in_degree: HashMap<PipelineNodeId, usize> =
105 self.nodes.keys().map(|&k| (k, 0)).collect();
106 for neighbours in self.edges.values() {
107 for &n in neighbours {
108 *in_degree.entry(n).or_insert(0) += 1;
109 }
110 }
111 let mut queue: VecDeque<PipelineNodeId> = in_degree
112 .iter()
113 .filter(|(_, &d)| d == 0)
114 .map(|(&k, _)| k)
115 .collect();
116 let mut visited = 0usize;
117 while let Some(node) = queue.pop_front() {
118 visited += 1;
119 if let Some(neighbours) = self.edges.get(&node) {
120 for &n in neighbours {
121 let deg = in_degree.entry(n).or_insert(0);
122 *deg -= 1;
123 if *deg == 0 {
124 queue.push_back(n);
125 }
126 }
127 }
128 }
129 visited == self.nodes.len()
130 }
131
132 #[must_use]
134 pub fn sources(&self) -> Vec<PipelineNodeId> {
135 let mut has_incoming: HashSet<PipelineNodeId> = HashSet::new();
136 for neighbours in self.edges.values() {
137 for &n in neighbours {
138 has_incoming.insert(n);
139 }
140 }
141 self.nodes
142 .keys()
143 .copied()
144 .filter(|id| !has_incoming.contains(id))
145 .collect()
146 }
147
148 #[must_use]
150 pub fn sinks(&self) -> Vec<PipelineNodeId> {
151 self.edges
152 .iter()
153 .filter(|(id, neighbours)| self.nodes.contains_key(id) && neighbours.is_empty())
154 .map(|(&id, _)| id)
155 .collect()
156 }
157
158 #[must_use]
160 pub fn node_count(&self) -> usize {
161 self.nodes.len()
162 }
163
164 #[must_use]
166 pub fn edge_count(&self) -> usize {
167 self.edges.values().map(|s| s.len()).sum()
168 }
169
170 #[must_use]
172 pub fn node_type(&self, id: PipelineNodeId) -> Option<PipelineNodeType> {
173 self.nodes.get(&id).map(|m| m.node_type)
174 }
175
176 #[must_use]
178 pub fn node_label(&self, id: PipelineNodeId) -> Option<&str> {
179 self.nodes.get(&id).map(|m| m.label.as_str())
180 }
181}
182
183impl Default for PipelineGraph {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 fn simple_graph() -> (PipelineGraph, PipelineNodeId, PipelineNodeId) {
194 let mut g = PipelineGraph::new();
195 let src = g.add_node(PipelineNodeType::Source, "src");
196 let sink = g.add_node(PipelineNodeType::Sink, "sink");
197 g.connect(src, sink).expect("connect should succeed");
198 (g, src, sink)
199 }
200
201 #[test]
204 fn test_sink_is_terminal() {
205 assert!(PipelineNodeType::Sink.is_terminal());
206 }
207
208 #[test]
209 fn test_transform_not_terminal() {
210 assert!(!PipelineNodeType::Transform.is_terminal());
211 }
212
213 #[test]
214 fn test_source_is_source() {
215 assert!(PipelineNodeType::Source.is_source());
216 }
217
218 #[test]
219 fn test_merge_not_source() {
220 assert!(!PipelineNodeType::Merge.is_source());
221 }
222
223 #[test]
226 fn test_add_node_increments_count() {
227 let mut g = PipelineGraph::new();
228 g.add_node(PipelineNodeType::Source, "s");
229 g.add_node(PipelineNodeType::Sink, "k");
230 assert_eq!(g.node_count(), 2);
231 }
232
233 #[test]
234 fn test_connect_increments_edge_count() {
235 let (g, _, _) = simple_graph();
236 assert_eq!(g.edge_count(), 1);
237 }
238
239 #[test]
240 fn test_connect_nonexistent_node_returns_error() {
241 let mut g = PipelineGraph::new();
242 let src = g.add_node(PipelineNodeType::Source, "s");
243 let ghost = PipelineNodeId(999);
244 assert!(g.connect(src, ghost).is_err());
245 }
246
247 #[test]
248 fn test_is_valid_dag_linear() {
249 let (g, _, _) = simple_graph();
250 assert!(g.is_valid_dag());
251 }
252
253 #[test]
254 fn test_is_valid_dag_cycle_detected() {
255 let mut g = PipelineGraph::new();
256 let a = g.add_node(PipelineNodeType::Source, "a");
257 let b = g.add_node(PipelineNodeType::Transform, "b");
258 g.connect(a, b).expect("connect should succeed");
259 g.connect(b, a).expect("connect should succeed"); assert!(!g.is_valid_dag());
261 }
262
263 #[test]
264 fn test_sources_returns_nodes_without_incoming() {
265 let (g, src, _) = simple_graph();
266 let sources = g.sources();
267 assert_eq!(sources.len(), 1);
268 assert!(sources.contains(&src));
269 }
270
271 #[test]
272 fn test_sinks_returns_nodes_without_outgoing() {
273 let (g, _, sink) = simple_graph();
274 let sinks = g.sinks();
275 assert_eq!(sinks.len(), 1);
276 assert!(sinks.contains(&sink));
277 }
278
279 #[test]
280 fn test_node_type_query() {
281 let (g, src, _) = simple_graph();
282 assert_eq!(g.node_type(src), Some(PipelineNodeType::Source));
283 }
284
285 #[test]
286 fn test_node_label_query() {
287 let (g, src, _) = simple_graph();
288 assert_eq!(g.node_label(src), Some("src"));
289 }
290
291 #[test]
292 fn test_diamond_graph_is_valid_dag() {
293 let mut g = PipelineGraph::new();
294 let a = g.add_node(PipelineNodeType::Source, "a");
295 let b = g.add_node(PipelineNodeType::Transform, "b");
296 let c = g.add_node(PipelineNodeType::Transform, "c");
297 let d = g.add_node(PipelineNodeType::Sink, "d");
298 g.connect(a, b).expect("connect should succeed");
299 g.connect(a, c).expect("connect should succeed");
300 g.connect(b, d).expect("connect should succeed");
301 g.connect(c, d).expect("connect should succeed");
302 assert!(g.is_valid_dag());
303 }
304
305 #[test]
306 fn test_multiple_sources_and_sinks() {
307 let mut g = PipelineGraph::new();
308 let s1 = g.add_node(PipelineNodeType::Source, "s1");
309 let s2 = g.add_node(PipelineNodeType::Source, "s2");
310 let m = g.add_node(PipelineNodeType::Merge, "m");
311 let k1 = g.add_node(PipelineNodeType::Sink, "k1");
312 let k2 = g.add_node(PipelineNodeType::Sink, "k2");
313 g.connect(s1, m).expect("connect should succeed");
314 g.connect(s2, m).expect("connect should succeed");
315 g.connect(m, k1).expect("connect should succeed");
316 g.connect(m, k2).expect("connect should succeed");
317 assert_eq!(g.sources().len(), 2);
318 assert_eq!(g.sinks().len(), 2);
319 }
320
321 #[test]
322 fn test_empty_graph_is_valid_dag() {
323 let g = PipelineGraph::new();
324 assert!(g.is_valid_dag());
325 }
326}