1crate::ix!();
3
4pub fn wire_up_network<NetworkItem>(
7 net: &mut Network<NetworkItem>
8) -> NetResult<()>
9where
10 NetworkItem: Debug + Send + Sync + Default,
11{
12 for (node_idx, node) in net.nodes_mut().iter_mut().enumerate() {
15 let out_count = node.operator().output_count();
16 if out_count > 4 {
17 return Err(NetworkError::InvalidConfiguration {
18 details: format!(
19 "Node #{} operator output_count={} exceeds the max of 4 ports",
20 node_idx, out_count
21 ),
22 });
23 }
24 for port_idx in 0..out_count {
26 node.outputs_mut()[port_idx] = Some(Arc::new(AsyncRwLock::new(NetworkItem::default())));
27 }
28 for port_idx in out_count..4 {
30 node.outputs_mut()[port_idx] = None;
31 }
32 }
33
34 let mut used_input_count: Vec<usize> = vec![0; net.nodes().len()];
37 let mut input_usage = vec![[0usize; 4]; net.nodes().len()]; let mut output_usage = vec![[0usize; 4]; net.nodes().len()]; let edges = net.edges().clone();
43
44 for (edge_idx, edge) in edges.iter().enumerate() {
47 let src = edge.source_index();
48 let so = edge.source_output_idx();
49 let dst = edge.dest_index();
50 let di = edge.dest_input_idx();
51
52 if *src >= net.nodes().len() || *dst >= net.nodes().len() {
54 return Err(NetworkError::InvalidConfiguration {
55 details: format!(
56 "Edge #{} references invalid node index (src={}, dst={}, node_count={})",
57 edge_idx, src, dst, net.nodes().len()
58 ),
59 });
60 }
61 let src_op = &net.nodes()[*src].operator();
62 let dst_op = &net.nodes()[*dst].operator();
63
64 if *so >= src_op.output_count() {
66 return Err(NetworkError::InvalidConfiguration {
67 details: format!(
68 "Edge #{} references source node {} output port {}, but operator only has {} outputs",
69 edge_idx, src, so, src_op.output_count()
70 ),
71 });
72 }
73 if *di >= dst_op.input_count() {
75 return Err(NetworkError::InvalidConfiguration {
76 details: format!(
77 "Edge #{} references dest node {} input port {}, but operator only has {} inputs",
78 edge_idx, dst, di, dst_op.input_count()
79 ),
80 });
81 }
82
83 if *so >= 4 || *di >= 4 {
85 return Err(NetworkError::InvalidConfiguration {
86 details: format!(
87 "Edge #{} references port so={}, di={} >= 4, out of range for 4 ports",
88 edge_idx, so, di
89 ),
90 });
91 }
92
93 let out_str = match src_op.output_port_type_str(*so) {
96 Some(s) => s,
97 None => {
98 return Err(NetworkError::InvalidConfiguration {
99 details: format!(
100 "Edge #{} source node {} output port {} has no declared type string",
101 edge_idx, src, so
102 ),
103 });
104 }
105 };
106 let in_str = match dst_op.input_port_type_str(*di) {
107 Some(s) => s,
108 None => {
109 return Err(NetworkError::InvalidConfiguration {
110 details: format!(
111 "Edge #{} dest node {} input port {} has no declared type string",
112 edge_idx, dst, di
113 ),
114 });
115 }
116 };
117 if out_str != in_str {
118 return Err(NetworkError::InvalidConfiguration {
119 details: format!(
120 "Edge #{} type mismatch: source node {} output port {} => {:?} != dest node {} input port {} => {:?}",
121 edge_idx, src, so, out_str, dst, di, in_str
122 ),
123 });
124 }
125
126 let arc_opt = net.nodes()[*src].outputs()[*so].clone();
128 if arc_opt.is_none() {
129 return Err(NetworkError::InvalidConfiguration {
130 details: format!(
131 "Edge #{} references source node {} output port {}, but that port is None",
132 edge_idx, src, so
133 ),
134 });
135 }
136 net.nodes_mut()[*dst].inputs_mut()[*di] = arc_opt;
137
138 used_input_count[*dst] += 1;
140 input_usage[*dst][*di] += 1;
141 output_usage[*src][*so] += 1;
142 }
143
144 for (node_idx, node) in net.nodes().iter().enumerate() {
150 let op = node.operator();
151 let in_count = op.input_count();
152 let out_count = op.output_count();
153 let used = used_input_count[node_idx];
154
155 if used != in_count {
157 return Err(NetworkError::InvalidConfiguration {
158 details: format!(
159 "Node #{} operator expects {} inputs, but wired edges used {}",
160 node_idx, in_count, used
161 ),
162 });
163 }
164
165 for i in 0..in_count {
167 if input_usage[node_idx][i] > 1 {
168 return Err(NetworkError::InvalidConfiguration {
169 details: format!(
170 "Node #{} input port {} is fed by multiple edges ({}), which is not allowed",
171 node_idx, i, input_usage[node_idx][i]
172 ),
173 });
174 }
175 if op.input_port_connection_required(i) && input_usage[node_idx][i] == 0 {
176 return Err(NetworkError::InvalidConfiguration {
177 details: format!(
178 "Node #{} input port {} is required but has no incoming edge",
179 node_idx, i
180 ),
181 });
182 }
183 }
184
185 for o in 0..out_count {
188 if op.output_port_connection_required(o) && output_usage[node_idx][o] == 0 {
189 return Err(NetworkError::InvalidConfiguration {
190 details: format!(
191 "Node #{} output port {} is required but has no downstream edge",
192 node_idx, o
193 ),
194 });
195 }
196 }
197 }
198
199 Ok(())
200}
201
202#[cfg(test)]
206mod wire_up_network_tests {
207
208 use super::*; #[test]
213 fn test_wire_up_single_constantop_ok() -> Result<(), NetworkError> {
214
215 let n0: NetworkNode<TestWireIO<i32>> = node!(0 => ConstantOp::new(100));
216
217 let mut net = NetworkBuilder::<TestWireIO<i32>>::default()
218 .nodes(vec![n0])
219 .edges(vec![])
220 .build()
221 .unwrap();
222
223 wire_up_network(&mut net)?;
225
226 assert!(net.nodes()[0].outputs()[0].is_some());
227 assert!(net.nodes()[0].outputs()[1].is_none());
228
229 for i in 0..4 {
231 assert!(net.nodes()[0].inputs()[i].is_none());
232 }
233
234 Ok(())
235 }
236
237 #[test]
239 fn test_wire_up_single_addop_mismatch() {
240 let n0: NetworkNode<TestWireIO<i32>> = node!(0 => AddOp::new(10));
241 let mut net = NetworkBuilder::<TestWireIO<i32>>::default()
242 .nodes(vec![n0])
243 .edges(vec![])
244 .build()
245 .unwrap();
246
247 let res = wire_up_network(&mut net);
248 assert!(res.is_err());
249 if let Err(NetworkError::InvalidConfiguration{ details }) = res {
250 assert!(
251 details.contains("expects 1 inputs, but wired edges used 0"),
252 "Expected an input_count mismatch error, got: {}",
253 details
254 );
255 }
256 }
257
258 #[test]
261 fn test_wire_up_two_nodes_ok() -> Result<(), NetworkError> {
262 let n0: NetworkNode<TestWireIO<i32>> = node!(0 => ConstantOp::new(42));
263 let n1: NetworkNode<TestWireIO<i32>> = node!(1 => AddOp::new(5));
264 let e = edge!(0:0 -> 1:0);
265
266 let mut net = NetworkBuilder::<TestWireIO<i32>>::default()
267 .nodes(vec![n0, n1])
268 .edges(vec![e])
269 .build()
270 .unwrap();
271
272 wire_up_network(&mut net)?;
273
274 assert!(net.nodes()[0].outputs()[0].is_some());
276 assert!(net.nodes()[1].inputs()[0].is_some());
278
279 Ok(())
280 }
281
282 #[test]
287 fn test_wire_up_fanout_ok() -> Result<(), NetworkError> {
288 let n0: NetworkNode<TestWireIO<i32>> = node!(0 => ConstantOp::new(10));
290 let n1: NetworkNode<TestWireIO<i32>> = node!(1 => AddOp::new(100));
291 let n2: NetworkNode<TestWireIO<i32>> = node!(2 => MultiplyOp::new(2));
292
293 let e1 = edge!(0:0 -> 1:0);
294 let e2 = edge!(0:0 -> 2:0);
295
296 let mut net = NetworkBuilder::<TestWireIO<i32>>::default()
297 .nodes(vec![n0, n1, n2])
298 .edges(vec![e1, e2])
299 .build()
300 .unwrap();
301
302 wire_up_network(&mut net)?;
303 assert!(net.nodes()[0].outputs()[0].is_some());
305 assert!(net.nodes()[1].inputs()[0].is_some());
307 assert!(net.nodes()[2].inputs()[0].is_some());
309 Ok(())
310 }
311
312 #[test]
315 fn test_wire_up_double_feed_same_input_port() {
316 let n0: NetworkNode<TestWireIO<i32>> = node!(0 => ConstantOp::new(42));
318 let n1: NetworkNode<TestWireIO<i32>> = node!(1 => AddOp::new(5));
319 let n2: NetworkNode<TestWireIO<i32>> = node!(2 => ConstantOp::new(123));
320
321 let e0 = edge!(0:0 -> 1:0);
325 let e1 = edge!(2:0 -> 1:0);
326
327 let mut net = NetworkBuilder::<TestWireIO<i32>>::default()
328 .nodes(vec![n0, n1, n2])
329 .edges(vec![e0, e1])
330 .build()
331 .unwrap();
332
333 let res = wire_up_network(&mut net);
334 assert!(res.is_err());
335 if let Err(NetworkError::InvalidConfiguration{details}) = res {
336 assert!(
337 details.contains("expects 1 inputs, but wired edges used 2"),
338 "Expected mismatch for node=1 with 2 edges into the same port. Got: {}",
339 details
340 );
341 }
342 }
343
344 #[test]
348 fn test_wire_up_operator_with_2_outputs_but_only_1_used() -> Result<(), NetworkError> {
349 let n0: NetworkNode<TestWireIO<i32>> = node!(0 => DoubleOutOp::default());
352 let n1: NetworkNode<TestWireIO<i32>> = node!(1 => MultiplyOp::new(2));
353 let nX: NetworkNode<TestWireIO<i32>> = node!(999 => ConstantOp::new(77));
358 let eX = edge!(999:0 -> 0:0);
360
361 let e0 = edge!(0:0 -> 1:0);
362
363 let mut net = NetworkBuilder::<TestWireIO<i32>>::default()
364 .nodes(vec![nX, n0, n1]) .edges(vec![ eX, e0 ])
367 .build()
368 .unwrap();
369
370 let n0const: NetworkNode<TestWireIO<i32>> = node!(0 => ConstantOp::new(77));
377 let n1dbl: NetworkNode<TestWireIO<i32>> = node!(1 => DoubleOutOp::default());
378 let n2mul: NetworkNode<TestWireIO<i32>> = node!(2 => MultiplyOp::new(2));
379 let eA = edge!(0:0 -> 1:0); let eB = edge!(1:0 -> 2:0); let mut net = NetworkBuilder::<TestWireIO<i32>>::default()
382 .nodes(vec![n0const, n1dbl, n2mul])
383 .edges(vec![eA, eB])
384 .build()
385 .unwrap();
386
387 wire_up_network(&mut net)?;
388 assert!(net.nodes()[1].outputs()[0].is_some());
390 assert!(net.nodes()[1].outputs()[1].is_some());
391 assert!(net.nodes()[2].inputs()[0].is_some());
393 Ok(())
395 }
396}