egui_graph_edit/
editor_ui.rs

1use std::collections::HashSet;
2
3use crate::color_hex_utils::*;
4use crate::utils::ColorUtils;
5
6use super::*;
7use egui::epaint::{CubicBezierShape, RectShape};
8use egui::*;
9
10pub type PortLocations = std::collections::HashMap<AnyParameterId, Pos2>;
11pub type NodeRects = std::collections::HashMap<NodeId, Rect>;
12
13const DISTANCE_TO_CONNECT: f32 = 10.0;
14
15/// Nodes communicate certain events to the parent graph when drawn. There is
16/// one special `User` variant which can be used by users as the return value
17/// when executing some custom actions in the UI of the node.
18#[derive(Clone, Debug)]
19pub enum NodeResponse<UserResponse: UserResponseTrait, NodeData: NodeDataTrait> {
20    ConnectEventStarted(NodeId, AnyParameterId),
21    ConnectEventEnded {
22        output: OutputId,
23        input: InputId,
24    },
25    CreatedNode(NodeId),
26    SelectNode(NodeId),
27    /// As a user of this library, prefer listening for `DeleteNodeFull` which
28    /// will also contain the user data for the deleted node.
29    DeleteNodeUi(NodeId),
30    /// Emitted when a node is deleted. The node will no longer exist in the
31    /// graph after this response is returned from the draw function, but its
32    /// contents are passed along with the event.
33    DeleteNodeFull {
34        node_id: NodeId,
35        node: Node<NodeData>,
36    },
37    DisconnectEvent {
38        output: OutputId,
39        input: InputId,
40    },
41    /// Emitted when a node is interacted with, and should be raised
42    RaiseNode(NodeId),
43    MoveNode {
44        node: NodeId,
45        drag_delta: Vec2,
46    },
47    User(UserResponse),
48}
49
50/// The return value of [`draw_graph_editor`]. This value can be used to make
51/// user code react to specific events that happened when drawing the graph.
52#[derive(Clone, Debug)]
53pub struct GraphResponse<UserResponse: UserResponseTrait, NodeData: NodeDataTrait> {
54    /// Events that occurred during this frame of rendering the graph. Check the
55    /// [`UserResponse`] type for a description of each event.
56    pub node_responses: Vec<NodeResponse<UserResponse, NodeData>>,
57    /// Is the mouse currently hovering the graph editor? Note that the node
58    /// finder is considered part of the graph editor, even when it floats
59    /// outside the graph editor rect.
60    pub cursor_in_editor: bool,
61    /// Is the mouse currently hovering the node finder?
62    pub cursor_in_finder: bool,
63}
64
65impl<UserResponse: UserResponseTrait, NodeData: NodeDataTrait> Default
66    for GraphResponse<UserResponse, NodeData>
67{
68    fn default() -> Self {
69        Self {
70            node_responses: Default::default(),
71            cursor_in_editor: false,
72            cursor_in_finder: false,
73        }
74    }
75}
76
77pub struct GraphNodeWidget<'a, NodeData, DataType, ValueType> {
78    pub position: &'a mut Pos2,
79    pub orientation: &'a mut NodeOrientation,
80    pub graph: &'a mut Graph<NodeData, DataType, ValueType>,
81    pub port_locations: &'a mut PortLocations,
82    pub node_rects: &'a mut NodeRects,
83    pub node_id: NodeId,
84    pub ongoing_drag: Option<(NodeId, AnyParameterId)>,
85    pub selected: bool,
86    pub pan: egui::Vec2,
87}
88
89impl<NodeData, DataType, ValueType, NodeTemplate, UserResponse, UserState, CategoryType>
90    GraphEditorState<NodeData, DataType, ValueType, NodeTemplate, UserState>
91where
92    NodeData: NodeDataTrait<
93        Response = UserResponse,
94        UserState = UserState,
95        DataType = DataType,
96        ValueType = ValueType,
97    >,
98    UserResponse: UserResponseTrait,
99    ValueType:
100        WidgetValueTrait<Response = UserResponse, UserState = UserState, NodeData = NodeData>,
101    NodeTemplate: NodeTemplateTrait<
102        NodeData = NodeData,
103        DataType = DataType,
104        ValueType = ValueType,
105        UserState = UserState,
106        CategoryType = CategoryType,
107    >,
108    DataType: DataTypeTrait<UserState>,
109    CategoryType: CategoryTrait,
110{
111    #[must_use]
112    pub fn draw_graph_editor(
113        &mut self,
114        ui: &mut Ui,
115        all_kinds: impl NodeTemplateIter<Item = NodeTemplate>,
116        user_state: &mut UserState,
117        prepend_responses: Vec<NodeResponse<UserResponse, NodeData>>,
118    ) -> GraphResponse<UserResponse, NodeData> {
119        ui.set_clip_rect(ui.max_rect());
120        let clip_rect = ui.clip_rect();
121        // Zoom may have never taken place, so ensure we use parent style
122        if !self.pan_zoom.started {
123            self.zoom(ui, 1.0);
124            self.pan_zoom.started = true;
125        }
126
127        // Zoom only within area where graph is shown
128        if ui.rect_contains_pointer(clip_rect) {
129            let scroll_delta = ui.input(|i| i.smooth_scroll_delta.y);
130            if scroll_delta != 0.0 {
131                let zoom_delta = (scroll_delta * 0.002).exp();
132                self.zoom(ui, zoom_delta);
133            }
134        }
135
136        // Render graph zoomed
137        let zoomed_style = self.pan_zoom.zoomed_style.clone();
138        let graph_response = show_zoomed(ui.style().clone(), zoomed_style, ui, |ui| {
139            self.draw_graph_editor_inside_zoom(ui, all_kinds, user_state, prepend_responses)
140        });
141
142        graph_response
143    }
144
145    /// Reset zoom to 1.0
146    pub fn reset_zoom(&mut self, ui: &Ui) {
147        let new_zoom = 1.0 / self.pan_zoom.zoom;
148        self.zoom(ui, new_zoom);
149    }
150
151    /// Zoom within the where you call `draw_graph_editor`. Use values like 1.01, or 0.99 to zoom.
152    /// For example: `let zoom_delta = (scroll_delta * 0.002).exp();`
153    pub fn zoom(&mut self, ui: &Ui, zoom_delta: f32) {
154        // Update zoom, and styles
155        let zoom_before = self.pan_zoom.zoom;
156        self.pan_zoom.zoom(ui.clip_rect(), ui.style(), zoom_delta);
157        if zoom_before != self.pan_zoom.zoom {
158            let actual_delta = self.pan_zoom.zoom / zoom_before;
159            self.update_node_positions_after_zoom(actual_delta);
160        }
161    }
162
163    fn update_node_positions_after_zoom(&mut self, zoom_delta: f32) {
164        // Update node positions, zoom towards center
165        let half_size = self.pan_zoom.clip_rect.size() / 2.0;
166        for (_id, node_pos) in self.node_positions.iter_mut() {
167            // 1. Get node local position (relative to origo)
168            let local_pos = node_pos.to_vec2() - half_size + self.pan_zoom.pan;
169            // 2. Scale local position by zoom delta
170            let scaled_local_pos = (local_pos * zoom_delta).to_pos2();
171            // 3. Transform back to global position
172            *node_pos = scaled_local_pos + half_size - self.pan_zoom.pan;
173            // This way we can retain pan untouched when zooming :)
174        }
175    }
176
177    fn draw_graph_editor_inside_zoom(
178        &mut self,
179        ui: &mut Ui,
180        all_kinds: impl NodeTemplateIter<Item = NodeTemplate>,
181        user_state: &mut UserState,
182        prepend_responses: Vec<NodeResponse<UserResponse, NodeData>>,
183    ) -> GraphResponse<UserResponse, NodeData> {
184        // This causes the graph editor to use as much free space as it can.
185        // (so for windows it will use up to the resizeably set limit
186        // and for a Panel it will fill it completely)
187        let editor_rect = ui.max_rect();
188        let resp = ui.allocate_rect(editor_rect, Sense::hover());
189
190        let cursor_pos = ui
191            .ctx()
192            .input(|i| i.pointer.hover_pos().unwrap_or(Pos2::ZERO));
193        let mut cursor_in_editor = resp.contains_pointer();
194        let mut cursor_in_finder = false;
195
196        // Gets filled with the node metrics as they are drawn
197        let mut port_locations = PortLocations::new();
198        let mut node_rects = NodeRects::new();
199
200        // The responses returned from node drawing have side effects that are best
201        // executed at the end of this function.
202        let mut delayed_responses: Vec<NodeResponse<UserResponse, NodeData>> = prepend_responses;
203
204        // Used to detect drag events in the background
205        let mut drag_started_on_background = false;
206        let mut drag_released_on_background = false;
207
208        debug_assert_eq!(
209            self.node_order.iter().copied().collect::<HashSet<_>>(),
210            self.graph.iter_nodes().collect::<HashSet<_>>(),
211            "The node_order field of the GraphEditorself was left in an \
212        inconsistent self. It has either more or less values than the graph."
213        );
214
215        // Allocate rect before the nodes, otherwise this will block the interaction
216        // with the nodes.
217        let r = ui.allocate_rect(ui.min_rect(), Sense::click().union(Sense::drag()));
218        if r.drag_started() {
219            drag_started_on_background = true;
220        } else if r.drag_stopped() {
221            drag_released_on_background = true;
222        }
223
224        /* Draw nodes */
225        for node_id in self.node_order.iter().copied() {
226            let responses = GraphNodeWidget {
227                position: self.node_positions.get_mut(node_id).unwrap(),
228                orientation: self.node_orientations.get_mut(node_id).unwrap(),
229                graph: &mut self.graph,
230                port_locations: &mut port_locations,
231                node_rects: &mut node_rects,
232                node_id,
233                ongoing_drag: self.connection_in_progress,
234                selected: self.selected_nodes.contains(&node_id),
235                pan: self.pan_zoom.pan + editor_rect.min.to_vec2(),
236            }
237            .show(&self.pan_zoom, ui, user_state);
238
239            // Actions executed later
240            delayed_responses.extend(responses);
241        }
242
243        /* Draw the node finder, if open */
244        let mut should_close_node_finder = false;
245        if let Some(ref mut node_finder) = self.node_finder {
246            let mut node_finder_area = Area::new(Id::from("node_finder")).order(Order::Foreground);
247            if let Some(pos) = node_finder.position {
248                node_finder_area = node_finder_area.current_pos(pos);
249            }
250            node_finder_area.show(ui.ctx(), |ui| {
251                if let Some(node_kind) = node_finder.show(ui, all_kinds, user_state) {
252                    let new_node = self.graph.add_node(
253                        node_kind.node_graph_label(user_state),
254                        node_kind.user_data(user_state),
255                        |graph, node_id| node_kind.build_node(graph, user_state, node_id),
256                    );
257                    self.node_positions.insert(
258                        new_node,
259                        node_finder.position.unwrap_or(cursor_pos)
260                            - self.pan_zoom.pan
261                            - editor_rect.min.to_vec2(),
262                    );
263                    self.node_orientations
264                        .insert(new_node, NodeOrientation::LeftToRight);
265                    self.node_order.push(new_node);
266
267                    should_close_node_finder = true;
268                    delayed_responses.push(NodeResponse::CreatedNode(new_node));
269                }
270                let finder_rect = ui.min_rect();
271                // If the cursor is not in the main editor, check if the cursor is in the finder
272                // if the cursor is in the finder, then we can consider that also in the editor.
273                if finder_rect.contains(cursor_pos) {
274                    cursor_in_editor = true;
275                    cursor_in_finder = true;
276                }
277            });
278        }
279        if should_close_node_finder {
280            self.node_finder = None;
281        }
282
283        /* Draw connections */
284        fn port_control(param_id: &AnyParameterId, orientation: NodeOrientation) -> Vec2 {
285            match (param_id, orientation) {
286                (AnyParameterId::Input(_), NodeOrientation::LeftToRight) => -Vec2::X,
287                (AnyParameterId::Input(_), NodeOrientation::RightToLeft) => Vec2::X,
288                (AnyParameterId::Output(_), NodeOrientation::LeftToRight) => Vec2::X,
289                (AnyParameterId::Output(_), NodeOrientation::RightToLeft) => -Vec2::X,
290            }
291        }
292
293        if let Some((_, ref locator)) = self.connection_in_progress {
294            let port_type = self.graph.any_param_type(*locator).unwrap();
295            let connection_color = port_type.data_type_color(user_state);
296            let start_pos = port_locations[locator];
297
298            // Find a port to connect to
299            fn snap_to_ports<
300                NodeData,
301                UserState,
302                DataType: DataTypeTrait<UserState>,
303                ValueType,
304                Key: slotmap::Key + Into<AnyParameterId>,
305                Value,
306            >(
307                graph: &Graph<NodeData, DataType, ValueType>,
308                port_type: &DataType,
309                ports: &SlotMap<Key, Value>,
310                port_locations: &PortLocations,
311                node_orientations: &SecondaryMap<NodeId, NodeOrientation>,
312                cursor_pos: Pos2,
313                default_control: Vec2,
314            ) -> (Pos2, Vec2) {
315                ports
316                    .iter()
317                    .find_map(|(port_id, _)| {
318                        let compatible_ports = graph
319                            .any_param_type(port_id.into())
320                            .map(|other| other == port_type)
321                            .unwrap_or(false);
322
323                        if compatible_ports {
324                            port_locations.get(&port_id.into()).and_then(|port_pos| {
325                                if port_pos.distance(cursor_pos) < DISTANCE_TO_CONNECT {
326                                    let param_id: AnyParameterId = port_id.into();
327                                    let dst_node_id = match param_id {
328                                        AnyParameterId::Output(id) => graph.get_output(id).node,
329                                        AnyParameterId::Input(id) => graph.get_input(id).node,
330                                    };
331                                    let dst_orientation = node_orientations[dst_node_id];
332                                    let dst_control = port_control(&param_id, dst_orientation);
333
334                                    Some((*port_pos, dst_control))
335                                } else {
336                                    None
337                                }
338                            })
339                        } else {
340                            None
341                        }
342                    })
343                    .unwrap_or((cursor_pos, default_control))
344            }
345
346            // Figure out where source connection should point to
347            let src_node_id = match locator {
348                AnyParameterId::Output(out_id) => self.graph.get_output(*out_id).node,
349                AnyParameterId::Input(in_id) => self.graph.get_input(*in_id).node,
350            };
351            let src_orientation = self.node_orientations[src_node_id];
352            let src_control = port_control(locator, src_orientation);
353
354            // Figure out where destination connection should point to
355            let (dst_pos, dst_control) = match locator {
356                AnyParameterId::Output(_) => snap_to_ports(
357                    &self.graph,
358                    port_type,
359                    &self.graph.inputs,
360                    &port_locations,
361                    &self.node_orientations,
362                    cursor_pos,
363                    -src_control,
364                ),
365
366                AnyParameterId::Input(_) => snap_to_ports(
367                    &self.graph,
368                    port_type,
369                    &self.graph.outputs,
370                    &port_locations,
371                    &self.node_orientations,
372                    cursor_pos,
373                    -src_control,
374                ),
375            };
376            draw_connection(
377                &self.pan_zoom,
378                ui.painter(),
379                start_pos,
380                src_control,
381                dst_pos,
382                dst_control,
383                connection_color,
384            );
385        }
386
387        for (input, output) in self.graph.iter_connections() {
388            let port_type = self
389                .graph
390                .any_param_type(AnyParameterId::Output(output))
391                .unwrap();
392            let connection_color = port_type.data_type_color(user_state);
393            let src_pos = port_locations[&AnyParameterId::Output(output)];
394            let dst_pos = port_locations[&AnyParameterId::Input(input)];
395            let src_id = self.graph.get_output(output).node;
396            let dst_id = self.graph.get_input(input).node;
397            let src_orientation = self.node_orientations[src_id];
398            let dst_orientation = self.node_orientations[dst_id];
399            let src_control = port_control(&output.into(), src_orientation);
400            let dst_control = port_control(&input.into(), dst_orientation);
401            draw_connection(
402                &self.pan_zoom,
403                ui.painter(),
404                src_pos,
405                src_control,
406                dst_pos,
407                dst_control,
408                connection_color,
409            );
410        }
411
412        /* Handle responses from drawing nodes */
413
414        // Some responses generate additional responses when processed. These
415        // are stored here to report them back to the user.
416        let mut extra_responses: Vec<NodeResponse<UserResponse, NodeData>> = Vec::new();
417
418        for response in delayed_responses.iter() {
419            match response {
420                NodeResponse::ConnectEventStarted(node_id, port) => {
421                    self.connection_in_progress = Some((*node_id, *port));
422                }
423                NodeResponse::ConnectEventEnded { input, output } => {
424                    self.graph.add_connection(*output, *input)
425                }
426                NodeResponse::CreatedNode(_) => {
427                    //Convenience NodeResponse for users
428                }
429                NodeResponse::SelectNode(node_id) => {
430                    self.selected_nodes = Vec::from([*node_id]);
431                }
432                NodeResponse::DeleteNodeUi(node_id) => {
433                    let (node, disc_events) = self.graph.remove_node(*node_id);
434                    // Pass the disconnection responses first so user code can perform cleanup
435                    // before node removal response.
436                    extra_responses.extend(
437                        disc_events
438                            .into_iter()
439                            .map(|(input, output)| NodeResponse::DisconnectEvent { input, output }),
440                    );
441                    // Pass the full node as a response so library users can
442                    // listen for it and get their user data.
443                    extra_responses.push(NodeResponse::DeleteNodeFull {
444                        node_id: *node_id,
445                        node,
446                    });
447                    self.node_positions.remove(*node_id);
448                    // Make sure to not leave references to old nodes hanging
449                    self.selected_nodes.retain(|id| *id != *node_id);
450                    self.node_order.retain(|id| *id != *node_id);
451                }
452                NodeResponse::DisconnectEvent { input, output } => {
453                    let other_node = self.graph.get_output(*output).node;
454                    self.graph.remove_connection(*input);
455                    self.connection_in_progress =
456                        Some((other_node, AnyParameterId::Output(*output)));
457                }
458                NodeResponse::RaiseNode(node_id) => {
459                    let old_pos = self
460                        .node_order
461                        .iter()
462                        .position(|id| *id == *node_id)
463                        .expect("Node to be raised should be in `node_order`");
464                    self.node_order.remove(old_pos);
465                    self.node_order.push(*node_id);
466                }
467                NodeResponse::MoveNode { node, drag_delta } => {
468                    self.node_positions[*node] += *drag_delta;
469                    // Handle multi-node selection movement
470                    if self.selected_nodes.contains(node) && self.selected_nodes.len() > 1 {
471                        for n in self.selected_nodes.iter().copied() {
472                            if n != *node {
473                                self.node_positions[n] += *drag_delta;
474                            }
475                        }
476                    }
477                }
478                NodeResponse::User(_) => {
479                    // These are handled by the user code.
480                }
481                NodeResponse::DeleteNodeFull { .. } => {
482                    unreachable!("The UI should never produce a DeleteNodeFull event.")
483                }
484            }
485        }
486
487        // Handle box selection
488        if let Some(box_start) = self.ongoing_box_selection {
489            let selection_rect = Rect::from_two_pos(cursor_pos, box_start);
490            let bg_color = Color32::from_rgba_unmultiplied(200, 200, 200, 20);
491            let stroke_color = Color32::from_rgba_unmultiplied(200, 200, 200, 180);
492            ui.painter().rect(
493                selection_rect,
494                2.0,
495                bg_color,
496                Stroke::new(3.0, stroke_color),
497                StrokeKind::Outside,
498            );
499
500            self.selected_nodes = node_rects
501                .into_iter()
502                .filter_map(|(node_id, rect)| {
503                    if selection_rect.intersects(rect) {
504                        Some(node_id)
505                    } else {
506                        None
507                    }
508                })
509                .collect();
510        }
511
512        // Push any responses that were generated during response handling.
513        // These are only informative for the end-user and need no special
514        // treatment here.
515        delayed_responses.extend(extra_responses);
516
517        /* Mouse input handling */
518
519        // This locks the context, so don't hold on to it for too long.
520        let mouse = &ui.ctx().input(|i| i.pointer.clone());
521
522        if mouse.any_released() && self.connection_in_progress.is_some() {
523            self.connection_in_progress = None;
524        }
525
526        if mouse.secondary_released() && !cursor_in_finder {
527            self.node_finder = Some(NodeFinder::new_at(cursor_pos));
528        }
529        if ui.ctx().input(|i| i.key_pressed(Key::Escape)) {
530            self.node_finder = None;
531        }
532
533        if r.dragged() && ui.ctx().input(|i| i.pointer.middle_down()) {
534            self.pan_zoom.pan += ui.ctx().input(|i| i.pointer.delta());
535        }
536
537        // Deselect and deactivate finder if the editor backround is clicked,
538        // *or* if the the mouse clicks off the ui
539        if mouse.any_pressed() && !cursor_in_finder {
540            self.selected_nodes = Vec::new();
541            self.node_finder = None;
542        }
543
544        if drag_started_on_background && mouse.primary_down() {
545            self.ongoing_box_selection = Some(cursor_pos);
546        }
547        if mouse.primary_released() || drag_released_on_background {
548            self.ongoing_box_selection = None;
549        }
550
551        GraphResponse {
552            node_responses: delayed_responses,
553            cursor_in_editor,
554            cursor_in_finder,
555        }
556    }
557}
558
559fn draw_connection(
560    pan_zoom: &PanZoom,
561    painter: &Painter,
562    src_pos: Pos2,
563    src_control: Vec2,
564    dst_pos: Pos2,
565    dst_control: Vec2,
566    color: Color32,
567) {
568    let connection_stroke = egui::Stroke {
569        width: 5.0 * pan_zoom.zoom,
570        color,
571    };
572
573    let control_scale = ((dst_pos.x - src_pos.x) / 2.0).abs().max(30.0);
574    let src_control = src_pos + src_control * control_scale;
575    let dst_control = dst_pos + dst_control * control_scale;
576
577    let bezier = CubicBezierShape::from_points_stroke(
578        [src_pos, src_control, dst_control, dst_pos],
579        false,
580        Color32::TRANSPARENT,
581        connection_stroke,
582    );
583
584    painter.add(bezier);
585
586    let [r, g, b, a] = color.to_srgba_unmultiplied();
587    let wide_stroke = egui::Stroke {
588        width: 10.0,
589        color: Color32::from_rgba_unmultiplied(r / 2, g / 2, b / 2, a / 2),
590    };
591
592    let wide_bezier = CubicBezierShape::from_points_stroke(
593        [src_pos, src_control, dst_control, dst_pos],
594        false,
595        Color32::TRANSPARENT,
596        wide_stroke,
597    );
598
599    painter.add(wide_bezier);
600}
601
602#[derive(Clone, Copy, Debug)]
603struct OuterRectMemory(Rect);
604
605impl<NodeData, DataType, ValueType, UserResponse, UserState>
606    GraphNodeWidget<'_, NodeData, DataType, ValueType>
607where
608    NodeData: NodeDataTrait<
609        Response = UserResponse,
610        UserState = UserState,
611        DataType = DataType,
612        ValueType = ValueType,
613    >,
614    UserResponse: UserResponseTrait,
615    ValueType:
616        WidgetValueTrait<Response = UserResponse, UserState = UserState, NodeData = NodeData>,
617    DataType: DataTypeTrait<UserState>,
618{
619    pub const MAX_NODE_SIZE: [f32; 2] = [200.0, 200.0];
620
621    pub fn show(
622        self,
623        pan_zoom: &PanZoom,
624        ui: &mut Ui,
625        user_state: &mut UserState,
626    ) -> Vec<NodeResponse<UserResponse, NodeData>> {
627        let mut child_ui = ui.new_child(
628            UiBuilder::new()
629                .max_rect(Rect::from_min_size(
630                    *self.position + self.pan,
631                    Self::MAX_NODE_SIZE.into(),
632                ))
633                .layout(*ui.layout())
634                .id_salt(self.node_id),
635        );
636
637        Self::show_graph_node(self, pan_zoom, &mut child_ui, user_state)
638    }
639
640    /// Draws this node. Also fills in the list of port locations with all of its ports.
641    /// Returns responses indicating multiple events.
642    fn show_graph_node(
643        self,
644        pan_zoom: &PanZoom,
645        ui: &mut Ui,
646        user_state: &mut UserState,
647    ) -> Vec<NodeResponse<UserResponse, NodeData>> {
648        let margin = egui::vec2(15.0, 5.0) * pan_zoom.zoom;
649        let mut responses = Vec::<NodeResponse<UserResponse, NodeData>>::new();
650
651        let background_color;
652        let text_color;
653        if ui.visuals().dark_mode {
654            background_color = color_from_hex("#3f3f3f").unwrap();
655            text_color = color_from_hex("#fefefe").unwrap();
656        } else {
657            background_color = color_from_hex("#ffffff").unwrap();
658            text_color = color_from_hex("#505050").unwrap();
659        }
660
661        ui.visuals_mut().widgets.noninteractive.fg_stroke =
662            Stroke::new(2.0 * pan_zoom.zoom, text_color);
663
664        // Preallocate shapes to paint below contents
665        let outline_shape = ui.painter().add(Shape::Noop);
666        let background_shape = ui.painter().add(Shape::Noop);
667
668        let mut outer_rect_bounds = ui.available_rect_before_wrap();
669        // Scale hack, otherwise some (larger) rects expand too much when zoomed out
670        outer_rect_bounds.max.x =
671            outer_rect_bounds.min.x + outer_rect_bounds.width() * pan_zoom.zoom;
672        let mut inner_rect = outer_rect_bounds.shrink2(margin);
673
674        // Make sure we don't shrink to the negative:
675        inner_rect.max.x = inner_rect.max.x.max(inner_rect.min.x);
676        inner_rect.max.y = inner_rect.max.y.max(inner_rect.min.y);
677
678        let mut child_ui = ui.new_child(UiBuilder::new().max_rect(inner_rect).layout(*ui.layout()));
679
680        // Get interaction rect from memory, it may expand after the window response on resize.
681        let interaction_rect = ui
682            .ctx()
683            .memory_mut(|mem| {
684                mem.data
685                    .get_temp::<OuterRectMemory>(child_ui.id())
686                    .map(|stored| stored.0)
687            })
688            .unwrap_or(outer_rect_bounds);
689        // After 0.20, layers added over others can block hover interaction. Call this first
690        // before creating the node content.
691        let window_response = ui.interact(
692            interaction_rect,
693            Id::new((self.node_id, "window")),
694            Sense::click_and_drag(),
695        );
696
697        let mut title_height = 0.0;
698
699        let mut input_port_heights = vec![];
700        let mut output_port_heights = vec![];
701
702        child_ui.vertical(|ui| {
703            ui.horizontal(|ui| {
704                ui.add(
705                    Label::new(
706                        RichText::new(&self.graph[self.node_id].label)
707                            .text_style(TextStyle::Button)
708                            .color(text_color),
709                    )
710                    .selectable(false),
711                );
712                responses.extend(self.graph[self.node_id].user_data.top_bar_ui(
713                    ui,
714                    self.node_id,
715                    self.graph,
716                    user_state,
717                ));
718                ui.add_space(8.0 * pan_zoom.zoom); // The size of the little h-flip icon
719                ui.add_space(4.0 * pan_zoom.zoom); // margin
720                ui.add_space(8.0 * pan_zoom.zoom); // The size of the little cross icon
721            });
722            ui.add_space(margin.y);
723            title_height = ui.min_size().y;
724
725            // First pass: Draw the inner fields. Compute port heights
726            let input_layout = match self.orientation {
727                NodeOrientation::LeftToRight => Layout::left_to_right(Align::default()),
728                NodeOrientation::RightToLeft => Layout::right_to_left(Align::default()),
729            };
730            let output_layout = match self.orientation {
731                NodeOrientation::LeftToRight => Layout::right_to_left(Align::default()),
732                NodeOrientation::RightToLeft => Layout::left_to_right(Align::default()),
733            };
734
735            let inputs = self.graph[self.node_id].inputs.clone();
736            for (param_name, param_id) in inputs {
737                if self.graph[param_id].shown_inline {
738                    let height_before = ui.min_rect().bottom();
739                    // NOTE: We want to pass the `user_data` to
740                    // `value_widget`, but we can't since that would require
741                    // borrowing the graph twice. Here, we make the
742                    // assumption that the value is cheaply replaced, and
743                    // use `std::mem::take` to temporarily replace it with a
744                    // dummy value. This requires `ValueType` to implement
745                    // Default, but results in a totally safe alternative.
746                    let mut value = std::mem::take(&mut self.graph[param_id].value);
747
748                    ui.with_layout(input_layout, |ui| {
749                        if self.graph.connection(param_id).is_some() {
750                            let node_responses = value.value_widget_connected(
751                                &param_name,
752                                self.node_id,
753                                ui,
754                                user_state,
755                                &self.graph[self.node_id].user_data,
756                            );
757
758                            responses.extend(node_responses.into_iter().map(NodeResponse::User));
759                        } else {
760                            let node_responses = value.value_widget(
761                                &param_name,
762                                self.node_id,
763                                ui,
764                                user_state,
765                                &self.graph[self.node_id].user_data,
766                            );
767
768                            responses.extend(node_responses.into_iter().map(NodeResponse::User));
769                        }
770                    });
771
772                    self.graph[param_id].value = value;
773
774                    self.graph[self.node_id].user_data.separator(
775                        ui,
776                        self.node_id,
777                        AnyParameterId::Input(param_id),
778                        self.graph,
779                        user_state,
780                    );
781
782                    let height_after = ui.min_rect().bottom();
783                    input_port_heights.push((height_before + height_after) / 2.0);
784                }
785            }
786
787            let outputs = self.graph[self.node_id].outputs.clone();
788            for (param_name, param_id) in outputs {
789                let height_before = ui.min_rect().bottom();
790                ui.with_layout(output_layout, |ui| {
791                    responses.extend(
792                        self.graph[self.node_id]
793                            .user_data
794                            .output_ui(ui, self.node_id, self.graph, user_state, &param_name)
795                            .into_iter(),
796                    );
797                });
798
799                self.graph[self.node_id].user_data.separator(
800                    ui,
801                    self.node_id,
802                    AnyParameterId::Output(param_id),
803                    self.graph,
804                    user_state,
805                );
806
807                let height_after = ui.min_rect().bottom();
808                output_port_heights.push((height_before + height_after) / 2.0);
809            }
810
811            responses.extend(self.graph[self.node_id].user_data.bottom_ui(
812                ui,
813                self.node_id,
814                self.graph,
815                user_state,
816            ));
817        });
818
819        // Second pass, iterate again to draw the ports. This happens outside
820        // the child_ui because we want ports to overflow the node background.
821
822        let outer_rect = child_ui.min_rect().expand2(margin);
823        let port_left = outer_rect.left();
824        let port_right = outer_rect.right();
825
826        // Save expanded rect to memory.
827        ui.ctx().memory_mut(|mem| {
828            mem.data
829                .insert_temp(child_ui.id(), OuterRectMemory(outer_rect))
830        });
831
832        #[allow(clippy::too_many_arguments)]
833        fn draw_port<NodeData, DataType, ValueType, UserResponse, UserState>(
834            pan_zoom: &PanZoom,
835            ui: &mut Ui,
836            graph: &Graph<NodeData, DataType, ValueType>,
837            node_id: NodeId,
838            user_state: &mut UserState,
839            port_pos: Pos2,
840            responses: &mut Vec<NodeResponse<UserResponse, NodeData>>,
841            param_id: AnyParameterId,
842            port_locations: &mut PortLocations,
843            ongoing_drag: Option<(NodeId, AnyParameterId)>,
844            is_connected_input: bool,
845        ) where
846            DataType: DataTypeTrait<UserState>,
847            UserResponse: UserResponseTrait,
848            NodeData: NodeDataTrait,
849        {
850            let port_type = graph.any_param_type(param_id).unwrap();
851
852            let port_rect = Rect::from_center_size(
853                port_pos,
854                egui::vec2(DISTANCE_TO_CONNECT * 2.0, DISTANCE_TO_CONNECT * 2.0) * pan_zoom.zoom,
855            );
856
857            let sense = if ongoing_drag.is_some() {
858                Sense::hover()
859            } else {
860                Sense::click_and_drag()
861            };
862
863            let resp = ui.allocate_rect(port_rect, sense);
864
865            // Check if the distance between the port and the mouse is the distance to connect
866            let close_enough = if let Some(pointer_pos) = ui.ctx().pointer_hover_pos() {
867                port_rect.center().distance(pointer_pos) < DISTANCE_TO_CONNECT * pan_zoom.zoom
868            } else {
869                false
870            };
871
872            let port_color = if close_enough {
873                Color32::WHITE
874            } else {
875                port_type.data_type_color(user_state)
876            };
877            ui.painter().circle(
878                port_rect.center(),
879                5.0 * pan_zoom.zoom,
880                port_color,
881                Stroke::NONE,
882            );
883
884            if resp.drag_started() {
885                if is_connected_input {
886                    let input = param_id.assume_input();
887                    let corresp_output = graph
888                        .connection(input)
889                        .expect("Connection data should be valid");
890                    responses.push(NodeResponse::DisconnectEvent {
891                        input: param_id.assume_input(),
892                        output: corresp_output,
893                    });
894                } else {
895                    responses.push(NodeResponse::ConnectEventStarted(node_id, param_id));
896                }
897            }
898
899            if let Some((origin_node, origin_param)) = ongoing_drag {
900                if origin_node != node_id {
901                    // Don't allow self-loops
902                    if graph.any_param_type(origin_param).unwrap() == port_type
903                        && close_enough
904                        && ui.input(|i| i.pointer.any_released())
905                    {
906                        match (param_id, origin_param) {
907                            (AnyParameterId::Input(input), AnyParameterId::Output(output))
908                            | (AnyParameterId::Output(output), AnyParameterId::Input(input)) => {
909                                responses.push(NodeResponse::ConnectEventEnded { input, output });
910                            }
911                            _ => { /* Ignore in-in or out-out connections */ }
912                        }
913                    }
914                }
915            }
916
917            port_locations.insert(param_id, port_rect.center());
918        }
919
920        // Input ports
921        for ((_, param), port_height) in self.graph[self.node_id]
922            .inputs
923            .iter()
924            .zip(input_port_heights.into_iter())
925        {
926            let should_draw = match self.graph[*param].kind() {
927                InputParamKind::ConnectionOnly => true,
928                InputParamKind::ConstantOnly => false,
929                InputParamKind::ConnectionOrConstant => true,
930            };
931
932            if should_draw {
933                let port_pos = match self.orientation {
934                    NodeOrientation::LeftToRight => pos2(port_left, port_height),
935                    NodeOrientation::RightToLeft => pos2(port_right, port_height),
936                };
937                draw_port(
938                    pan_zoom,
939                    ui,
940                    self.graph,
941                    self.node_id,
942                    user_state,
943                    port_pos,
944                    &mut responses,
945                    AnyParameterId::Input(*param),
946                    self.port_locations,
947                    self.ongoing_drag,
948                    self.graph.connection(*param).is_some(),
949                );
950            }
951        }
952
953        // Output ports
954        for ((_, param), port_height) in self.graph[self.node_id]
955            .outputs
956            .iter()
957            .zip(output_port_heights.into_iter())
958        {
959            let port_pos = match self.orientation {
960                NodeOrientation::LeftToRight => pos2(port_right, port_height),
961                NodeOrientation::RightToLeft => pos2(port_left, port_height),
962            };
963            draw_port(
964                pan_zoom,
965                ui,
966                self.graph,
967                self.node_id,
968                user_state,
969                port_pos,
970                &mut responses,
971                AnyParameterId::Output(*param),
972                self.port_locations,
973                self.ongoing_drag,
974                false,
975            );
976        }
977
978        // Draw the background shape.
979        // NOTE: This code is a bit more involved than it needs to be because egui
980        // does not support drawing rectangles with asymmetrical round corners.
981
982        let (shape, outline) = {
983            let rounding_radius = (4.0 * pan_zoom.zoom) as u8;
984            let corner_radius = CornerRadius::same(rounding_radius);
985
986            let titlebar_height = title_height + margin.y;
987            let titlebar_rect =
988                Rect::from_min_size(outer_rect.min, vec2(outer_rect.width(), titlebar_height));
989            let titlebar = Shape::Rect(RectShape {
990                blur_width: 0.0,
991                rect: titlebar_rect,
992                corner_radius,
993                fill: self.graph[self.node_id]
994                    .user_data
995                    .titlebar_color(ui, self.node_id, self.graph, user_state)
996                    .unwrap_or_else(|| background_color.lighten(0.8)),
997                stroke: Stroke::NONE,
998                stroke_kind: StrokeKind::Inside,
999                round_to_pixels: None,
1000                brush: None,
1001            });
1002
1003            let body_rect = Rect::from_min_size(
1004                outer_rect.min + vec2(0.0, titlebar_height - rounding_radius as f32),
1005                vec2(outer_rect.width(), outer_rect.height() - titlebar_height),
1006            );
1007            let body = Shape::Rect(RectShape {
1008                blur_width: 0.0,
1009                rect: body_rect,
1010                corner_radius: CornerRadius::ZERO,
1011                fill: background_color,
1012                stroke: Stroke::NONE,
1013                stroke_kind: StrokeKind::Inside,
1014                round_to_pixels: None,
1015                brush: None,
1016            });
1017
1018            let bottom_body_rect = Rect::from_min_size(
1019                body_rect.min + vec2(0.0, body_rect.height() - titlebar_height * 0.5),
1020                vec2(outer_rect.width(), titlebar_height),
1021            );
1022            let bottom_body = Shape::Rect(RectShape {
1023                blur_width: 0.0,
1024                rect: bottom_body_rect,
1025                corner_radius,
1026                fill: background_color,
1027                stroke: Stroke::NONE,
1028                stroke_kind: StrokeKind::Inside,
1029                round_to_pixels: None,
1030                brush: None,
1031            });
1032
1033            let node_rect = titlebar_rect.union(body_rect).union(bottom_body_rect);
1034            let outline = if self.selected {
1035                Shape::Rect(RectShape {
1036                    blur_width: 0.0,
1037                    rect: node_rect.expand(1.0 * pan_zoom.zoom),
1038                    corner_radius,
1039                    fill: Color32::WHITE.lighten(0.8),
1040                    stroke: Stroke::NONE,
1041                    stroke_kind: StrokeKind::Inside,
1042                    round_to_pixels: None,
1043                    brush: None,
1044                })
1045            } else {
1046                Shape::Noop
1047            };
1048
1049            // Take note of the node rect, so the editor can use it later to compute intersections.
1050            self.node_rects.insert(self.node_id, node_rect);
1051
1052            (Shape::Vec(vec![titlebar, body, bottom_body]), outline)
1053        };
1054
1055        ui.painter().set(background_shape, shape);
1056        ui.painter().set(outline_shape, outline);
1057
1058        // --- Interaction ---
1059
1060        // Titlebar buttons
1061        let can_flip =
1062            self.graph.nodes[self.node_id]
1063                .user_data
1064                .can_flip(self.node_id, self.graph, user_state);
1065
1066        if can_flip && Self::flip_button(pan_zoom, ui, outer_rect).clicked() {
1067            *self.orientation = self.orientation.flip();
1068        }
1069
1070        let can_delete = self.graph.nodes[self.node_id].user_data.can_delete(
1071            self.node_id,
1072            self.graph,
1073            user_state,
1074        );
1075
1076        if can_delete && Self::close_button(pan_zoom, ui, outer_rect).clicked() {
1077            responses.push(NodeResponse::DeleteNodeUi(self.node_id));
1078        };
1079
1080        // Movement
1081        let drag_delta = window_response.drag_delta();
1082        if drag_delta.length_sq() > 0.0 {
1083            responses.push(NodeResponse::MoveNode {
1084                node: self.node_id,
1085                drag_delta,
1086            });
1087            responses.push(NodeResponse::RaiseNode(self.node_id));
1088        }
1089
1090        // Node selection
1091        //
1092        // HACK: Only set the select response when no other response is active.
1093        // This prevents some issues.
1094        if responses.is_empty() && window_response.clicked_by(PointerButton::Primary) {
1095            responses.push(NodeResponse::SelectNode(self.node_id));
1096            responses.push(NodeResponse::RaiseNode(self.node_id));
1097        }
1098
1099        responses
1100    }
1101
1102    fn close_button(pan_zoom: &PanZoom, ui: &mut Ui, node_rect: Rect) -> Response {
1103        // Measurements
1104        let margin = 8.0 * pan_zoom.zoom;
1105        let size = 10.0 * pan_zoom.zoom;
1106        let stroke_width = 2.0;
1107        let offs = margin + size / 2.0;
1108
1109        let position = pos2(node_rect.right() - offs, node_rect.top() + offs);
1110        let rect = Rect::from_center_size(position, vec2(size, size));
1111        let resp = ui.allocate_rect(rect, Sense::click());
1112
1113        let dark_mode = ui.visuals().dark_mode;
1114        let color = if resp.clicked() {
1115            if dark_mode {
1116                color_from_hex("#ffffff").unwrap()
1117            } else {
1118                color_from_hex("#000000").unwrap()
1119            }
1120        } else if resp.hovered() {
1121            if dark_mode {
1122                color_from_hex("#dddddd").unwrap()
1123            } else {
1124                color_from_hex("#222222").unwrap()
1125            }
1126        } else {
1127            #[allow(clippy::collapsible_else_if)]
1128            if dark_mode {
1129                color_from_hex("#aaaaaa").unwrap()
1130            } else {
1131                color_from_hex("#555555").unwrap()
1132            }
1133        };
1134        let stroke = Stroke {
1135            width: stroke_width,
1136            color,
1137        };
1138
1139        ui.painter()
1140            .line_segment([rect.left_top(), rect.right_bottom()], stroke);
1141        ui.painter()
1142            .line_segment([rect.right_top(), rect.left_bottom()], stroke);
1143
1144        resp
1145    }
1146
1147    fn flip_button(pan_zoom: &PanZoom, ui: &mut Ui, node_rect: Rect) -> Response {
1148        // Measurements
1149        let margin = 8.0 * pan_zoom.zoom;
1150        let size = 10.0 * pan_zoom.zoom;
1151        let stroke_width = 2.0;
1152        let offs = margin + size / 2.0;
1153
1154        let position = pos2(node_rect.right() - offs * 2.0 - 4.0, node_rect.top() + offs);
1155        let rect = Rect::from_center_size(position, vec2(size, size));
1156        let resp = ui.allocate_rect(rect, Sense::click());
1157
1158        let dark_mode = ui.visuals().dark_mode;
1159        let color = if resp.clicked() {
1160            if dark_mode {
1161                color_from_hex("#ffffff").unwrap()
1162            } else {
1163                color_from_hex("#000000").unwrap()
1164            }
1165        } else if resp.hovered() {
1166            if dark_mode {
1167                color_from_hex("#dddddd").unwrap()
1168            } else {
1169                color_from_hex("#222222").unwrap()
1170            }
1171        } else {
1172            #[allow(clippy::collapsible_else_if)]
1173            if dark_mode {
1174                color_from_hex("#aaaaaa").unwrap()
1175            } else {
1176                color_from_hex("#555555").unwrap()
1177            }
1178        };
1179        let stroke = Stroke {
1180            width: stroke_width,
1181            color,
1182        };
1183
1184        let lines = [
1185            [rect.left_center(), rect.right_center()],
1186            [
1187                rect.left_center(),
1188                rect.left_center().lerp(rect.center_top(), 0.5),
1189            ],
1190            [
1191                rect.left_center(),
1192                rect.left_center().lerp(rect.center_bottom(), 0.5),
1193            ],
1194            [
1195                rect.right_center(),
1196                rect.right_center().lerp(rect.center_top(), 0.5),
1197            ],
1198            [
1199                rect.right_center(),
1200                rect.right_center().lerp(rect.center_bottom(), 0.5),
1201            ],
1202        ];
1203
1204        for line in lines {
1205            ui.painter().line_segment(line, stroke);
1206        }
1207
1208        resp
1209    }
1210}