Skip to main content

a3s_flow/
registry.rs

1//! Node registry — maps node type strings to [`Node`] implementations.
2//!
3//! [`NodeRegistry`] is the extension point for adding custom node types.
4//! The default registry ships with all built-in Dify-compatible nodes.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8
9use serde::{Deserialize, Serialize};
10use serde_json::{json, Value};
11
12use crate::error::{FlowError, Result};
13use crate::node::Node;
14use crate::nodes::assign::AssignNode;
15use crate::nodes::code::CodeNode;
16use crate::nodes::cond::IfElseNode;
17use crate::nodes::context_get::ContextGetNode;
18use crate::nodes::context_set::ContextSetNode;
19use crate::nodes::csv_parse::CsvParseNode;
20use crate::nodes::end::EndNode;
21use crate::nodes::http::HttpRequestNode;
22use crate::nodes::iteration::IterationNode;
23use crate::nodes::list_operator::ListOperatorNode;
24use crate::nodes::llm::LlmNode;
25use crate::nodes::loop_node::LoopNode;
26use crate::nodes::noop::NoopNode;
27use crate::nodes::parameter_extractor::ParameterExtractorNode;
28use crate::nodes::question_classifier::QuestionClassifierNode;
29use crate::nodes::start::StartNode;
30use crate::nodes::subflow::SubFlowNode;
31use crate::nodes::template_transform::TemplateTransformNode;
32use crate::nodes::variable_aggregator::VariableAggregatorNode;
33
34/// Static capability descriptor for a registered node type.
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
36pub struct NodeDescriptor {
37    /// Stable type string used in flow definitions.
38    pub node_type: String,
39    /// Human-friendly display name suitable for UI lists.
40    pub display_name: String,
41    /// Short category label for grouping similar node types.
42    pub category: String,
43    /// Single-sentence summary of what the node does.
44    pub summary: String,
45    /// Suggested default config payload for editor-side node creation.
46    #[serde(default = "default_node_config")]
47    pub default_data: Value,
48    /// Optional field hints for editors and capability discovery UIs.
49    #[serde(default)]
50    pub fields: Vec<NodeFieldDescriptor>,
51}
52
53/// Field-level hint for a node configuration shape.
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
55pub struct NodeFieldDescriptor {
56    pub key: String,
57    pub kind: String,
58    pub required: bool,
59    pub description: String,
60}
61
62fn default_node_config() -> Value {
63    json!({})
64}
65
66/// Registry mapping node type strings to [`Node`] implementations.
67///
68/// `Clone` is cheap — all values are `Arc`-wrapped.
69#[derive(Clone)]
70pub struct NodeRegistry {
71    nodes: HashMap<String, Arc<dyn Node>>,
72    descriptors: HashMap<String, NodeDescriptor>,
73    builtin_types: HashSet<String>,
74}
75
76impl NodeRegistry {
77    /// Create a registry pre-loaded with all built-in node types.
78    ///
79    /// | Type string | Node |
80    /// |-------------|------|
81    /// | `"noop"` | [`NoopNode`] — passes inputs through |
82    /// | `"start"` | [`StartNode`] — Dify-compatible input declaration with defaults |
83    /// | `"end"` | [`EndNode`] — gathers upstream values via JSON pointer paths |
84    /// | `"http-request"` | [`HttpRequestNode`] — HTTP GET/POST/PUT/DELETE/PATCH |
85    /// | `"if-else"` | [`IfElseNode`] — multi-case conditional routing |
86    /// | `"template-transform"` | [`TemplateTransformNode`] — Jinja2 rendering |
87    /// | `"variable-aggregator"` | [`VariableAggregatorNode`] — first non-null fan-in |
88    /// | `"code"` | [`CodeNode`] — sandboxed Rhai script |
89    /// | `"csv-parse"` | [`CsvParseNode`] — parse CSV text into JSON array |
90    /// | `"iteration"` | [`IterationNode`] — sub-flow loop over an array |
91    /// | `"sub-flow"` | [`SubFlowNode`] — execute a named flow as an inline step |
92    /// | `"llm"` | [`LlmNode`] — OpenAI-compatible chat completion |
93    /// | `"question-classifier"` | [`QuestionClassifierNode`] — LLM-powered routing |
94    /// | `"assign"` | [`AssignNode`] — write key-value pairs into the flow's variable scope |
95    /// | `"context-get"` | [`ContextGetNode`] — read keys from the shared execution context |
96    /// | `"context-set"` | [`ContextSetNode`] — write key-value pairs into the shared execution context |
97    /// | `"parameter-extractor"` | [`ParameterExtractorNode`] — LLM-powered structured parameter extraction |
98    /// | `"loop"` | [`LoopNode`] — while-loop over a sub-flow with break condition |
99    /// | `"list-operator"` | [`ListOperatorNode`] — filter / sort / deduplicate / limit a JSON array |
100    pub fn with_defaults() -> Self {
101        let mut r = Self {
102            nodes: HashMap::new(),
103            descriptors: HashMap::new(),
104            builtin_types: HashSet::new(),
105        };
106        r.register_builtin(
107            Arc::new(NoopNode),
108            "No-op",
109            "utility",
110            "Pass inputs through unchanged for placeholder or fan-in flows.",
111            json!({}),
112            vec![],
113        );
114        r.register_builtin(
115            Arc::new(StartNode),
116            "Start",
117            "control",
118            "Declares workflow inputs and default values at the entry point.",
119            json!({ "inputs": [] }),
120            vec![field(
121                "inputs",
122                "array",
123                false,
124                "Input variable declarations.",
125            )],
126        );
127        r.register_builtin(
128            Arc::new(EndNode),
129            "End",
130            "control",
131            "Collects final outputs from upstream nodes.",
132            json!({ "outputs": {} }),
133            vec![field("outputs", "object", false, "Output field mapping.")],
134        );
135        r.register_builtin(
136            Arc::new(HttpRequestNode),
137            "HTTP Request",
138            "integration",
139            "Calls external HTTP APIs with configurable method, headers, and body.",
140            json!({ "method": "GET", "url": "", "headers": {} }),
141            vec![
142                field("method", "string", true, "HTTP method."),
143                field("url", "string", true, "Request URL."),
144            ],
145        );
146        r.register_builtin(
147            Arc::new(IfElseNode),
148            "If/Else",
149            "logic",
150            "Routes execution to a branch based on evaluated conditions.",
151            json!({ "cases": [{ "id": "case_1", "logical_operator": "and", "conditions": [] }] }),
152            vec![field(
153                "cases",
154                "array",
155                true,
156                "Branch definitions keyed by case id.",
157            )],
158        );
159        r.register_builtin(
160            Arc::new(TemplateTransformNode),
161            "Template Transform",
162            "transform",
163            "Renders structured or text output from a Jinja-style template.",
164            json!({ "template": "" }),
165            vec![field(
166                "template",
167                "string",
168                true,
169                "Jinja-style template body.",
170            )],
171        );
172        r.register_builtin(
173            Arc::new(VariableAggregatorNode),
174            "Variable Aggregator",
175            "transform",
176            "Selects the first non-null value from multiple upstream branches.",
177            json!({ "mode": "first_non_null" }),
178            vec![field("mode", "string", false, "Aggregation strategy.")],
179        );
180        r.register_builtin(
181            Arc::new(CodeNode),
182            "Code",
183            "compute",
184            "Executes sandboxed Rhai code against the current flow state.",
185            json!({ "script": "" }),
186            vec![field("script", "string", true, "Rhai script body.")],
187        );
188        r.register_builtin(
189            Arc::new(CsvParseNode),
190            "CSV Parse",
191            "transform",
192            "Parses CSV text into structured JSON rows.",
193            json!({ "text": "" }),
194            vec![field("text", "string", true, "Raw CSV text input.")],
195        );
196        r.register_builtin(
197            Arc::new(IterationNode),
198            "Iteration",
199            "control",
200            "Runs a sub-flow for each item in an input array.",
201            json!({ "input_selector": "", "output_selector": "", "mode": "parallel", "flow": { "nodes": [], "edges": [] } }),
202            vec![
203                field("input_selector", "string", true, "Selector for input array."),
204                field("flow", "object", true, "Nested flow definition."),
205            ],
206        );
207        r.register_builtin(
208            Arc::new(SubFlowNode),
209            "Sub-flow",
210            "control",
211            "Invokes another named flow as a nested step.",
212            json!({ "flow_name": "" }),
213            vec![field("flow_name", "string", true, "Target named flow.")],
214        );
215        r.register_builtin(
216            Arc::new(LlmNode),
217            "LLM",
218            "ai",
219            "Sends a prompt to an OpenAI-compatible chat completion model.",
220            json!({ "model": "gpt-4o-mini", "system_prompt": "", "user_prompt": "", "api_base": "https://api.openai.com/v1", "api_key": "", "temperature": 0.7 }),
221            vec![
222                field("model", "string", true, "Model identifier."),
223                field("user_prompt", "string", true, "User prompt template."),
224            ],
225        );
226        r.register_builtin(
227            Arc::new(QuestionClassifierNode),
228            "Question Classifier",
229            "ai",
230            "Classifies user intent into predefined categories with an LLM.",
231            json!({ "model": "gpt-4o-mini", "question": "", "classes": [], "api_base": "https://api.openai.com/v1", "api_key": "", "temperature": 0 }),
232            vec![
233                field("question", "string", true, "Input question to classify."),
234                field("classes", "array", true, "Target label list."),
235            ],
236        );
237        r.register_builtin(
238            Arc::new(AssignNode),
239            "Assign",
240            "context",
241            "Writes key-value pairs into the flow variable scope.",
242            json!({ "assigns": {} }),
243            vec![field("assigns", "object", true, "Key-value assignments.")],
244        );
245        r.register_builtin(
246            Arc::new(ContextGetNode),
247            "Context Get",
248            "context",
249            "Reads values from shared execution context.",
250            json!({ "keys": [] }),
251            vec![field("keys", "array", true, "Context keys to read.")],
252        );
253        r.register_builtin(
254            Arc::new(ContextSetNode),
255            "Context Set",
256            "context",
257            "Writes values into shared execution context.",
258            json!({ "values": {} }),
259            vec![field("values", "object", true, "Context values to write.")],
260        );
261        r.register_builtin(
262            Arc::new(ParameterExtractorNode),
263            "Parameter Extractor",
264            "ai",
265            "Extracts structured parameters from natural language with an LLM.",
266            json!({ "model": "gpt-4o-mini", "query": "", "parameters": [], "api_base": "https://api.openai.com/v1", "api_key": "", "temperature": 0 }),
267            vec![
268                field("query", "string", true, "Natural language query."),
269                field("parameters", "array", true, "Parameter definitions."),
270            ],
271        );
272        r.register_builtin(
273            Arc::new(LoopNode),
274            "Loop",
275            "control",
276            "Repeats a sub-flow until a break condition is met.",
277            json!({ "output_selector": "", "max_iterations": 10, "flow": { "nodes": [], "edges": [] } }),
278            vec![
279                field("max_iterations", "number", false, "Maximum loop iterations."),
280                field("flow", "object", true, "Loop body flow definition."),
281            ],
282        );
283        r.register_builtin(
284            Arc::new(ListOperatorNode),
285            "List Operator",
286            "transform",
287            "Filters, sorts, deduplicates, or limits a JSON array.",
288            json!({ "operation": "limit", "input": "", "limit": 10 }),
289            vec![field("operation", "string", true, "List operation kind.")],
290        );
291        r
292    }
293
294    /// Register a custom node implementation.
295    ///
296    /// The node's [`Node::node_type`] is used as the lookup key.
297    /// Overwrites any existing registration for the same type string.
298    pub fn register(&mut self, node: Arc<dyn Node>) {
299        let node_type = node.node_type().to_string();
300        self.nodes.insert(node_type.clone(), node);
301        self.descriptors
302            .entry(node_type.clone())
303            .or_insert_with(|| NodeDescriptor {
304                node_type: node_type.clone(),
305                display_name: node_type.clone(),
306                category: "custom".to_string(),
307                summary: "Custom node registered at runtime.".to_string(),
308                default_data: default_node_config(),
309                fields: Vec::new(),
310            });
311    }
312
313    /// Register a node implementation with explicit discovery metadata.
314    pub fn register_with_descriptor(&mut self, node: Arc<dyn Node>, descriptor: NodeDescriptor) {
315        let node_type = node.node_type().to_string();
316        self.nodes.insert(node_type.clone(), node);
317        self.descriptors.insert(
318            node_type.clone(),
319            NodeDescriptor {
320                node_type,
321                ..descriptor
322            },
323        );
324    }
325
326    /// Remove a registered node type and its discovery metadata.
327    ///
328    /// Returns `true` if the node type existed and was removed, `false`
329    /// otherwise.
330    pub fn unregister(&mut self, node_type: &str) -> Result<bool> {
331        if self.is_builtin(node_type) {
332            return Err(FlowError::ProtectedNodeType(node_type.to_string()));
333        }
334        let removed_node = self.nodes.remove(node_type).is_some();
335        let removed_descriptor = self.descriptors.remove(node_type).is_some();
336        Ok(removed_node || removed_descriptor)
337    }
338
339    /// Return whether a node type is part of the built-in catalog.
340    pub fn is_builtin(&self, node_type: &str) -> bool {
341        self.builtin_types.contains(node_type)
342    }
343
344    /// Look up a node implementation by type string.
345    pub fn get(&self, node_type: &str) -> Result<Arc<dyn Node>> {
346        self.nodes.get(node_type).cloned().ok_or_else(|| {
347            FlowError::InvalidDefinition(format!("unknown node type: '{node_type}'"))
348        })
349    }
350
351    /// Return all registered node type strings, sorted alphabetically.
352    pub fn list_types(&self) -> Vec<String> {
353        let mut types: Vec<String> = self.nodes.keys().cloned().collect();
354        types.sort();
355        types
356    }
357
358    /// Return node descriptors sorted by node type.
359    pub fn list_descriptors(&self) -> Vec<NodeDescriptor> {
360        let mut descriptors: Vec<NodeDescriptor> = self.descriptors.values().cloned().collect();
361        descriptors.sort_by(|a, b| a.node_type.cmp(&b.node_type));
362        descriptors
363    }
364
365    fn register_builtin(
366        &mut self,
367        node: Arc<dyn Node>,
368        display_name: &str,
369        category: &str,
370        summary: &str,
371        default_data: Value,
372        fields: Vec<NodeFieldDescriptor>,
373    ) {
374        let node_type = node.node_type().to_string();
375        self.register_with_descriptor(
376            node,
377            NodeDescriptor {
378                node_type: String::new(),
379                display_name: display_name.to_string(),
380                category: category.to_string(),
381                summary: summary.to_string(),
382                default_data,
383                fields,
384            },
385        );
386        self.builtin_types.insert(node_type);
387    }
388}
389
390fn field(key: &str, kind: &str, required: bool, description: &str) -> NodeFieldDescriptor {
391    NodeFieldDescriptor {
392        key: key.to_string(),
393        kind: kind.to_string(),
394        required,
395        description: description.to_string(),
396    }
397}
398
399impl Default for NodeRegistry {
400    fn default() -> Self {
401        Self::with_defaults()
402    }
403}