Skip to main content

jellyflow_core/ops/mutation/
batch.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use crate::core::{Edge, EdgeId, Graph, Node, NodeId, Port, PortId};
4use crate::ops::{EdgeEndpoints, GraphOp, GraphTransaction};
5
6use super::error::GraphMutationError;
7use super::planner::GraphMutationPlanner;
8
9/// Plans a small transaction while tracking ids created by earlier staged ops.
10pub struct GraphMutationBatchPlanner<'a> {
11    graph: &'a Graph,
12    ops: Vec<GraphOp>,
13    staged: StagedMutationIds,
14}
15
16#[derive(Default)]
17struct StagedMutationIds {
18    staged_nodes: BTreeSet<NodeId>,
19    staged_ports: BTreeSet<PortId>,
20    staged_edges: BTreeSet<EdgeId>,
21    staged_edge_endpoints: BTreeMap<EdgeId, EdgeEndpoints>,
22}
23
24impl<'a> GraphMutationBatchPlanner<'a> {
25    pub fn new(graph: &'a Graph) -> Self {
26        Self {
27            graph,
28            ops: Vec::new(),
29            staged: StagedMutationIds::default(),
30        }
31    }
32
33    pub fn is_empty(&self) -> bool {
34        self.ops.is_empty()
35    }
36
37    pub fn into_ops(self) -> Vec<GraphOp> {
38        self.ops
39    }
40
41    pub fn into_transaction(self, label: impl Into<String>) -> GraphTransaction {
42        GraphTransaction::from_ops(self.ops).with_label(label)
43    }
44
45    pub fn add_node_with_ports(
46        &mut self,
47        id: NodeId,
48        node: Node,
49        ports: impl IntoIterator<Item = (PortId, Port)>,
50    ) -> Result<(), GraphMutationError> {
51        if self.staged.contains_node(id) {
52            return Err(GraphMutationError::NodeAlreadyExists(id));
53        }
54
55        let ports: Vec<(PortId, Port)> = ports.into_iter().collect();
56        for (port_id, _) in &ports {
57            if self.staged.contains_port(*port_id) {
58                return Err(GraphMutationError::PortAlreadyExists(*port_id));
59            }
60        }
61
62        let ops = GraphMutationPlanner::new(self.graph).add_node_with_ports_ops(
63            id,
64            node,
65            ports.clone(),
66        )?;
67
68        self.staged.insert_node(id);
69        for (port_id, _) in ports {
70            self.staged.insert_port(port_id);
71        }
72        self.extend_ops(ops);
73        Ok(())
74    }
75
76    pub fn add_edge(&mut self, id: EdgeId, edge: Edge) -> Result<(), GraphMutationError> {
77        if self.graph.edges.contains_key(&id) || self.staged.contains_edge(id) {
78            return Err(GraphMutationError::EdgeAlreadyExists(id));
79        }
80        self.require_known_port(edge.from)?;
81        self.require_known_port(edge.to)?;
82
83        self.staged.insert_edge(id, &edge);
84        self.push_op(GraphOp::AddEdge { id, edge });
85        Ok(())
86    }
87
88    pub fn set_edge_endpoints(
89        &mut self,
90        id: EdgeId,
91        to: EdgeEndpoints,
92    ) -> Result<(), GraphMutationError> {
93        self.require_known_port(to.from)?;
94        self.require_known_port(to.to)?;
95
96        let from = if let Some(endpoints) = self.staged.edge_endpoints(id) {
97            endpoints
98        } else {
99            let edge = self
100                .graph
101                .edges
102                .get(&id)
103                .ok_or(GraphMutationError::MissingEdge(id))?;
104            EdgeEndpoints::from_edge(edge)
105        };
106
107        self.staged.set_edge_endpoints(id, to);
108        self.push_op(GraphOp::SetEdgeEndpoints { id, from, to });
109        Ok(())
110    }
111
112    fn push_op(&mut self, op: GraphOp) {
113        self.ops.push(op);
114    }
115
116    fn extend_ops(&mut self, ops: impl IntoIterator<Item = GraphOp>) {
117        self.ops.extend(ops);
118    }
119
120    fn require_known_port(&self, id: PortId) -> Result<(), GraphMutationError> {
121        if self.graph.ports.contains_key(&id) || self.staged.contains_port(id) {
122            Ok(())
123        } else {
124            Err(GraphMutationError::MissingPort(id))
125        }
126    }
127}
128
129impl StagedMutationIds {
130    fn contains_node(&self, id: NodeId) -> bool {
131        self.staged_nodes.contains(&id)
132    }
133
134    fn contains_port(&self, id: PortId) -> bool {
135        self.staged_ports.contains(&id)
136    }
137
138    fn contains_edge(&self, id: EdgeId) -> bool {
139        self.staged_edges.contains(&id)
140    }
141
142    fn insert_node(&mut self, id: NodeId) {
143        self.staged_nodes.insert(id);
144    }
145
146    fn insert_port(&mut self, id: PortId) {
147        self.staged_ports.insert(id);
148    }
149
150    fn insert_edge(&mut self, id: EdgeId, edge: &Edge) {
151        self.staged_edges.insert(id);
152        self.staged_edge_endpoints
153            .insert(id, EdgeEndpoints::from_edge(edge));
154    }
155
156    fn edge_endpoints(&self, id: EdgeId) -> Option<EdgeEndpoints> {
157        self.staged_edge_endpoints.get(&id).copied()
158    }
159
160    fn set_edge_endpoints(&mut self, id: EdgeId, to: EdgeEndpoints) {
161        self.staged_edge_endpoints.insert(id, to);
162    }
163}