use rand::Rng;
use serde::de::Error;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fmt;
#[derive(Eq, Hash, Clone, Copy, PartialEq, PartialOrd, Ord)]
pub struct NodeId(u128);
impl NodeId {
pub fn gen() -> NodeId {
NodeId(rand::thread_rng().gen())
}
}
impl fmt::Display for NodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "node_{}", &base64::encode(self.0.to_be_bytes())[0..22])
}
}
impl fmt::Debug for NodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self)
}
}
impl Serialize for NodeId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for NodeId {
fn deserialize<D>(deserializer: D) -> Result<NodeId, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if let Some(suffix) = s.strip_prefix("node_") {
let decoded_bytes: Vec<u8> = base64::decode(suffix).map_err(D::Error::custom)?;
let decoded_value = <u128>::from_be_bytes(
decoded_bytes
.try_into()
.map_err(|_| D::Error::custom("node id is wrong length"))?,
);
Ok(NodeId(decoded_value))
} else {
Err(D::Error::custom(
"not a valid node id (must start with node_)",
))
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Eq, Hash, PartialEq)]
pub struct Edge {
pub from: NodeId,
pub to: NodeId,
pub input: u32,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct Graph {
pub nodes: Vec<NodeId>,
pub edges: Vec<Edge>,
}
#[derive(Debug, Clone, Default)]
pub struct InsertionPoint {
pub from_output: Option<NodeId>,
pub to_inputs: Vec<(NodeId, u32)>,
}
impl InsertionPoint {
pub fn open() -> Self {
InsertionPoint {
from_output: None,
to_inputs: Vec::new(),
}
}
}
fn remove_invalid_edges(nodes: &[NodeId], edges: &mut Vec<Edge>) {
let mut inputs_seen = HashSet::<(NodeId, u32)>::new();
let nodes_set: HashSet<NodeId> = nodes.iter().cloned().collect();
edges.retain(|edge| {
let valid = nodes_set.contains(&edge.to)
&& nodes_set.contains(&edge.from)
&& !inputs_seen.contains(&(edge.to, edge.input));
inputs_seen.insert((edge.to, edge.input));
valid
});
}
fn map_inputs(nodes: &[NodeId], edges: &[Edge]) -> HashMap<NodeId, Vec<Option<NodeId>>> {
let mut input_mapping = HashMap::<NodeId, Vec<Option<NodeId>>>::new();
for node in nodes.iter() {
input_mapping.insert(*node, Default::default());
}
for edge in edges.iter() {
if !input_mapping.contains_key(&edge.to) || !input_mapping.contains_key(&edge.from) {
continue;
}
let input_vec = input_mapping.get_mut(&edge.to).unwrap();
if input_vec
.get(edge.input as usize)
.cloned()
.flatten()
.is_some()
{
continue;
}
if input_vec.len() <= edge.input as usize {
for _ in 0..=edge.input as usize - input_vec.len() {
input_vec.push(None);
}
}
input_vec[edge.input as usize] = Some(edge.from);
}
input_mapping
}
fn map_outputs(nodes: &[NodeId], edges: &[Edge]) -> HashMap<NodeId, HashSet<(NodeId, u32)>> {
let mut output_mapping = HashMap::<NodeId, HashSet<(NodeId, u32)>>::new();
for node in nodes.iter() {
output_mapping.insert(*node, Default::default());
}
for edge in edges.iter() {
if !output_mapping.contains_key(&edge.to) || !output_mapping.contains_key(&edge.from) {
continue;
}
let output_set = output_mapping.get_mut(&edge.from).unwrap();
output_set.insert((edge.to, edge.input));
}
output_mapping
}
fn find_cycles(
nodes: &[NodeId],
input_mapping: &HashMap<NodeId, Vec<Option<NodeId>>>,
) -> HashSet<Edge> {
let mut cyclic_edges = HashSet::<Edge>::new();
enum Color {
White, Grey, Black, }
let mut stack: Vec<NodeId> = nodes.iter().rev().cloned().collect();
let mut color =
HashMap::<NodeId, Color>::from_iter(input_mapping.keys().map(|x| (*x, Color::White)));
while let Some(&n) = stack.last() {
let cn = color.get(&n).unwrap();
match cn {
Color::White => {
color.insert(n, Color::Grey);
for (i, opt_m) in input_mapping.get(&n).unwrap().iter().rev().enumerate() {
let m = match opt_m {
Some(m) => m,
None => continue,
};
let cm = color.get(m).unwrap();
match cm {
Color::White => {
stack.push(*m);
}
Color::Grey => {
cyclic_edges.insert(Edge {
from: *m,
to: n,
input: i as u32,
});
}
Color::Black => {
}
}
}
}
Color::Grey => {
color.insert(n, Color::Black);
stack.pop();
}
Color::Black => {
stack.pop();
}
}
}
cyclic_edges
}
fn start_nodes(
nodes: &[NodeId],
input_mapping: &HashMap<NodeId, Vec<Option<NodeId>>>,
) -> Vec<NodeId> {
let mut start_nodes: HashSet<NodeId> = input_mapping.keys().cloned().collect();
for input_vec in input_mapping.values() {
for maybe_input in input_vec {
match maybe_input {
Some(input) => {
start_nodes.remove(input);
}
None => {}
}
}
}
let start_nodes: Vec<NodeId> = nodes
.iter()
.filter(|x| start_nodes.contains(x))
.cloned()
.collect();
start_nodes
}
fn first_input(input_mapping: &HashMap<NodeId, Vec<Option<NodeId>>>, start_node: NodeId) -> NodeId {
let mut node = start_node;
while let Some(n) = input_mapping
.get(&node)
.and_then(|input_vec| input_vec.get(0).cloned().flatten())
{
node = n;
}
node
}
pub fn topo_sort_nodes(
start_nodes: &[NodeId],
input_mapping: &HashMap<NodeId, Vec<Option<NodeId>>>,
) -> Vec<NodeId> {
enum Color {
White, Grey, Black, }
let mut stack: Vec<NodeId> = start_nodes.iter().rev().cloned().collect();
let mut color =
HashMap::<NodeId, Color>::from_iter(input_mapping.keys().map(|x| (*x, Color::White)));
let mut topo_order = Vec::<NodeId>::new();
while let Some(&n) = stack.last() {
let cn = color.get(&n).unwrap();
match cn {
Color::White => {
color.insert(n, Color::Grey);
for m in input_mapping.get(&n).unwrap().iter().rev().flatten() {
let cm = color.get(m).unwrap();
match cm {
Color::White => {
stack.push(*m);
}
Color::Grey => {
panic!("Cycle detected in graph");
}
Color::Black => {
}
}
}
}
Color::Grey => {
color.insert(n, Color::Black);
stack.pop();
topo_order.push(n);
}
Color::Black => {
panic!("DFS integrity error"); }
}
}
topo_order
}
impl Graph {
pub fn mapping(&self) -> (Vec<NodeId>, HashMap<NodeId, Vec<Option<NodeId>>>) {
let input_mapping = map_inputs(&self.nodes, &self.edges);
let start_nodes = start_nodes(&self.nodes, &input_mapping);
(start_nodes, input_mapping)
}
pub fn fix(&mut self) {
let orig_n_edges = self.edges.len();
remove_invalid_edges(&self.nodes, &mut self.edges);
let edges_removed = orig_n_edges - self.edges.len();
if edges_removed > 0 {
println!("Removed {} invalid edges", edges_removed);
}
let mut input_mapping = map_inputs(&self.nodes, &self.edges);
let cyclic_edges = find_cycles(&self.nodes, &input_mapping);
if !cyclic_edges.is_empty() {
self.edges.retain(|e| !cyclic_edges.contains(e));
println!("Removed {} edges to break cycles", cyclic_edges.len());
input_mapping = map_inputs(&self.nodes, &self.edges);
}
let start_nodes = start_nodes(&self.nodes, &input_mapping);
self.nodes = topo_sort_nodes(&start_nodes, &input_mapping);
}
pub fn delete_nodes(&mut self, nodes: &HashSet<NodeId>) {
let input_mapping = map_inputs(&self.nodes, &self.edges);
let output_mapping = map_outputs(&self.nodes, &self.edges);
let subgraph_nodes: Vec<NodeId> = self
.nodes
.iter()
.filter(|n| nodes.contains(n))
.cloned()
.collect();
let subgraph_input_mapping = map_inputs(&subgraph_nodes, &self.edges);
let start_nodes = start_nodes(&subgraph_nodes, &subgraph_input_mapping);
for &start_node in start_nodes.iter() {
let end_node = first_input(&subgraph_input_mapping, start_node);
let outgoing_connections = output_mapping.get(&start_node).unwrap();
let incoming_connection = input_mapping
.get(&end_node)
.unwrap()
.get(0)
.cloned()
.flatten();
if let Some(from) = incoming_connection {
for &(to, input) in outgoing_connections.iter() {
self.edges.push(Edge { from, to, input });
}
}
}
self.nodes.retain(|n| !nodes.contains(n));
self.edges
.retain(|e| !nodes.contains(&e.from) && !nodes.contains(&e.to));
}
pub fn insert_node(&mut self, node: NodeId, insertion_point: &InsertionPoint) {
self.nodes.push(node);
if let Some(from_output) = insertion_point.from_output {
self.edges.retain(|e| {
e.from != from_output && !insertion_point.to_inputs.contains(&(e.to, e.input))
});
self.edges.push(Edge {
from: from_output,
to: node,
input: 0,
});
}
for &(to, input) in insertion_point.to_inputs.iter() {
self.edges.push(Edge {
from: node,
to,
input,
});
}
}
pub fn move_nodes(&mut self, nodes: &HashSet<NodeId>, insertion_point: &InsertionPoint) {
let input_mapping = map_inputs(&self.nodes, &self.edges);
let output_mapping = map_outputs(&self.nodes, &self.edges);
let subgraph_nodes: Vec<NodeId> = self
.nodes
.iter()
.filter(|n| nodes.contains(n))
.cloned()
.collect();
let subgraph_input_mapping = map_inputs(&subgraph_nodes, &self.edges);
let start_nodes = start_nodes(&subgraph_nodes, &subgraph_input_mapping);
for &start_node in start_nodes.iter() {
let end_node = first_input(&subgraph_input_mapping, start_node);
let outgoing_connections = output_mapping.get(&start_node).unwrap();
let incoming_connection = input_mapping
.get(&end_node)
.unwrap()
.get(0)
.cloned()
.flatten();
self.edges.retain(|e| {
!(e.from == start_node
|| (e.to == end_node && e.input == 0)
|| (Some(e.from) == insertion_point.from_output
&& insertion_point.to_inputs.contains(&(e.to, e.input))))
});
if let Some(from_output) = insertion_point.from_output {
self.edges.push(Edge {
from: from_output,
to: end_node,
input: 0,
});
}
for &(to_node, to_input) in insertion_point.to_inputs.iter() {
self.edges.push(Edge {
from: start_node,
to: to_node,
input: to_input,
});
}
if let Some(from) = incoming_connection {
for &(to, input) in outgoing_connections.iter() {
self.edges.push(Edge { from, to, input });
}
}
}
}
pub fn replace_node(&mut self, old_node: NodeId, new_node: NodeId) {
for edge in self.edges.iter_mut() {
if edge.from == old_node {
edge.from = new_node;
}
if edge.to == old_node {
edge.to = new_node;
}
}
for node in self.nodes.iter_mut() {
if node == &old_node {
*node = new_node;
break;
}
}
}
pub fn nodes_between(&self, a: NodeId, b: NodeId) -> HashSet<NodeId> {
let output_mapping = map_outputs(&self.nodes, &self.edges);
fn walk_dependents_of(
output_mapping: &HashMap<NodeId, HashSet<(NodeId, u32)>>,
dependents: &mut HashSet<NodeId>,
node: NodeId,
) {
dependents.insert(node);
let outgoing_connections = output_mapping.get(&node).unwrap();
for &(next_node, _) in outgoing_connections.iter() {
walk_dependents_of(output_mapping, dependents, next_node);
}
}
let mut dependents_of_a = HashSet::<NodeId>::new();
walk_dependents_of(&output_mapping, &mut dependents_of_a, a);
let mut dependents_of_b = HashSet::<NodeId>::new();
walk_dependents_of(&output_mapping, &mut dependents_of_b, b);
if !dependents_of_a.contains(&b) && !dependents_of_b.contains(&a) {
return HashSet::<_>::new();
}
dependents_of_a
.symmetric_difference(&dependents_of_b)
.cloned()
.collect()
}
pub fn select_surrounded(&self, new_nodes: &[NodeId], selected: &mut HashSet<NodeId>) {
let input_mapping = map_inputs(&self.nodes, &self.edges);
let output_mapping = map_outputs(&self.nodes, &self.edges);
for &target in new_nodes {
if !input_mapping
.get(&target)
.unwrap()
.iter()
.flatten()
.any(|&upstream_node| selected.contains(&upstream_node))
{
continue;
}
if !output_mapping
.get(&target)
.unwrap()
.iter()
.any(|&(downstream_node, _)| selected.contains(&downstream_node))
{
continue;
}
selected.insert(target);
}
}
}