Skip to main content

a3s_flow/nodes/
parameter_extractor.rs

1//! `"parameter-extractor"` node — LLM-powered structured parameter extraction.
2//!
3//! Uses an LLM to infer and extract typed parameters from natural language
4//! text. The output is a JSON object keyed by the declared parameter names,
5//! suitable for passing directly to downstream HTTP-request or code nodes.
6//!
7//! Works with any OpenAI-compatible API — same backend options as the `"llm"`
8//! node (OpenAI, Ollama, vLLM, Together AI, etc.).
9//!
10//! # Config schema
11//!
12//! ```json
13//! {
14//!   "model":      "gpt-4o-mini",
15//!   "query":      "Book a flight from {{ city }} to London on {{ date }}",
16//!   "parameters": [
17//!     { "name": "origin",      "type": "string",  "description": "Departure city", "required": true },
18//!     { "name": "destination", "type": "string",  "description": "Arrival city",   "required": true },
19//!     { "name": "date",        "type": "string",  "description": "ISO date",        "required": false }
20//!   ],
21//!   "api_base":    "https://api.openai.com/v1",
22//!   "api_key":     "sk-...",
23//!   "temperature": 0.0
24//! }
25//! ```
26//!
27//! | Field | Type | Required | Description |
28//! |-------|------|:--------:|-------------|
29//! | `model` | string | ✅ | Model identifier |
30//! | `query` | string | ✅ | Text to extract from — rendered as Jinja2 template |
31//! | `parameters` | array | ✅ | At least one parameter declaration |
32//! | `parameters[].name` | string | ✅ | Parameter key in the output object |
33//! | `parameters[].type` | string | — | Type hint for the LLM (`"string"`, `"number"`, `"boolean"`, `"object"`, `"array"`) — default `"string"` |
34//! | `parameters[].description` | string | — | Extraction guidance for the LLM |
35//! | `parameters[].required` | boolean | — | If true, extraction fails when the LLM cannot find the value (default `false`) |
36//! | `api_base`, `api_key`, `temperature`, `max_tokens` | | — | Same as `"llm"` node |
37//!
38//! ## Template context
39//!
40//! `query` is a Jinja2 template rendered with the same context as the `"llm"` node:
41//! global `variables` plus upstream node outputs keyed by node ID.
42//!
43//! # Output schema
44//!
45//! ```json
46//! { "origin": "Paris", "destination": "London", "date": null }
47//! ```
48//!
49//! Optional parameters that the LLM cannot find are set to `null`.
50//! Required parameters that cannot be found cause an `Internal` error.
51
52use async_trait::async_trait;
53use serde::Deserialize;
54use serde_json::Value;
55
56use crate::error::{FlowError, Result};
57use crate::node::{ExecContext, Node};
58
59use super::llm::{build_jinja_context, do_chat_completion, render, ChatMessage, LlmConfig};
60
61// ── Parameter declaration ──────────────────────────────────────────────────
62
63#[derive(Debug, Deserialize)]
64struct ParamDecl {
65    name: String,
66    #[serde(rename = "type", default = "default_param_type")]
67    param_type: String,
68    #[serde(default)]
69    description: String,
70    #[serde(default)]
71    required: bool,
72}
73
74fn default_param_type() -> String {
75    "string".into()
76}
77
78// ── Node ──────────────────────────────────────────────────────────────────
79
80/// Parameter extractor node — LLM-powered structured extraction.
81pub struct ParameterExtractorNode;
82
83#[async_trait]
84impl Node for ParameterExtractorNode {
85    fn node_type(&self) -> &str {
86        "parameter-extractor"
87    }
88
89    async fn execute(&self, ctx: ExecContext) -> Result<Value> {
90        let config = LlmConfig::from_connection_data(&ctx.data)?;
91
92        let query_template = ctx.data["query"].as_str().ok_or_else(|| {
93            FlowError::InvalidDefinition("parameter-extractor: missing data.query".into())
94        })?;
95
96        let params: Vec<ParamDecl> = serde_json::from_value(ctx.data["parameters"].clone())
97            .map_err(|e| {
98                FlowError::InvalidDefinition(format!(
99                    "parameter-extractor: invalid parameters declaration: {e}"
100                ))
101            })?;
102
103        if params.is_empty() {
104            return Err(FlowError::InvalidDefinition(
105                "parameter-extractor: at least one parameter required".into(),
106            ));
107        }
108
109        let jinja_ctx = build_jinja_context(&ctx);
110        let query = render(query_template, &jinja_ctx)?;
111
112        // ── Build extraction prompt ────────────────────────────────────────
113        let param_list: String = params
114            .iter()
115            .map(|p| {
116                let req = if p.required {
117                    " (required)"
118                } else {
119                    " (optional)"
120                };
121                if p.description.is_empty() {
122                    format!("- {}: {}{}", p.name, p.param_type, req)
123                } else {
124                    format!("- {}: {} — {}{}", p.name, p.param_type, p.description, req)
125                }
126            })
127            .collect::<Vec<_>>()
128            .join("\n");
129
130        let param_names: String = params
131            .iter()
132            .map(|p| format!("\"{}\"", p.name))
133            .collect::<Vec<_>>()
134            .join(", ");
135
136        let system_prompt = format!(
137            "You are a parameter extraction assistant. Extract the following parameters \
138             from the user's text and return them as a JSON object.\n\
139             Parameters to extract:\n{param_list}\n\n\
140             Return ONLY a valid JSON object with keys: {param_names}.\n\
141             For optional parameters that cannot be found, use null.\n\
142             Do not include any explanation or markdown code fences."
143        );
144
145        let messages = vec![
146            ChatMessage {
147                role: "system".into(),
148                content: system_prompt,
149            },
150            ChatMessage {
151                role: "user".into(),
152                content: query,
153            },
154        ];
155
156        let result = do_chat_completion(
157            &config.api_base,
158            &config.api_key,
159            &config.model,
160            messages,
161            Some(config.temperature),
162            config.max_tokens,
163        )
164        .await?;
165
166        // ── Parse LLM response ─────────────────────────────────────────────
167        let json_text = strip_markdown_fences(result.text.trim());
168        let extracted: Value = serde_json::from_str(json_text).map_err(|e| {
169            FlowError::Internal(format!(
170                "parameter-extractor: LLM returned invalid JSON: {e}\nResponse: {}",
171                result.text
172            ))
173        })?;
174
175        let obj = extracted.as_object().ok_or_else(|| {
176            FlowError::Internal("parameter-extractor: LLM response is not a JSON object".into())
177        })?;
178
179        // Validate required parameters.
180        for param in &params {
181            if param.required && matches!(obj.get(&param.name), None | Some(Value::Null)) {
182                return Err(FlowError::Internal(format!(
183                    "parameter-extractor: required parameter '{}' not found in LLM response",
184                    param.name
185                )));
186            }
187        }
188
189        Ok(extracted)
190    }
191}
192
193// ── Helpers ───────────────────────────────────────────────────────────────
194
195/// Strip markdown code fences from a string (e.g. ` ```json ... ``` `).
196fn strip_markdown_fences(text: &str) -> &str {
197    // Handle ```json ... ``` or ``` ... ```
198    if let Some(start) = text.find("```") {
199        let after_fence = &text[start + 3..];
200        // Skip optional language tag on the same line.
201        let content_start = after_fence.find('\n').map(|n| n + 1).unwrap_or(0);
202        let content = &after_fence[content_start..];
203        if let Some(end) = content.rfind("```") {
204            return content[..end].trim();
205        }
206    }
207    text
208}
209
210// ── Tests ─────────────────────────────────────────────────────────────────────
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use serde_json::json;
216    use std::collections::HashMap;
217
218    // ── strip_markdown_fences ──────────────────────────────────────────────
219
220    #[test]
221    fn strips_json_code_fence() {
222        let input = "```json\n{\"a\": 1}\n```";
223        assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
224    }
225
226    #[test]
227    fn strips_plain_code_fence() {
228        let input = "```\n{\"a\": 1}\n```";
229        assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
230    }
231
232    #[test]
233    fn passthrough_when_no_fence() {
234        let input = "{\"a\": 1}";
235        assert_eq!(strip_markdown_fences(input), input);
236    }
237
238    // ── Config validation ──────────────────────────────────────────────────
239
240    #[tokio::test]
241    async fn rejects_missing_model() {
242        let node = ParameterExtractorNode;
243        let err = node
244            .execute(ExecContext {
245                data: json!({
246                    "query": "hello",
247                    "parameters": [{ "name": "x", "type": "string" }]
248                }),
249                ..Default::default()
250            })
251            .await
252            .unwrap_err();
253        assert!(matches!(err, FlowError::InvalidDefinition(_)));
254    }
255
256    #[tokio::test]
257    async fn rejects_missing_query() {
258        let node = ParameterExtractorNode;
259        let err = node
260            .execute(ExecContext {
261                data: json!({
262                    "model": "gpt-4o-mini",
263                    "parameters": [{ "name": "x" }]
264                }),
265                ..Default::default()
266            })
267            .await
268            .unwrap_err();
269        assert!(matches!(err, FlowError::InvalidDefinition(_)));
270    }
271
272    #[tokio::test]
273    async fn rejects_empty_parameters() {
274        let node = ParameterExtractorNode;
275        let err = node
276            .execute(ExecContext {
277                data: json!({
278                    "model": "gpt-4o-mini",
279                    "query":  "hello",
280                    "parameters": []
281                }),
282                ..Default::default()
283            })
284            .await
285            .unwrap_err();
286        assert!(matches!(err, FlowError::InvalidDefinition(_)));
287    }
288
289    #[tokio::test]
290    async fn rejects_invalid_template() {
291        // Template render error fires before any network call.
292        let node = ParameterExtractorNode;
293        let err = node
294            .execute(ExecContext {
295                data: json!({
296                    "model": "gpt-4o-mini",
297                    "query":  "{{ x | bad_filter }}",
298                    "parameters": [{ "name": "x" }]
299                }),
300                variables: HashMap::new(),
301                inputs: HashMap::new(),
302                ..Default::default()
303            })
304            .await
305            .unwrap_err();
306        assert!(matches!(err, FlowError::Internal(_)));
307    }
308}