flow_graph_interpreter/
graph.rs

1mod error;
2mod helpers;
3mod operation_settings;
4pub(crate) mod types;
5use std::collections::HashMap;
6
7pub use error::Error as GraphError;
8use flow_expression_parser::ast::{
9  BlockExpression,
10  ConnectionExpression,
11  ConnectionTargetExpression,
12  FlowExpression,
13  InstancePort,
14  InstanceTarget,
15};
16use flow_graph::NodeReference;
17use serde_json::Value;
18use types::*;
19use wick_config::config::components::{ComponentConfig, OperationConfig};
20use wick_config::config::{ComponentImplementation, ExecutionSettings, FlowOperation};
21use wick_packet::RuntimeConfig;
22
23use self::helpers::{ensure_added, ParseHelper};
24pub(crate) use self::operation_settings::{LiquidOperationConfig, OperationSettings};
25use crate::interpreter::components::core;
26use crate::HandlerMap;
27
28pub(crate) trait NodeDecorator {
29  fn decorate(node: &mut Node) -> Result<(), String>;
30}
31
32#[derive(Debug)]
33#[must_use]
34pub(crate) struct Reference(NodeReference);
35
36impl From<&NodeReference> for Reference {
37  fn from(v: &NodeReference) -> Self {
38    Self(v.clone())
39  }
40}
41
42impl Reference {
43  pub(crate) fn name(&self) -> &str {
44    self.0.name()
45  }
46  pub(crate) fn namespace(&self) -> &str {
47    self.0.component_id()
48  }
49}
50
51fn register_operation(
52  mut scope: Vec<String>,
53  network: &mut Network,
54  flow: &mut FlowOperation,
55  handlers: &HandlerMap,
56  op_config_base: &LiquidOperationConfig,
57) -> Result<(), GraphError> {
58  scope.push(flow.name().to_owned());
59
60  for flow in flow.flows_mut() {
61    let scope = scope.clone();
62    register_operation(scope, network, flow, handlers, op_config_base)?;
63  }
64  let name = scope.join("::");
65  let mut schematic = Schematic::new(name, Default::default(), Default::default());
66  let mut ids = flow.instances().keys().cloned().collect::<Vec<_>>();
67  ids.sort();
68
69  for name in ids {
70    let def = flow.instances().get(&name).unwrap();
71    debug!(%name, config=?def.data(),settings=?def.settings(), "registering operation");
72    let mut op_config = op_config_base.clone();
73    op_config.set_template(def.data().cloned());
74
75    let node = schematic.add_and_get_mut(
76      name,
77      NodeReference::new(def.component_id(), def.name()),
78      OperationSettings::new(op_config.clone(), def.settings().cloned()),
79    );
80    helpers::decorate(def.component_id(), def.name(), node, handlers)?;
81  }
82
83  expand_until_done(&mut schematic, flow, handlers, op_config_base, expand_expressions)?;
84
85  for expression in flow.expressions() {
86    process_flow_expression(&mut schematic, expression, handlers)?;
87  }
88
89  network.add_schematic(schematic);
90  Ok(())
91}
92
93fn process_flow_expression(
94  schematic: &mut Schematic,
95  expr: &FlowExpression,
96  handlers: &HandlerMap,
97) -> Result<(), GraphError> {
98  match expr {
99    FlowExpression::ConnectionExpression(expr) => process_connection_expression(schematic, expr, handlers)?,
100    FlowExpression::BlockExpression(expr) => {
101      for expr in expr.iter() {
102        process_flow_expression(schematic, expr, handlers)?;
103      }
104    }
105  }
106  Ok(())
107}
108
109fn process_connection_expression(
110  schematic: &mut Schematic,
111  expr: &ConnectionExpression,
112  _handlers: &HandlerMap,
113) -> Result<(), GraphError> {
114  let from = expr.from();
115  let to = expr.to();
116  assert!(
117    to.port().name().is_some(),
118    "Missing downstream port for expr: {:?}",
119    expr
120  );
121  let to_port = schematic
122    .find_mut(to.instance().id().unwrap())
123    .map(|component| component.add_input(to.port().name().unwrap()));
124
125  if to_port.is_none() {
126    error!("missing downstream: instance {:?}", to);
127    return Err(GraphError::missing_downstream(to.instance().id().unwrap()));
128  }
129  let to_port = to_port.unwrap();
130
131  if let Some(component) = schematic.find_mut(from.instance().id().unwrap()) {
132    let from_port = component.add_output(from.port().name().unwrap());
133    trace!(
134      ?from_port,
135      from = %expr.from(),
136      ?to_port,
137      to = %expr.to(),
138      "graph:connecting"
139    );
140    schematic.connect(from_port, to_port, Default::default())?;
141  } else {
142    panic!("Can't find component {}", from.instance());
143  }
144  Ok(())
145}
146
147#[allow(trivial_casts)]
148fn expand_until_done(
149  schematic: &mut Schematic,
150  expressions: &mut FlowOperation,
151  handlers: &HandlerMap,
152  config: &LiquidOperationConfig,
153  func: fn(
154    &mut Schematic,
155    &mut FlowOperation,
156    &HandlerMap,
157    &LiquidOperationConfig,
158    &mut usize,
159  ) -> Result<ExpandResult, GraphError>,
160) -> Result<(), GraphError> {
161  let mut id_index = 0;
162  loop {
163    let result = func(schematic, expressions, handlers, config, &mut id_index)?;
164
165    if result == ExpandResult::Done {
166      break;
167    }
168  }
169  Ok(())
170}
171
172#[derive(Debug, PartialEq, Clone, Copy)]
173enum ExpandResult {
174  Done,
175  Continue,
176}
177
178impl ExpandResult {
179  fn update(self, next: ExpandResult) -> Self {
180    if self == ExpandResult::Continue {
181      self
182    } else {
183      next
184    }
185  }
186}
187
188#[allow(clippy::option_if_let_else)]
189fn expand_expressions(
190  schematic: &mut Schematic,
191  flow: &mut FlowOperation,
192  handlers: &HandlerMap,
193  config: &LiquidOperationConfig,
194  inline_id: &mut usize,
195) -> Result<ExpandResult, GraphError> {
196  let result = ExpandResult::Done;
197
198  let config_map = flow
199    .instances()
200    .iter()
201    .map(|(k, v)| {
202      let mut base = config.clone();
203      base.set_template(v.data().cloned());
204
205      Ok::<_, GraphError>((k.clone(), (base, v.settings().cloned())))
206    })
207    .collect::<Result<HashMap<_, _>, _>>()?;
208  add_nodes_to_schematic(schematic, flow.expressions_mut(), handlers, &config_map, inline_id)?;
209  let result = result.update(expand_port_paths(schematic, flow.expressions_mut())?);
210  let result = result.update(expand_defaulted_ports(schematic, flow.expressions_mut())?);
211  Ok(result)
212}
213
214fn add_nodes_to_schematic(
215  schem: &mut Schematic,
216  flow: &mut [FlowExpression],
217  handlers: &HandlerMap,
218  config_map: &HashMap<String, (LiquidOperationConfig, Option<ExecutionSettings>)>,
219  id_index: &mut usize,
220) -> Result<(), GraphError> {
221  for (_i, expression) in flow.iter_mut().enumerate() {
222    match expression {
223      FlowExpression::ConnectionExpression(conn) => {
224        let config = conn
225          .from()
226          .instance()
227          .id()
228          .and_then(|id| config_map.get(id).cloned())
229          .unwrap_or((LiquidOperationConfig::default(), None));
230
231        ensure_added(schem, conn.from_mut().instance_mut(), handlers, config, id_index)?;
232
233        let config = conn
234          .to()
235          .instance()
236          .id()
237          .and_then(|id| config_map.get(id).cloned())
238          .unwrap_or((LiquidOperationConfig::default(), None));
239
240        ensure_added(schem, conn.to_mut().instance_mut(), handlers, config, id_index)?;
241      }
242      FlowExpression::BlockExpression(expressions) => {
243        add_nodes_to_schematic(schem, expressions.inner_mut(), handlers, config_map, id_index)?;
244      }
245    }
246  }
247
248  Ok(())
249}
250
251fn connection(
252  from: (InstanceTarget, impl Into<InstancePort>),
253  to: (InstanceTarget, impl Into<InstancePort>),
254) -> FlowExpression {
255  FlowExpression::connection(ConnectionExpression::new(
256    ConnectionTargetExpression::new(from.0, from.1),
257    ConnectionTargetExpression::new(to.0, to.1),
258  ))
259}
260
261#[allow(clippy::option_if_let_else, clippy::too_many_lines, clippy::cognitive_complexity)]
262fn expand_defaulted_ports(
263  schematic: &mut Schematic,
264  expressions: &mut [FlowExpression],
265) -> Result<ExpandResult, GraphError> {
266  let mut result = ExpandResult::Done;
267  for (_i, expression) in expressions.iter_mut().enumerate() {
268    match expression {
269      FlowExpression::ConnectionExpression(expr) => {
270        let (from, to) = expr.clone().into_parts();
271        let (from_inst, from_port, _) = from.into_parts();
272        let (to_inst, to_port, _) = to.into_parts();
273        match (from_port, to_port) {
274          (InstancePort::None, InstancePort::None) => {
275            let from_node = schematic.get_node(&from_inst)?;
276            let to_node = schematic.get_node(&to_inst)?;
277            let from_node_ports = from_node.outputs();
278            let to_node_ports = to_node.inputs();
279            debug!(
280              from = %from_inst, from_ports = ?from_node_ports, to = %to_inst, to_ports = ?to_node_ports,
281              "graph:inferring ports for both up and down"
282            );
283            if from_node_ports.is_empty() && to_node_ports.is_empty() {
284              // can't do anything yet.
285              continue;
286            }
287
288            // If there's only one port on each side, connect them.
289            if from_node_ports.len() == 1 && to_node_ports.len() == 1 {
290              let from_port = from_node_ports[0].name();
291              let to_port = to_node_ports[0].name();
292              debug!(from = %from_inst, from_port,to = %to_inst, to_port, reason="unary", "graph:inferred ports");
293              expression.replace(connection((from_inst, from_port), (to_inst, to_port)));
294              result = ExpandResult::Continue;
295              continue;
296            }
297
298            let mut new_connections = Vec::new();
299            // if either side is a schematic input/output node, adopt the names of all ports we're pointing to.
300            if matches!(from_inst, InstanceTarget::Input | InstanceTarget::Default) {
301              for port in to_node_ports {
302                let port_name = port.name();
303                debug!(from = %from_inst, from_port=port_name,to = %to_inst, to_port=port_name, reason="upstream_default", "graph:inferred ports");
304                new_connections.push(connection((from_inst.clone(), port_name), (to_inst.clone(), port_name)));
305              }
306            } else if matches!(to_inst, InstanceTarget::Output | InstanceTarget::Default) {
307              for port in from_node_ports {
308                let port_name = port.name();
309                debug!(from = %from_inst, from_port=port_name,to = %to_inst, to_port=port_name, reason="downstream_default", "graph:inferred ports");
310                new_connections.push(connection((from_inst.clone(), port_name), (to_inst.clone(), port_name)));
311              }
312            } else {
313              for port in from_node_ports {
314                if !to_node_ports.contains(port) && !matches!(to_inst, InstanceTarget::Output | InstanceTarget::Default)
315                {
316                  return Err(GraphError::port_inference_down(
317                    &from_inst,
318                    port.name(),
319                    to_inst,
320                    to_node_ports,
321                  ));
322                }
323                let port_name = port.name();
324                debug!(from = %from_inst, from_port=port_name,to = %to_inst, to_port=port_name, reason="all_downstream", "graph:inferred ports");
325                new_connections.push(connection(
326                  (from_inst.clone(), port.name()),
327                  (to_inst.clone(), port.name()),
328                ));
329              }
330            }
331
332            assert!(!new_connections.is_empty(), "unhandled case for port inference");
333            result = ExpandResult::Continue;
334            expression.replace(FlowExpression::block(BlockExpression::new(new_connections)));
335          }
336          (InstancePort::None, to_port) => {
337            let port_name = to_port.name().unwrap();
338            let from_node = schematic.get_node(&from_inst)?;
339            let ports = from_node.outputs();
340            debug!(
341              from = %from_inst, from_ports = ?ports, to = %to_inst,
342              "graph:inferring ports for upstream"
343            );
344            // if we're at a schematic input node, adopt the name of what we're pointing to.
345            if matches!(from_inst, InstanceTarget::Input | InstanceTarget::Default) {
346              expression.replace(connection((from_inst, port_name), (to_inst, to_port.clone())));
347              result = ExpandResult::Continue;
348              continue;
349            }
350            if ports.len() == 1 {
351              expression.replace(connection((from_inst, ports[0].name()), (to_inst, to_port.clone())));
352              result = ExpandResult::Continue;
353              continue;
354            }
355
356            if !ports.iter().any(|p| p.name() == port_name) {
357              return Err(GraphError::port_inference_up(&to_inst, port_name, from_inst, ports));
358            }
359
360            result = ExpandResult::Continue;
361            expression.replace(connection((from_inst, port_name), (to_inst, to_port.clone())));
362          }
363          (from_port, InstancePort::None) => {
364            let port_name = from_port.name().unwrap();
365            let to_node = schematic.get_node(&to_inst)?;
366            let ports = to_node.inputs();
367            debug!(
368              from = %from_inst, to = %to_inst, to_ports = ?ports,
369              "graph:inferring ports for downstream"
370            );
371
372            // if we're at a schematic input node, adopt the name of what we're pointing to.
373            if matches!(to_inst, InstanceTarget::Output | InstanceTarget::Default) {
374              expression.replace(connection((from_inst, from_port.clone()), (to_inst, port_name)));
375              result = ExpandResult::Continue;
376              continue;
377            }
378
379            if ports.len() == 1 {
380              expression.replace(connection((from_inst, from_port.clone()), (to_inst, ports[0].name())));
381              result = ExpandResult::Continue;
382              continue;
383            }
384
385            if !ports.iter().any(|p| p.name() == port_name) {
386              return Err(GraphError::port_inference_down(&from_inst, port_name, to_inst, ports));
387            }
388
389            result = ExpandResult::Continue;
390            expression.replace(connection((from_inst, from_port.clone()), (to_inst, port_name)));
391          }
392          _ => continue,
393        }
394      }
395      FlowExpression::BlockExpression(expressions) => {
396        result = result.update(expand_defaulted_ports(schematic, expressions.inner_mut())?);
397      }
398    }
399  }
400  Ok(result)
401}
402
403#[allow(clippy::option_if_let_else)]
404fn expand_port_paths(
405  schematic: &mut Schematic,
406  expressions: &mut [FlowExpression],
407) -> Result<ExpandResult, GraphError> {
408  let mut result = ExpandResult::Done;
409  for (i, expression) in expressions.iter_mut().enumerate() {
410    match expression {
411      FlowExpression::ConnectionExpression(expr) => {
412        let (from, to) = expr.clone().into_parts();
413        let (from_inst, from_port, _) = from.into_parts();
414        let (to_inst, to_port, _) = to.into_parts();
415        if let InstancePort::Path(name, parts) = from_port {
416          let id = format!("{}_pluck_{}_{}_[{}]", schematic.name(), i, name, parts.join(","));
417          let config = HashMap::from([(
418            "path".to_owned(),
419            Value::Array(parts.into_iter().map(Value::String).collect()),
420          )]);
421
422          let node = schematic.add_and_get_mut(
423            &id,
424            NodeReference::new("core", "pluck"),
425            OperationSettings::new(Some(RuntimeConfig::from(config)).into(), None),
426          );
427          core::pluck::Op::decorate(node).map_err(|e| GraphError::config(id.clone(), e))?;
428
429          expression.replace(FlowExpression::block(BlockExpression::new(vec![
430            connection((from_inst, &name), (InstanceTarget::named(&id), InstancePort::None)),
431            connection((InstanceTarget::named(&id), InstancePort::None), (to_inst, to_port)),
432          ])));
433          result = ExpandResult::Continue;
434        }
435      }
436      FlowExpression::BlockExpression(expressions) => {
437        result = result.update(expand_port_paths(schematic, expressions.inner_mut())?);
438      }
439    }
440  }
441  Ok(result)
442}
443
444pub fn from_def(
445  manifest: &mut wick_config::config::ComponentConfiguration,
446  handlers: &HandlerMap,
447) -> Result<Network, GraphError> {
448  let mut network = Network::new(
449    manifest.name().cloned().unwrap_or_default(),
450    OperationSettings::new(manifest.root_config().cloned().into(), None),
451  );
452
453  let mut op_config_base = LiquidOperationConfig::default();
454  op_config_base.set_root(manifest.root_config().cloned());
455
456  if let ComponentImplementation::Composite(composite) = manifest.component_mut() {
457    for flow in composite.operations_mut() {
458      register_operation(vec![], &mut network, flow, handlers, &op_config_base)?;
459    }
460  }
461
462  #[cfg(debug_assertions)]
463  {
464    let names: Vec<_> = network.schematics().iter().map(|s| s.name()).collect();
465    trace!(nodes=?names,"graph:nodes");
466    for schematic in network.schematics() {
467      let schem_name = &schematic.name();
468      for node in schematic.nodes() {
469        let name = &node.name;
470        let inputs = node.inputs().iter().map(|n| n.name()).collect::<Vec<_>>();
471        let outputs = node.outputs().iter().map(|n| n.name()).collect::<Vec<_>>();
472        trace!(schematic = schem_name, node = name, ?inputs, ?outputs, data=?node.data(), "graph:node");
473      }
474    }
475  }
476
477  Ok(network)
478}