Skip to main content

ferrum_flow/canvas/
node_renderer.rs

1use gpui::*;
2use std::collections::HashMap;
3
4use crate::node::Node;
5use crate::plugin::{NodeCardVariant, RenderContext};
6use crate::{Graph, Port, PortId, PortPosition};
7
8pub trait NodeRenderer: Send + Sync {
9    /// render node inner UI
10    fn render(&self, node: &Node, ctx: &mut RenderContext) -> AnyElement;
11
12    // custom render port UI
13    fn port_render(&self, node: &Node, port: &Port, ctx: &mut RenderContext) -> Option<AnyElement> {
14        let size = port.size;
15        let position = port_screen_position(node, port.id, &ctx)?;
16
17        Some(
18            div()
19                .absolute()
20                .left(position.x - size.width / 2.0 * ctx.viewport.zoom)
21                .top(position.y - size.height / 2.0 * ctx.viewport.zoom)
22                .w(size.width * ctx.viewport.zoom)
23                .h(size.height * ctx.viewport.zoom)
24                .rounded_full()
25                .bg(rgb(ctx.theme.default_port_fill))
26                .into_any(),
27        )
28    }
29
30    /// computing the position of port relative to node
31    /// built-in Node Plugin is cached this.
32    fn port_offset(&self, node: &Node, port: &Port, graph: &Graph) -> Point<Pixels> {
33        let ports: Vec<&Port> = graph
34            .ports
35            .values()
36            .filter(|p| p.node_id == node.id && p.kind == port.kind && p.position == port.position)
37            .collect();
38
39        let total = ports.len() as f32;
40        let index = port.index as f32;
41        let size = node.size;
42
43        match port.position {
44            PortPosition::Left => {
45                let spacing = size.height / (total + 1.0);
46                Point::new(px(0.0), spacing * (index + 1.0))
47            }
48            PortPosition::Right => {
49                let spacing = size.height / (total + 1.0);
50                Point::new(size.width, spacing * (index + 1.0))
51            }
52            PortPosition::Top => {
53                let spacing = size.width / (total + 1.0);
54                Point::new(spacing * (index + 1.0), px(0.0))
55            }
56            PortPosition::Bottom => {
57                let spacing = size.width / (total + 1.0);
58                Point::new(spacing * (index + 1.0), size.height)
59            }
60        }
61    }
62}
63
64pub struct RendererRegistry {
65    map: HashMap<String, Box<dyn NodeRenderer>>,
66    default: Box<dyn NodeRenderer>,
67    undefined: Box<dyn NodeRenderer>,
68}
69
70impl RendererRegistry {
71    pub fn new() -> Self {
72        Self {
73            map: HashMap::new(),
74            default: Box::new(DefaultNodeRenderer {}),
75            undefined: Box::new(UndefinedNodeRenderer {}),
76        }
77    }
78
79    pub fn register<R>(&mut self, name: impl Into<String>, renderer: R)
80    where
81        R: NodeRenderer + 'static,
82    {
83        self.map.insert(name.into(), Box::new(renderer));
84    }
85
86    pub fn get(&self, name: &str) -> &dyn NodeRenderer {
87        if name.is_empty() {
88            return self.default.as_ref();
89        }
90
91        self.map
92            .get(name)
93            .map(|r| r.as_ref())
94            .unwrap_or(self.undefined.as_ref())
95    }
96}
97
98struct DefaultNodeRenderer;
99
100impl NodeRenderer for DefaultNodeRenderer {
101    fn render(&self, node: &Node, ctx: &mut RenderContext) -> AnyElement {
102        let node_id = node.id;
103        let selected = ctx
104            .graph
105            .selected_node
106            .iter()
107            .find(|id| **id == node_id)
108            .is_some();
109
110        ctx.node_card_shell(node, selected, NodeCardVariant::Default)
111            .child(
112                div()
113                    .size_full()
114                    .flex()
115                    .items_center()
116                    .justify_center()
117                    .text_center()
118                    .px_2()
119                    .child(default_node_caption(node))
120                    .text_color(rgb(ctx.theme.node_caption_text)),
121            )
122            .into_any()
123    }
124}
125
126struct UndefinedNodeRenderer;
127
128impl NodeRenderer for UndefinedNodeRenderer {
129    fn render(&self, node: &Node, ctx: &mut RenderContext) -> AnyElement {
130        ctx.node_card_shell(node, false, NodeCardVariant::UndefinedType)
131            .child(
132                div()
133                    .size_full()
134                    .flex()
135                    .items_center()
136                    .justify_center()
137                    .text_center()
138                    .px_2()
139                    .child(undefined_node_caption(node))
140                    .text_color(rgb(ctx.theme.undefined_node_caption_text)),
141            )
142            .into_any()
143    }
144}
145
146pub fn port_screen_position(
147    node: &Node,
148    port_id: PortId,
149    ctx: &RenderContext,
150) -> Option<Point<Pixels>> {
151    let node_pos = node.point();
152
153    let offset = ctx.port_offset_cached(&node.id, &port_id)?;
154
155    Some(ctx.viewport.world_to_screen(node_pos + offset))
156}
157
158fn data_title(data: &serde_json::Value) -> Option<String> {
159    if let Some(s) = data.get("label").and_then(|v| v.as_str()) {
160        let t = s.trim();
161        if !t.is_empty() {
162            return Some(t.to_string());
163        }
164    }
165    None
166}
167
168/// Label for [`DefaultNodeRenderer`]: user-facing title from `data`, else `node_type`, else a generic word.
169/// UUID stays off-canvas; use debug/inspector/tooltip if operators need the id.
170pub fn default_node_caption(node: &Node) -> String {
171    if let Some(s) = data_title(&node.data) {
172        return s;
173    }
174    if !node.node_type.is_empty() {
175        return node.node_type.clone();
176    }
177    "Node".to_string()
178}
179
180fn undefined_node_caption(node: &Node) -> String {
181    if !node.node_type.is_empty() {
182        return format!("Unknown type: {}", node.node_type);
183    }
184    "Unknown node type".to_string()
185}