Skip to main content

egui_graph_edit/
graph_impls.rs

1use super::*;
2
3impl<NodeData, DataType, ValueType> Graph<NodeData, DataType, ValueType>
4where
5    DataType: PartialEq,
6{
7    pub fn new() -> Self {
8        Self {
9            nodes: SlotMap::default(),
10            inputs: SlotMap::default(),
11            outputs: SlotMap::default(),
12            connections: SecondaryMap::default(),
13        }
14    }
15
16    pub fn add_node(
17        &mut self,
18        label: String,
19        user_data: NodeData,
20        f: impl FnOnce(&mut Graph<NodeData, DataType, ValueType>, NodeId),
21    ) -> NodeId {
22        let node_id = self.nodes.insert_with_key(|node_id| {
23            Node {
24                id: node_id,
25                label,
26                // These get filled in later by the user function
27                inputs: Vec::default(),
28                outputs: Vec::default(),
29                user_data,
30            }
31        });
32
33        f(self, node_id);
34
35        node_id
36    }
37
38    pub fn add_input_param(
39        &mut self,
40        node_id: NodeId,
41        name: String,
42        typ: DataType,
43        value: ValueType,
44        kind: InputParamKind,
45        shown_inline: bool,
46    ) -> InputId {
47        let input_id = self.inputs.insert_with_key(|input_id| InputParam {
48            id: input_id,
49            typ,
50            value,
51            kind,
52            node: node_id,
53            shown_inline,
54        });
55        self.nodes[node_id].inputs.push((name, input_id));
56        input_id
57    }
58
59    pub fn update_input_param(
60        &mut self,
61        input_id: InputId,
62        name: Option<String>,
63        typ: Option<DataType>,
64        value: Option<ValueType>,
65        kind: Option<InputParamKind>,
66        shown_inline: Option<bool>,
67    ) {
68        if let Some(input_param) = self.inputs.get_mut(input_id) {
69            if let Some(new_typ) = typ {
70                input_param.typ = new_typ;
71            }
72            if let Some(new_value) = value {
73                input_param.value = new_value;
74            }
75            if let Some(new_kind) = kind {
76                input_param.kind = new_kind;
77            }
78            if let Some(new_shown_inline) = shown_inline {
79                input_param.shown_inline = new_shown_inline;
80            }
81
82            if let Some(new_name) = name {
83                for (curr_name, id) in &mut self.nodes[input_param.node].inputs {
84                    if *id == input_id {
85                        *curr_name = new_name;
86                        break;
87                    }
88                }
89            }
90        }
91
92        self.ensure_connection_types(AnyParameterId::Input(input_id));
93    }
94
95    pub fn remove_input_param(&mut self, param: InputId) {
96        let node = self[param].node;
97        self[node].inputs.retain(|(_, id)| *id != param);
98        self.inputs.remove(param);
99        self.connections.retain(|i, _| i != param);
100    }
101
102    pub fn add_output_param(&mut self, node_id: NodeId, name: String, typ: DataType) -> OutputId {
103        let output_id = self.outputs.insert_with_key(|output_id| OutputParam {
104            id: output_id,
105            node: node_id,
106            typ,
107        });
108        self.nodes[node_id].outputs.push((name, output_id));
109        output_id
110    }
111
112    pub fn update_output_param(
113        &mut self,
114        output_id: OutputId,
115        name: Option<String>,
116        typ: Option<DataType>,
117    ) {
118        if let Some(output_param) = self.outputs.get_mut(output_id) {
119            if let Some(new_typ) = typ {
120                output_param.typ = new_typ;
121            }
122
123            if let Some(new_name) = name {
124                for (curr_name, id) in &mut self.nodes[output_param.node].outputs {
125                    if *id == output_id {
126                        *curr_name = new_name;
127                        break;
128                    }
129                }
130            }
131        }
132
133        self.ensure_connection_types(AnyParameterId::Output(output_id));
134    }
135
136    pub fn remove_output_param(&mut self, param: OutputId) {
137        let node = self[param].node;
138        self[node].outputs.retain(|(_, id)| *id != param);
139        self.outputs.remove(param);
140        self.connections.retain(|_, o| *o != param);
141    }
142
143    /// Deletes mistyped connection made with param_id
144    ///
145    /// This is only needed connection param type is changed with means
146    /// other than [`Graph::update_input_param`].
147    pub fn ensure_connection_types(&mut self, param_id: AnyParameterId) {
148        let mut to_remove = Vec::default();
149
150        for (to_id, from_id) in self.iter_connections() {
151            // ignore connections that don't touch param_id.
152            if AnyParameterId::Input(to_id) != param_id
153                && AnyParameterId::Output(from_id) != param_id
154            {
155                continue;
156            }
157
158            // connection has mismatched types
159            if self.get_input(to_id).typ != self.get_output(from_id).typ {
160                to_remove.push(to_id);
161            }
162        }
163
164        for in_id in to_remove {
165            self.remove_connection(in_id);
166        }
167    }
168
169    /// Removes a node from the graph with given `node_id`. This also removes
170    /// any incoming or outgoing connections from that node
171    ///
172    /// This function returns the list of connections that has been removed
173    /// after deleting this node as input-output pairs. Note that one of the two
174    /// ids in the pair (the one on `node_id`'s end) will be invalid after
175    /// calling this function.
176    pub fn remove_node(&mut self, node_id: NodeId) -> (Node<NodeData>, Vec<(InputId, OutputId)>) {
177        let mut disconnect_events = vec![];
178
179        self.connections.retain(|i, o| {
180            if self.outputs[*o].node == node_id || self.inputs[i].node == node_id {
181                disconnect_events.push((i, *o));
182                false
183            } else {
184                true
185            }
186        });
187
188        // NOTE: Collect is needed because we can't borrow the input ids while
189        // we remove them inside the loop.
190        for input in self[node_id].input_ids().collect::<SVec<_>>() {
191            self.inputs.remove(input);
192        }
193        for output in self[node_id].output_ids().collect::<SVec<_>>() {
194            self.outputs.remove(output);
195        }
196        let removed_node = self.nodes.remove(node_id).expect("Node should exist");
197
198        (removed_node, disconnect_events)
199    }
200
201    pub fn remove_connection(&mut self, input_id: InputId) -> Option<OutputId> {
202        self.connections.remove(input_id)
203    }
204
205    pub fn iter_nodes(&self) -> impl Iterator<Item = NodeId> + '_ {
206        self.nodes.iter().map(|(id, _)| id)
207    }
208
209    pub fn add_connection(&mut self, output: OutputId, input: InputId) {
210        self.connections.insert(input, output);
211    }
212
213    pub fn iter_connections(&self) -> impl Iterator<Item = (InputId, OutputId)> + '_ {
214        self.connections.iter().map(|(o, i)| (o, *i))
215    }
216
217    pub fn connection(&self, input: InputId) -> Option<OutputId> {
218        self.connections.get(input).copied()
219    }
220
221    pub fn any_param_type(&self, param: AnyParameterId) -> Result<&DataType, EguiGraphError> {
222        match param {
223            AnyParameterId::Input(input) => self.inputs.get(input).map(|x| &x.typ),
224            AnyParameterId::Output(output) => self.outputs.get(output).map(|x| &x.typ),
225        }
226        .ok_or(EguiGraphError::InvalidParameterId(param))
227    }
228
229    pub fn try_get_input(&self, input: InputId) -> Option<&InputParam<DataType, ValueType>> {
230        self.inputs.get(input)
231    }
232
233    pub fn get_input(&self, input: InputId) -> &InputParam<DataType, ValueType> {
234        &self.inputs[input]
235    }
236
237    pub fn try_get_output(&self, output: OutputId) -> Option<&OutputParam<DataType>> {
238        self.outputs.get(output)
239    }
240
241    pub fn get_output(&self, output: OutputId) -> &OutputParam<DataType> {
242        &self.outputs[output]
243    }
244}
245
246impl<NodeData, DataType, ValueType> Default for Graph<NodeData, DataType, ValueType>
247where
248    DataType: PartialEq,
249{
250    fn default() -> Self {
251        Self::new()
252    }
253}
254
255impl<NodeData> Node<NodeData> {
256    pub fn inputs<'a, DataType: PartialEq, DataValue>(
257        &'a self,
258        graph: &'a Graph<NodeData, DataType, DataValue>,
259    ) -> impl Iterator<Item = &'a InputParam<DataType, DataValue>> + 'a {
260        self.input_ids().map(|id| graph.get_input(id))
261    }
262
263    pub fn outputs<'a, DataType: PartialEq, DataValue>(
264        &'a self,
265        graph: &'a Graph<NodeData, DataType, DataValue>,
266    ) -> impl Iterator<Item = &'a OutputParam<DataType>> + 'a {
267        self.output_ids().map(|id| graph.get_output(id))
268    }
269
270    pub fn input_ids(&self) -> impl Iterator<Item = InputId> + '_ {
271        self.inputs.iter().map(|(_name, id)| *id)
272    }
273
274    pub fn output_ids(&self) -> impl Iterator<Item = OutputId> + '_ {
275        self.outputs.iter().map(|(_name, id)| *id)
276    }
277
278    pub fn get_input(&self, name: &str) -> Result<InputId, EguiGraphError> {
279        self.inputs
280            .iter()
281            .find(|(param_name, _id)| param_name == name)
282            .map(|x| x.1)
283            .ok_or_else(|| EguiGraphError::NoParameterNamed(self.id, name.into()))
284    }
285
286    pub fn get_output(&self, name: &str) -> Result<OutputId, EguiGraphError> {
287        self.outputs
288            .iter()
289            .find(|(param_name, _id)| param_name == name)
290            .map(|x| x.1)
291            .ok_or_else(|| EguiGraphError::NoParameterNamed(self.id, name.into()))
292    }
293}
294
295impl<DataType, ValueType> InputParam<DataType, ValueType> {
296    pub fn value(&self) -> &ValueType {
297        &self.value
298    }
299
300    pub fn kind(&self) -> InputParamKind {
301        self.kind
302    }
303
304    pub fn node(&self) -> NodeId {
305        self.node
306    }
307}