1use crate::graph::GraphTrait;
2use crate::operation::marker::MarkerSet;
3use crate::operation::{OperationError, OperationResult};
4use crate::semantics::AbstractGraph;
5use crate::util::bimap::BiMap;
6use crate::util::{InternString, log};
7use crate::{NodeKey, Semantics, SubstMarker, interned_string_newtype};
8use derive_more::From;
9use error_stack::bail;
10use petgraph::visit::UndirectedAdaptor;
11use serde::{Deserialize, Serialize};
12use std::borrow::Cow;
13use std::cell::RefCell;
14use std::collections::{HashMap, HashSet};
15use thiserror::Error;
16#[derive(Debug, Error)]
20pub enum OperationParameterError {
21 #[error(
22 "Context node {0:?} is not connected to any explicit input nodes in the parameter graph"
23 )]
24 ContextNodeNotConnected(SubstMarker),
25}
26
27#[cfg_attr(
28 feature = "serde",
29 derive(Serialize, Deserialize),
30 serde(bound = "S: crate::serde::SemanticsSerde")
31)]
32pub struct OperationParameter<S: Semantics> {
33 pub explicit_input_nodes: Vec<SubstMarker>,
35 pub parameter_graph: AbstractGraph<S>,
37 pub node_keys_to_subst: BiMap<NodeKey, SubstMarker>,
41}
42
43impl<S: Semantics> PartialEq for OperationParameter<S> {
44 fn eq(&self, other: &Self) -> bool {
45 self.explicit_input_nodes == other.explicit_input_nodes
47 && self
48 .parameter_graph
49 .semantically_matches_with_same_keys(&other.parameter_graph)
50 && self.node_keys_to_subst == other.node_keys_to_subst
51 }
52}
53
54impl<S: Semantics> Clone for OperationParameter<S> {
55 fn clone(&self) -> Self {
56 OperationParameter {
57 explicit_input_nodes: self.explicit_input_nodes.clone(),
58 parameter_graph: self.parameter_graph.clone(),
59 node_keys_to_subst: self.node_keys_to_subst.clone(),
60 }
61 }
62}
63
64impl<S: Semantics> OperationParameter<S> {
65 pub fn new_empty() -> Self {
66 OperationParameter {
67 explicit_input_nodes: Vec::new(),
68 parameter_graph: AbstractGraph::<S>::new(),
69 node_keys_to_subst: BiMap::new(),
70 }
71 }
72
73 pub fn check_validity(&self) -> Result<(), OperationParameterError> {
74 let undi = UndirectedAdaptor(&self.parameter_graph.graph);
76 let components = petgraph::algo::tarjan_scc(&undi);
77
78 for component in components {
79 let mut contains_explicit_input = false;
80 for key in &component {
81 let subst_marker = self
82 .node_keys_to_subst
83 .get_left(key)
84 .expect("internal error: should find subst marker for node key");
85 if self.explicit_input_nodes.contains(subst_marker) {
86 contains_explicit_input = true;
87 break;
88 }
89 }
90 if !contains_explicit_input {
91 let example_context_node = component[0];
92 let subst_marker = self
93 .node_keys_to_subst
94 .get_left(&example_context_node)
95 .expect("internal error: should find subst marker for node key");
96 return Err(OperationParameterError::ContextNodeNotConnected(
97 *subst_marker,
98 ));
99 }
100 }
101
102 Ok(())
103 }
104}
105
106#[derive(Debug)]
109pub struct ParameterSubstitution {
110 pub mapping: HashMap<SubstMarker, NodeKey>,
111}
112
113impl ParameterSubstitution {
114 pub fn new(mapping: HashMap<SubstMarker, NodeKey>) -> Self {
115 ParameterSubstitution { mapping }
116 }
117
118 pub fn infer_explicit_for_param(
119 selected_nodes: &[NodeKey],
120 param: &OperationParameter<impl Semantics>,
121 ) -> OperationResult<Self> {
122 if param.explicit_input_nodes.len() != selected_nodes.len() {
123 bail!(OperationError::InvalidOperationArgumentCount {
124 expected: param.explicit_input_nodes.len(),
125 actual: selected_nodes.len(),
126 });
127 }
128
129 let mapping = param
130 .explicit_input_nodes
131 .iter()
132 .zip(selected_nodes.iter())
133 .map(|(subst_marker, node_key)| (subst_marker.clone(), *node_key))
134 .collect();
135 Ok(ParameterSubstitution { mapping })
136 }
137}
138
139#[derive(Debug, Clone, Copy, From, Hash, Eq, PartialEq)]
140#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
141pub enum NewNodeMarker {
142 Named(InternString),
143 #[from(ignore)]
145 Implicit(u32),
146}
147interned_string_newtype!(NewNodeMarker, NewNodeMarker::Named);
148
149#[derive(Debug, Clone, Copy, From, Hash, Eq, PartialEq)]
151pub enum NodeMarker {
152 Subst(SubstMarker),
153 New(NewNodeMarker),
154}
155
156pub struct GraphWithSubstitution<'a, G: GraphTrait> {
157 pub graph: &'a mut G,
158 pub subst: &'a ParameterSubstitution,
160 new_nodes_map: HashMap<NewNodeMarker, NodeKey>,
162 max_new_node_marker: u32,
163 new_nodes: Vec<NodeKey>,
165 new_edges: Vec<(NodeKey, NodeKey)>,
166 removed_nodes: Vec<NodeKey>,
167 removed_edges: Vec<(NodeKey, NodeKey)>,
168 changed_node_av: HashMap<NodeKey, G::NodeAttr>,
169 changed_edge_av: HashMap<(NodeKey, NodeKey), G::EdgeAttr>,
170}
171
172impl<'a, G: GraphTrait<NodeAttr: Clone, EdgeAttr: Clone>> GraphWithSubstitution<'a, G> {
173 pub fn new(graph: &'a mut G, subst: &'a ParameterSubstitution) -> Self {
174 GraphWithSubstitution {
175 graph,
176 subst,
177 new_nodes_map: HashMap::new(),
178 max_new_node_marker: 0,
179 new_nodes: Vec::new(),
180 new_edges: Vec::new(),
181 removed_nodes: Vec::new(),
182 removed_edges: Vec::new(),
183 changed_node_av: HashMap::new(),
184 changed_edge_av: HashMap::new(),
185 }
186 }
187
188 pub fn get_node_key(&self, marker: &NodeMarker) -> Option<NodeKey> {
189 let found_key = match marker {
190 NodeMarker::Subst(sm) => {
191 self.subst.mapping.get(&sm).copied()
193 }
194 NodeMarker::New(nnm) => self.new_nodes_map.get(&nnm).copied(),
195 };
196 if let Some(key) = found_key {
197 if self.removed_nodes.contains(&key) {
198 return None;
200 }
201 }
202 found_key
203 }
204
205 pub fn new_node_marker(&mut self) -> NewNodeMarker {
206 let marker = NewNodeMarker::Implicit(self.max_new_node_marker);
207 self.max_new_node_marker += 1;
208 marker
209 }
210
211 pub fn add_node(&mut self, marker: impl Into<NewNodeMarker>, value: G::NodeAttr) {
212 let marker = marker.into();
213 if self.get_node_key(&NodeMarker::New(marker)).is_some() {
215 panic!(
218 "Marker {:?} already exists in the substitution mapping",
219 marker
220 );
221 }
222 let node_key = self.graph.add_node(value);
223 self.new_nodes.push(node_key);
224 self.new_nodes_map.insert(marker, node_key);
225 }
226 pub fn delete_node(&mut self, marker: impl Into<NodeMarker>) -> Option<G::NodeAttr> {
227 let marker = marker.into();
228 let Some(node_key) = self.get_node_key(&marker) else {
229 return None; };
231 let removed_value = self.graph.delete_node(node_key);
232 if removed_value.is_some() {
233 self.removed_nodes.push(node_key);
234 }
235 removed_value
236 }
237
238 pub fn add_edge(
239 &mut self,
240 src_marker: impl Into<NodeMarker>,
241 dst_marker: impl Into<NodeMarker>,
242 value: G::EdgeAttr,
243 ) -> Option<G::EdgeAttr> {
244 let src_marker = src_marker.into();
245 let dst_marker = dst_marker.into();
246 let src_key = self.get_node_key(&src_marker)?;
247 let dst_key = self.get_node_key(&dst_marker)?;
248 self.new_edges.push((src_key, dst_key));
249 self.graph.add_edge(src_key, dst_key, value)
250 }
251
252 pub fn delete_edge(
253 &mut self,
254 src_marker: impl Into<NodeMarker>,
255 dst_marker: impl Into<NodeMarker>,
256 ) -> Option<G::EdgeAttr> {
257 let src_marker = src_marker.into();
258 let dst_marker = dst_marker.into();
259 let src_key = self.get_node_key(&src_marker)?;
260 let dst_key = self.get_node_key(&dst_marker)?;
261 let removed_value = self.graph.delete_edge(src_key, dst_key);
262 if removed_value.is_some() {
263 self.removed_edges.push((src_key, dst_key));
264 }
265 removed_value
266 }
267
268 pub fn get_node_value(&self, marker: impl Into<NodeMarker>) -> Option<&G::NodeAttr> {
269 let marker = marker.into();
270 self.get_node_key(&marker)
271 .and_then(|node_key| self.graph.get_node_attr(node_key))
272 }
273
274 pub fn set_node_value(
278 &mut self,
279 marker: impl Into<NodeMarker>,
280 value: G::NodeAttr,
281 ) -> Option<G::NodeAttr> {
282 let marker = marker.into();
283 let node_key = self.get_node_key(&marker)?;
284 self.changed_node_av.insert(node_key, value.clone());
285 let old_value = self.graph.set_node_attr(node_key, value.clone());
286 if old_value.is_some() {
287 self.changed_node_av.insert(node_key, value);
289 }
290 old_value
291 }
292
293 pub fn maybe_set_node_value(
297 &mut self,
298 marker: impl Into<NodeMarker>,
299 maybe_written_av: G::NodeAttr,
300 join: impl Fn(&G::NodeAttr, &G::NodeAttr) -> Option<G::NodeAttr>,
301 ) -> Option<G::NodeAttr> {
302 let marker = marker.into();
303 let node_key = self.get_node_key(&marker)?;
304 if let Some(old_av) = self.graph.get_node_attr(node_key) {
305 self.changed_node_av
307 .insert(node_key, maybe_written_av.clone());
308 let merged_av = join(old_av, &maybe_written_av)
310 .expect("must be able to join. TODO: think about if this requirement makes sense");
317 self.graph.set_node_attr(node_key, merged_av)
319 } else {
320 None
321 }
322 }
323
324 pub fn get_edge_value(
325 &self,
326 src_marker: impl Into<NodeMarker>,
327 dst_marker: impl Into<NodeMarker>,
328 ) -> Option<&G::EdgeAttr> {
329 let src_marker = src_marker.into();
330 let dst_marker = dst_marker.into();
331 let src_key = self.get_node_key(&src_marker)?;
332 let dst_key = self.get_node_key(&dst_marker)?;
333 self.graph.get_edge_attr((src_key, dst_key))
334 }
335
336 pub fn set_edge_value(
337 &mut self,
338 src_marker: impl Into<NodeMarker>,
339 dst_marker: impl Into<NodeMarker>,
340 value: G::EdgeAttr,
341 ) -> Option<G::EdgeAttr> {
342 let src_marker = src_marker.into();
343 let dst_marker = dst_marker.into();
344 let src_key = self.get_node_key(&src_marker)?;
345 let dst_key = self.get_node_key(&dst_marker)?;
346 self.changed_edge_av
347 .insert((src_key, dst_key), value.clone());
348 let old_value = self.graph.set_edge_attr((src_key, dst_key), value.clone());
349 if old_value.is_some() {
350 self.changed_edge_av.insert((src_key, dst_key), value);
352 } else {
353 log::warn!(
354 "Attempted to set edge value for non-existing edge from {:?} to {:?}.",
355 src_key,
356 dst_key
357 );
358 }
359 old_value
360 }
361
362 pub fn maybe_set_edge_value(
363 &mut self,
364 src_marker: impl Into<NodeMarker>,
365 dst_marker: impl Into<NodeMarker>,
366 maybe_written_av: G::EdgeAttr,
367 join: impl Fn(&G::EdgeAttr, &G::EdgeAttr) -> Option<G::EdgeAttr>,
368 ) -> Option<G::EdgeAttr> {
369 let src_marker = src_marker.into();
370 let dst_marker = dst_marker.into();
371 let src_key = self.get_node_key(&src_marker)?;
372 let dst_key = self.get_node_key(&dst_marker)?;
373 if let Some(old_av) = self.graph.get_edge_attr((src_key, dst_key)) {
374 self.changed_edge_av
376 .insert((src_key, dst_key), maybe_written_av.clone());
377 let merged_av = join(old_av, &maybe_written_av)
379 .expect("must be able to join. TODO: think about if this requirement makes sense");
380 self.graph.set_edge_attr((src_key, dst_key), merged_av)
382 } else {
383 log::warn!(
384 "Attempted to set edge value for non-existing edge from {:?} to {:?}.",
385 src_key,
386 dst_key
387 );
388 None
389 }
390 }
391
392 fn get_new_nodes_and_edges_from_desired_names(
393 &self,
394 desired_node_output_names: &HashMap<NewNodeMarker, AbstractOutputNodeMarker>,
395 ) -> (
396 HashMap<AbstractOutputNodeMarker, NodeKey>,
397 Vec<(NodeKey, NodeKey)>,
398 ) {
399 let mut new_nodes = HashMap::new();
400 for (marker, node_key) in &self.new_nodes_map {
401 let Some(output_marker) = desired_node_output_names.get(&marker) else {
402 continue;
403 };
404 new_nodes.insert(*output_marker, *node_key);
405 }
406 let mut new_edges = Vec::new();
407 let new_node_or_existing = |node_key: &NodeKey| {
408 new_nodes.values().any(|&n| n == *node_key)
409 || self.subst.mapping.values().any(|&n| n == *node_key)
410 };
411 for (src_key, dst_key) in &self.new_edges {
413 if new_node_or_existing(src_key) || new_node_or_existing(dst_key) {
414 new_edges.push((*src_key, *dst_key));
415 }
416 }
417 (new_nodes, new_edges)
418 }
419
420 pub fn get_abstract_output<
421 S: Semantics<NodeAbstract = G::NodeAttr, EdgeAbstract = G::EdgeAttr>,
422 >(
423 &self,
424 desired_node_output_names: HashMap<NewNodeMarker, AbstractOutputNodeMarker>,
425 ) -> AbstractOperationOutput<S> {
426 let (new_nodes, new_edges) =
428 self.get_new_nodes_and_edges_from_desired_names(&desired_node_output_names);
429
430 let existing_nodes: HashSet<NodeKey> = self.subst.mapping.values().cloned().collect();
433 let mut existing_edges = HashSet::new();
434 for (src, dst, _) in self.graph.edges() {
435 existing_edges.insert((src, dst));
436 }
437 let mut changed_abstract_values_nodes = HashMap::new();
438 for (node_key, node_av) in &self.changed_node_av {
439 if existing_nodes.contains(node_key) {
440 changed_abstract_values_nodes.insert(*node_key, node_av.clone());
441 }
442 }
443 let mut changed_abstract_edges = HashMap::new();
444 for (&(src, dst), edge_av) in &self.changed_edge_av {
445 if existing_edges.contains(&(src, dst)) {
446 changed_abstract_edges.insert((src, dst), edge_av.clone());
447 }
448 }
449
450 AbstractOperationOutput {
451 new_nodes,
452 new_edges,
454 removed_edges: self.removed_edges.clone(),
457 removed_nodes: self.removed_nodes.clone(),
458 changed_abstract_values_nodes,
459 changed_abstract_values_edges: changed_abstract_edges,
460 }
461 }
462
463 pub fn get_concrete_output(
464 &self,
465 desired_node_output_names: HashMap<NewNodeMarker, AbstractOutputNodeMarker>,
466 ) -> OperationOutput {
467 let (new_nodes, _new_edges) =
468 self.get_new_nodes_and_edges_from_desired_names(&desired_node_output_names);
469
470 OperationOutput {
471 new_nodes,
472 removed_nodes: self.removed_nodes.clone(),
474 }
475 }
476}
477
478#[derive(Debug)]
479pub struct OperationArgument<'a> {
480 pub selected_input_nodes: Cow<'a, [NodeKey]>,
481 pub subst: ParameterSubstitution,
484 pub hidden_nodes: HashSet<NodeKey>,
488 pub marker_set: &'a RefCell<MarkerSet>,
489}
490
491#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, From)]
492#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
493pub struct AbstractOutputNodeMarker(pub InternString);
494interned_string_newtype!(AbstractOutputNodeMarker);
495
496pub struct OperationOutput {
503 pub new_nodes: HashMap<AbstractOutputNodeMarker, NodeKey>,
504 pub removed_nodes: Vec<NodeKey>,
507}
508
509impl OperationOutput {
510 pub fn no_changes() -> Self {
511 OperationOutput {
512 new_nodes: HashMap::new(),
513 removed_nodes: Vec::new(),
514 }
515 }
516}
517
518pub struct AbstractOperationOutput<S: Semantics> {
519 pub new_nodes: HashMap<AbstractOutputNodeMarker, NodeKey>,
520 pub removed_nodes: Vec<NodeKey>,
521 pub new_edges: Vec<(NodeKey, NodeKey)>,
522 pub removed_edges: Vec<(NodeKey, NodeKey)>,
523 pub changed_abstract_values_nodes: HashMap<NodeKey, S::NodeAbstract>,
525 pub changed_abstract_values_edges: HashMap<(NodeKey, NodeKey), S::EdgeAbstract>,
526}
527