1use super::*;
2
3impl<NodeData, DataType, ValueType> Graph<NodeData, DataType, ValueType>
4where
5 DataType: PartialEq,
6{
7 pub fn new() -> Self {
8 Self {
9 nodes: SlotMap::default(),
10 inputs: SlotMap::default(),
11 outputs: SlotMap::default(),
12 connections: SecondaryMap::default(),
13 }
14 }
15
16 pub fn add_node(
17 &mut self,
18 label: String,
19 user_data: NodeData,
20 f: impl FnOnce(&mut Graph<NodeData, DataType, ValueType>, NodeId),
21 ) -> NodeId {
22 let node_id = self.nodes.insert_with_key(|node_id| {
23 Node {
24 id: node_id,
25 label,
26 inputs: Vec::default(),
28 outputs: Vec::default(),
29 user_data,
30 }
31 });
32
33 f(self, node_id);
34
35 node_id
36 }
37
38 pub fn add_input_param(
39 &mut self,
40 node_id: NodeId,
41 name: String,
42 typ: DataType,
43 value: ValueType,
44 kind: InputParamKind,
45 shown_inline: bool,
46 ) -> InputId {
47 let input_id = self.inputs.insert_with_key(|input_id| InputParam {
48 id: input_id,
49 typ,
50 value,
51 kind,
52 node: node_id,
53 shown_inline,
54 });
55 self.nodes[node_id].inputs.push((name, input_id));
56 input_id
57 }
58
59 pub fn update_input_param(
60 &mut self,
61 input_id: InputId,
62 name: Option<String>,
63 typ: Option<DataType>,
64 value: Option<ValueType>,
65 kind: Option<InputParamKind>,
66 shown_inline: Option<bool>,
67 ) {
68 if let Some(input_param) = self.inputs.get_mut(input_id) {
69 if let Some(new_typ) = typ {
70 input_param.typ = new_typ;
71 }
72 if let Some(new_value) = value {
73 input_param.value = new_value;
74 }
75 if let Some(new_kind) = kind {
76 input_param.kind = new_kind;
77 }
78 if let Some(new_shown_inline) = shown_inline {
79 input_param.shown_inline = new_shown_inline;
80 }
81
82 if let Some(new_name) = name {
83 for (curr_name, id) in &mut self.nodes[input_param.node].inputs {
84 if *id == input_id {
85 *curr_name = new_name;
86 break;
87 }
88 }
89 }
90 }
91
92 self.ensure_connection_types(AnyParameterId::Input(input_id));
93 }
94
95 pub fn remove_input_param(&mut self, param: InputId) {
96 let node = self[param].node;
97 self[node].inputs.retain(|(_, id)| *id != param);
98 self.inputs.remove(param);
99 self.connections.retain(|i, _| i != param);
100 }
101
102 pub fn add_output_param(&mut self, node_id: NodeId, name: String, typ: DataType) -> OutputId {
103 let output_id = self.outputs.insert_with_key(|output_id| OutputParam {
104 id: output_id,
105 node: node_id,
106 typ,
107 });
108 self.nodes[node_id].outputs.push((name, output_id));
109 output_id
110 }
111
112 pub fn update_output_param(
113 &mut self,
114 output_id: OutputId,
115 name: Option<String>,
116 typ: Option<DataType>,
117 ) {
118 if let Some(output_param) = self.outputs.get_mut(output_id) {
119 if let Some(new_typ) = typ {
120 output_param.typ = new_typ;
121 }
122
123 if let Some(new_name) = name {
124 for (curr_name, id) in &mut self.nodes[output_param.node].outputs {
125 if *id == output_id {
126 *curr_name = new_name;
127 break;
128 }
129 }
130 }
131 }
132
133 self.ensure_connection_types(AnyParameterId::Output(output_id));
134 }
135
136 pub fn remove_output_param(&mut self, param: OutputId) {
137 let node = self[param].node;
138 self[node].outputs.retain(|(_, id)| *id != param);
139 self.outputs.remove(param);
140 self.connections.retain(|_, o| *o != param);
141 }
142
143 pub fn ensure_connection_types(&mut self, param_id: AnyParameterId) {
148 let mut to_remove = Vec::default();
149
150 for (to_id, from_id) in self.iter_connections() {
151 if AnyParameterId::Input(to_id) != param_id
153 && AnyParameterId::Output(from_id) != param_id
154 {
155 continue;
156 }
157
158 if self.get_input(to_id).typ != self.get_output(from_id).typ {
160 to_remove.push(to_id);
161 }
162 }
163
164 for in_id in to_remove {
165 self.remove_connection(in_id);
166 }
167 }
168
169 pub fn remove_node(&mut self, node_id: NodeId) -> (Node<NodeData>, Vec<(InputId, OutputId)>) {
177 let mut disconnect_events = vec![];
178
179 self.connections.retain(|i, o| {
180 if self.outputs[*o].node == node_id || self.inputs[i].node == node_id {
181 disconnect_events.push((i, *o));
182 false
183 } else {
184 true
185 }
186 });
187
188 for input in self[node_id].input_ids().collect::<SVec<_>>() {
191 self.inputs.remove(input);
192 }
193 for output in self[node_id].output_ids().collect::<SVec<_>>() {
194 self.outputs.remove(output);
195 }
196 let removed_node = self.nodes.remove(node_id).expect("Node should exist");
197
198 (removed_node, disconnect_events)
199 }
200
201 pub fn remove_connection(&mut self, input_id: InputId) -> Option<OutputId> {
202 self.connections.remove(input_id)
203 }
204
205 pub fn iter_nodes(&self) -> impl Iterator<Item = NodeId> + '_ {
206 self.nodes.iter().map(|(id, _)| id)
207 }
208
209 pub fn add_connection(&mut self, output: OutputId, input: InputId) {
210 self.connections.insert(input, output);
211 }
212
213 pub fn iter_connections(&self) -> impl Iterator<Item = (InputId, OutputId)> + '_ {
214 self.connections.iter().map(|(o, i)| (o, *i))
215 }
216
217 pub fn connection(&self, input: InputId) -> Option<OutputId> {
218 self.connections.get(input).copied()
219 }
220
221 pub fn any_param_type(&self, param: AnyParameterId) -> Result<&DataType, EguiGraphError> {
222 match param {
223 AnyParameterId::Input(input) => self.inputs.get(input).map(|x| &x.typ),
224 AnyParameterId::Output(output) => self.outputs.get(output).map(|x| &x.typ),
225 }
226 .ok_or(EguiGraphError::InvalidParameterId(param))
227 }
228
229 pub fn try_get_input(&self, input: InputId) -> Option<&InputParam<DataType, ValueType>> {
230 self.inputs.get(input)
231 }
232
233 pub fn get_input(&self, input: InputId) -> &InputParam<DataType, ValueType> {
234 &self.inputs[input]
235 }
236
237 pub fn try_get_output(&self, output: OutputId) -> Option<&OutputParam<DataType>> {
238 self.outputs.get(output)
239 }
240
241 pub fn get_output(&self, output: OutputId) -> &OutputParam<DataType> {
242 &self.outputs[output]
243 }
244}
245
246impl<NodeData, DataType, ValueType> Default for Graph<NodeData, DataType, ValueType>
247where
248 DataType: PartialEq,
249{
250 fn default() -> Self {
251 Self::new()
252 }
253}
254
255impl<NodeData> Node<NodeData> {
256 pub fn inputs<'a, DataType: PartialEq, DataValue>(
257 &'a self,
258 graph: &'a Graph<NodeData, DataType, DataValue>,
259 ) -> impl Iterator<Item = &'a InputParam<DataType, DataValue>> + 'a {
260 self.input_ids().map(|id| graph.get_input(id))
261 }
262
263 pub fn outputs<'a, DataType: PartialEq, DataValue>(
264 &'a self,
265 graph: &'a Graph<NodeData, DataType, DataValue>,
266 ) -> impl Iterator<Item = &'a OutputParam<DataType>> + 'a {
267 self.output_ids().map(|id| graph.get_output(id))
268 }
269
270 pub fn input_ids(&self) -> impl Iterator<Item = InputId> + '_ {
271 self.inputs.iter().map(|(_name, id)| *id)
272 }
273
274 pub fn output_ids(&self) -> impl Iterator<Item = OutputId> + '_ {
275 self.outputs.iter().map(|(_name, id)| *id)
276 }
277
278 pub fn get_input(&self, name: &str) -> Result<InputId, EguiGraphError> {
279 self.inputs
280 .iter()
281 .find(|(param_name, _id)| param_name == name)
282 .map(|x| x.1)
283 .ok_or_else(|| EguiGraphError::NoParameterNamed(self.id, name.into()))
284 }
285
286 pub fn get_output(&self, name: &str) -> Result<OutputId, EguiGraphError> {
287 self.outputs
288 .iter()
289 .find(|(param_name, _id)| param_name == name)
290 .map(|x| x.1)
291 .ok_or_else(|| EguiGraphError::NoParameterNamed(self.id, name.into()))
292 }
293}
294
295impl<DataType, ValueType> InputParam<DataType, ValueType> {
296 pub fn value(&self) -> &ValueType {
297 &self.value
298 }
299
300 pub fn kind(&self) -> InputParamKind {
301 self.kind
302 }
303
304 pub fn node(&self) -> NodeId {
305 self.node
306 }
307}