use crate::graph::GraphTrait;
use crate::operation::marker::MarkerSet;
use crate::operation::{OperationError, OperationResult};
use crate::semantics::AbstractGraph;
use crate::util::bimap::BiMap;
use crate::util::{InternString, log};
use crate::{NodeKey, Semantics, SubstMarker, interned_string_newtype};
use derive_more::From;
use error_stack::bail;
use petgraph::visit::UndirectedAdaptor;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum OperationParameterError {
#[error(
"Context node {0:?} is not connected to any explicit input nodes in the parameter graph"
)]
ContextNodeNotConnected(SubstMarker),
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(bound = "S: crate::serde::SemanticsSerde")
)]
pub struct OperationParameter<S: Semantics> {
pub explicit_input_nodes: Vec<SubstMarker>,
pub parameter_graph: AbstractGraph<S>,
pub node_keys_to_subst: BiMap<NodeKey, SubstMarker>,
}
impl<S: Semantics> PartialEq for OperationParameter<S> {
fn eq(&self, other: &Self) -> bool {
self.explicit_input_nodes == other.explicit_input_nodes
&& self
.parameter_graph
.semantically_matches_with_same_keys(&other.parameter_graph)
&& self.node_keys_to_subst == other.node_keys_to_subst
}
}
impl<S: Semantics> Clone for OperationParameter<S> {
fn clone(&self) -> Self {
OperationParameter {
explicit_input_nodes: self.explicit_input_nodes.clone(),
parameter_graph: self.parameter_graph.clone(),
node_keys_to_subst: self.node_keys_to_subst.clone(),
}
}
}
impl<S: Semantics> OperationParameter<S> {
pub fn new_empty() -> Self {
OperationParameter {
explicit_input_nodes: Vec::new(),
parameter_graph: AbstractGraph::<S>::new(),
node_keys_to_subst: BiMap::new(),
}
}
pub fn check_validity(&self) -> Result<(), OperationParameterError> {
let undi = UndirectedAdaptor(&self.parameter_graph.graph);
let components = petgraph::algo::tarjan_scc(&undi);
for component in components {
let mut contains_explicit_input = false;
for key in &component {
let subst_marker = self
.node_keys_to_subst
.get_left(key)
.expect("internal error: should find subst marker for node key");
if self.explicit_input_nodes.contains(subst_marker) {
contains_explicit_input = true;
break;
}
}
if !contains_explicit_input {
let example_context_node = component[0];
let subst_marker = self
.node_keys_to_subst
.get_left(&example_context_node)
.expect("internal error: should find subst marker for node key");
return Err(OperationParameterError::ContextNodeNotConnected(
*subst_marker,
));
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct ParameterSubstitution {
pub mapping: HashMap<SubstMarker, NodeKey>,
}
impl ParameterSubstitution {
pub fn new(mapping: HashMap<SubstMarker, NodeKey>) -> Self {
ParameterSubstitution { mapping }
}
pub fn infer_explicit_for_param(
selected_nodes: &[NodeKey],
param: &OperationParameter<impl Semantics>,
) -> OperationResult<Self> {
if param.explicit_input_nodes.len() != selected_nodes.len() {
bail!(OperationError::InvalidOperationArgumentCount {
expected: param.explicit_input_nodes.len(),
actual: selected_nodes.len(),
});
}
let mapping = param
.explicit_input_nodes
.iter()
.zip(selected_nodes.iter())
.map(|(subst_marker, node_key)| (subst_marker.clone(), *node_key))
.collect();
Ok(ParameterSubstitution { mapping })
}
}
#[derive(Debug, Clone, Copy, From, Hash, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum NewNodeMarker {
Named(InternString),
#[from(ignore)]
Implicit(u32),
}
interned_string_newtype!(NewNodeMarker, NewNodeMarker::Named);
#[derive(Debug, Clone, Copy, From, Hash, Eq, PartialEq)]
pub enum NodeMarker {
Subst(SubstMarker),
New(NewNodeMarker),
}
pub struct GraphWithSubstitution<'a, G: GraphTrait> {
pub graph: &'a mut G,
pub subst: &'a ParameterSubstitution,
new_nodes_map: HashMap<NewNodeMarker, NodeKey>,
max_new_node_marker: u32,
new_nodes: Vec<NodeKey>,
new_edges: Vec<(NodeKey, NodeKey)>,
removed_nodes: Vec<NodeKey>,
removed_edges: Vec<(NodeKey, NodeKey)>,
changed_node_av: HashMap<NodeKey, G::NodeAttr>,
changed_edge_av: HashMap<(NodeKey, NodeKey), G::EdgeAttr>,
}
impl<'a, G: GraphTrait<NodeAttr: Clone, EdgeAttr: Clone>> GraphWithSubstitution<'a, G> {
pub fn new(graph: &'a mut G, subst: &'a ParameterSubstitution) -> Self {
GraphWithSubstitution {
graph,
subst,
new_nodes_map: HashMap::new(),
max_new_node_marker: 0,
new_nodes: Vec::new(),
new_edges: Vec::new(),
removed_nodes: Vec::new(),
removed_edges: Vec::new(),
changed_node_av: HashMap::new(),
changed_edge_av: HashMap::new(),
}
}
pub fn get_node_key(&self, marker: &NodeMarker) -> Option<NodeKey> {
let found_key = match marker {
NodeMarker::Subst(sm) => {
self.subst.mapping.get(&sm).copied()
}
NodeMarker::New(nnm) => self.new_nodes_map.get(&nnm).copied(),
};
if let Some(key) = found_key {
if self.removed_nodes.contains(&key) {
return None;
}
}
found_key
}
pub fn new_node_marker(&mut self) -> NewNodeMarker {
let marker = NewNodeMarker::Implicit(self.max_new_node_marker);
self.max_new_node_marker += 1;
marker
}
pub fn add_node(&mut self, marker: impl Into<NewNodeMarker>, value: G::NodeAttr) {
let marker = marker.into();
if self.get_node_key(&NodeMarker::New(marker)).is_some() {
panic!(
"Marker {:?} already exists in the substitution mapping",
marker
);
}
let node_key = self.graph.add_node(value);
self.new_nodes.push(node_key);
self.new_nodes_map.insert(marker, node_key);
}
pub fn delete_node(&mut self, marker: impl Into<NodeMarker>) -> Option<G::NodeAttr> {
let marker = marker.into();
let Some(node_key) = self.get_node_key(&marker) else {
return None; };
let removed_value = self.graph.delete_node(node_key);
if removed_value.is_some() {
self.removed_nodes.push(node_key);
}
removed_value
}
pub fn add_edge(
&mut self,
src_marker: impl Into<NodeMarker>,
dst_marker: impl Into<NodeMarker>,
value: G::EdgeAttr,
) -> Option<G::EdgeAttr> {
let src_marker = src_marker.into();
let dst_marker = dst_marker.into();
let src_key = self.get_node_key(&src_marker)?;
let dst_key = self.get_node_key(&dst_marker)?;
self.new_edges.push((src_key, dst_key));
self.graph.add_edge(src_key, dst_key, value)
}
pub fn delete_edge(
&mut self,
src_marker: impl Into<NodeMarker>,
dst_marker: impl Into<NodeMarker>,
) -> Option<G::EdgeAttr> {
let src_marker = src_marker.into();
let dst_marker = dst_marker.into();
let src_key = self.get_node_key(&src_marker)?;
let dst_key = self.get_node_key(&dst_marker)?;
let removed_value = self.graph.delete_edge(src_key, dst_key);
if removed_value.is_some() {
self.removed_edges.push((src_key, dst_key));
}
removed_value
}
pub fn get_node_value(&self, marker: impl Into<NodeMarker>) -> Option<&G::NodeAttr> {
let marker = marker.into();
self.get_node_key(&marker)
.and_then(|node_key| self.graph.get_node_attr(node_key))
}
pub fn set_node_value(
&mut self,
marker: impl Into<NodeMarker>,
value: G::NodeAttr,
) -> Option<G::NodeAttr> {
let marker = marker.into();
let node_key = self.get_node_key(&marker)?;
self.changed_node_av.insert(node_key, value.clone());
let old_value = self.graph.set_node_attr(node_key, value.clone());
if old_value.is_some() {
self.changed_node_av.insert(node_key, value);
}
old_value
}
pub fn maybe_set_node_value(
&mut self,
marker: impl Into<NodeMarker>,
maybe_written_av: G::NodeAttr,
join: impl Fn(&G::NodeAttr, &G::NodeAttr) -> Option<G::NodeAttr>,
) -> Option<G::NodeAttr> {
let marker = marker.into();
let node_key = self.get_node_key(&marker)?;
if let Some(old_av) = self.graph.get_node_attr(node_key) {
self.changed_node_av
.insert(node_key, maybe_written_av.clone());
let merged_av = join(old_av, &maybe_written_av)
.expect("must be able to join. TODO: think about if this requirement makes sense");
self.graph.set_node_attr(node_key, merged_av)
} else {
None
}
}
pub fn get_edge_value(
&self,
src_marker: impl Into<NodeMarker>,
dst_marker: impl Into<NodeMarker>,
) -> Option<&G::EdgeAttr> {
let src_marker = src_marker.into();
let dst_marker = dst_marker.into();
let src_key = self.get_node_key(&src_marker)?;
let dst_key = self.get_node_key(&dst_marker)?;
self.graph.get_edge_attr((src_key, dst_key))
}
pub fn set_edge_value(
&mut self,
src_marker: impl Into<NodeMarker>,
dst_marker: impl Into<NodeMarker>,
value: G::EdgeAttr,
) -> Option<G::EdgeAttr> {
let src_marker = src_marker.into();
let dst_marker = dst_marker.into();
let src_key = self.get_node_key(&src_marker)?;
let dst_key = self.get_node_key(&dst_marker)?;
self.changed_edge_av
.insert((src_key, dst_key), value.clone());
let old_value = self.graph.set_edge_attr((src_key, dst_key), value.clone());
if old_value.is_some() {
self.changed_edge_av.insert((src_key, dst_key), value);
} else {
log::warn!(
"Attempted to set edge value for non-existing edge from {:?} to {:?}.",
src_key,
dst_key
);
}
old_value
}
pub fn maybe_set_edge_value(
&mut self,
src_marker: impl Into<NodeMarker>,
dst_marker: impl Into<NodeMarker>,
maybe_written_av: G::EdgeAttr,
join: impl Fn(&G::EdgeAttr, &G::EdgeAttr) -> Option<G::EdgeAttr>,
) -> Option<G::EdgeAttr> {
let src_marker = src_marker.into();
let dst_marker = dst_marker.into();
let src_key = self.get_node_key(&src_marker)?;
let dst_key = self.get_node_key(&dst_marker)?;
if let Some(old_av) = self.graph.get_edge_attr((src_key, dst_key)) {
self.changed_edge_av
.insert((src_key, dst_key), maybe_written_av.clone());
let merged_av = join(old_av, &maybe_written_av)
.expect("must be able to join. TODO: think about if this requirement makes sense");
self.graph.set_edge_attr((src_key, dst_key), merged_av)
} else {
log::warn!(
"Attempted to set edge value for non-existing edge from {:?} to {:?}.",
src_key,
dst_key
);
None
}
}
fn get_new_nodes_and_edges_from_desired_names(
&self,
desired_node_output_names: &HashMap<NewNodeMarker, AbstractOutputNodeMarker>,
) -> (
HashMap<AbstractOutputNodeMarker, NodeKey>,
Vec<(NodeKey, NodeKey)>,
) {
let mut new_nodes = HashMap::new();
for (marker, node_key) in &self.new_nodes_map {
let Some(output_marker) = desired_node_output_names.get(&marker) else {
continue;
};
new_nodes.insert(*output_marker, *node_key);
}
let mut new_edges = Vec::new();
let new_node_or_existing = |node_key: &NodeKey| {
new_nodes.values().any(|&n| n == *node_key)
|| self.subst.mapping.values().any(|&n| n == *node_key)
};
for (src_key, dst_key) in &self.new_edges {
if new_node_or_existing(src_key) || new_node_or_existing(dst_key) {
new_edges.push((*src_key, *dst_key));
}
}
(new_nodes, new_edges)
}
pub fn get_abstract_output<
S: Semantics<NodeAbstract = G::NodeAttr, EdgeAbstract = G::EdgeAttr>,
>(
&self,
desired_node_output_names: HashMap<NewNodeMarker, AbstractOutputNodeMarker>,
) -> AbstractOperationOutput<S> {
let (new_nodes, new_edges) =
self.get_new_nodes_and_edges_from_desired_names(&desired_node_output_names);
let existing_nodes: HashSet<NodeKey> = self.subst.mapping.values().cloned().collect();
let mut existing_edges = HashSet::new();
for (src, dst, _) in self.graph.edges() {
existing_edges.insert((src, dst));
}
let mut changed_abstract_values_nodes = HashMap::new();
for (node_key, node_av) in &self.changed_node_av {
if existing_nodes.contains(node_key) {
changed_abstract_values_nodes.insert(*node_key, node_av.clone());
}
}
let mut changed_abstract_edges = HashMap::new();
for (&(src, dst), edge_av) in &self.changed_edge_av {
if existing_edges.contains(&(src, dst)) {
changed_abstract_edges.insert((src, dst), edge_av.clone());
}
}
AbstractOperationOutput {
new_nodes,
new_edges,
removed_edges: self.removed_edges.clone(),
removed_nodes: self.removed_nodes.clone(),
changed_abstract_values_nodes,
changed_abstract_values_edges: changed_abstract_edges,
}
}
pub fn get_concrete_output(
&self,
desired_node_output_names: HashMap<NewNodeMarker, AbstractOutputNodeMarker>,
) -> OperationOutput {
let (new_nodes, _new_edges) =
self.get_new_nodes_and_edges_from_desired_names(&desired_node_output_names);
OperationOutput {
new_nodes,
removed_nodes: self.removed_nodes.clone(),
}
}
}
#[derive(Debug)]
pub struct OperationArgument<'a> {
pub selected_input_nodes: Cow<'a, [NodeKey]>,
pub subst: ParameterSubstitution,
pub hidden_nodes: HashSet<NodeKey>,
pub marker_set: &'a RefCell<MarkerSet>,
}
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, From)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct AbstractOutputNodeMarker(pub InternString);
interned_string_newtype!(AbstractOutputNodeMarker);
pub struct OperationOutput {
pub new_nodes: HashMap<AbstractOutputNodeMarker, NodeKey>,
pub removed_nodes: Vec<NodeKey>,
}
impl OperationOutput {
pub fn no_changes() -> Self {
OperationOutput {
new_nodes: HashMap::new(),
removed_nodes: Vec::new(),
}
}
}
pub struct AbstractOperationOutput<S: Semantics> {
pub new_nodes: HashMap<AbstractOutputNodeMarker, NodeKey>,
pub removed_nodes: Vec<NodeKey>,
pub new_edges: Vec<(NodeKey, NodeKey)>,
pub removed_edges: Vec<(NodeKey, NodeKey)>,
pub changed_abstract_values_nodes: HashMap<NodeKey, S::NodeAbstract>,
pub changed_abstract_values_edges: HashMap<(NodeKey, NodeKey), S::EdgeAbstract>,
}