a3s_flow/nodes/
parameter_extractor.rs1use async_trait::async_trait;
53use serde::Deserialize;
54use serde_json::Value;
55
56use crate::error::{FlowError, Result};
57use crate::node::{ExecContext, Node};
58
59use super::llm::{build_jinja_context, do_chat_completion, render, ChatMessage, LlmConfig};
60
61#[derive(Debug, Deserialize)]
64struct ParamDecl {
65 name: String,
66 #[serde(rename = "type", default = "default_param_type")]
67 param_type: String,
68 #[serde(default)]
69 description: String,
70 #[serde(default)]
71 required: bool,
72}
73
74fn default_param_type() -> String {
75 "string".into()
76}
77
78pub struct ParameterExtractorNode;
82
83#[async_trait]
84impl Node for ParameterExtractorNode {
85 fn node_type(&self) -> &str {
86 "parameter-extractor"
87 }
88
89 async fn execute(&self, ctx: ExecContext) -> Result<Value> {
90 let config = LlmConfig::from_connection_data(&ctx.data)?;
91
92 let query_template = ctx.data["query"].as_str().ok_or_else(|| {
93 FlowError::InvalidDefinition("parameter-extractor: missing data.query".into())
94 })?;
95
96 let params: Vec<ParamDecl> = serde_json::from_value(ctx.data["parameters"].clone())
97 .map_err(|e| {
98 FlowError::InvalidDefinition(format!(
99 "parameter-extractor: invalid parameters declaration: {e}"
100 ))
101 })?;
102
103 if params.is_empty() {
104 return Err(FlowError::InvalidDefinition(
105 "parameter-extractor: at least one parameter required".into(),
106 ));
107 }
108
109 let jinja_ctx = build_jinja_context(&ctx);
110 let query = render(query_template, &jinja_ctx)?;
111
112 let param_list: String = params
114 .iter()
115 .map(|p| {
116 let req = if p.required {
117 " (required)"
118 } else {
119 " (optional)"
120 };
121 if p.description.is_empty() {
122 format!("- {}: {}{}", p.name, p.param_type, req)
123 } else {
124 format!("- {}: {} — {}{}", p.name, p.param_type, p.description, req)
125 }
126 })
127 .collect::<Vec<_>>()
128 .join("\n");
129
130 let param_names: String = params
131 .iter()
132 .map(|p| format!("\"{}\"", p.name))
133 .collect::<Vec<_>>()
134 .join(", ");
135
136 let system_prompt = format!(
137 "You are a parameter extraction assistant. Extract the following parameters \
138 from the user's text and return them as a JSON object.\n\
139 Parameters to extract:\n{param_list}\n\n\
140 Return ONLY a valid JSON object with keys: {param_names}.\n\
141 For optional parameters that cannot be found, use null.\n\
142 Do not include any explanation or markdown code fences."
143 );
144
145 let messages = vec![
146 ChatMessage {
147 role: "system".into(),
148 content: system_prompt,
149 },
150 ChatMessage {
151 role: "user".into(),
152 content: query,
153 },
154 ];
155
156 let result = do_chat_completion(
157 &config.api_base,
158 &config.api_key,
159 &config.model,
160 messages,
161 Some(config.temperature),
162 config.max_tokens,
163 )
164 .await?;
165
166 let json_text = strip_markdown_fences(result.text.trim());
168 let extracted: Value = serde_json::from_str(json_text).map_err(|e| {
169 FlowError::Internal(format!(
170 "parameter-extractor: LLM returned invalid JSON: {e}\nResponse: {}",
171 result.text
172 ))
173 })?;
174
175 let obj = extracted.as_object().ok_or_else(|| {
176 FlowError::Internal("parameter-extractor: LLM response is not a JSON object".into())
177 })?;
178
179 for param in ¶ms {
181 if param.required && matches!(obj.get(¶m.name), None | Some(Value::Null)) {
182 return Err(FlowError::Internal(format!(
183 "parameter-extractor: required parameter '{}' not found in LLM response",
184 param.name
185 )));
186 }
187 }
188
189 Ok(extracted)
190 }
191}
192
193fn strip_markdown_fences(text: &str) -> &str {
197 if let Some(start) = text.find("```") {
199 let after_fence = &text[start + 3..];
200 let content_start = after_fence.find('\n').map(|n| n + 1).unwrap_or(0);
202 let content = &after_fence[content_start..];
203 if let Some(end) = content.rfind("```") {
204 return content[..end].trim();
205 }
206 }
207 text
208}
209
210#[cfg(test)]
213mod tests {
214 use super::*;
215 use serde_json::json;
216 use std::collections::HashMap;
217
218 #[test]
221 fn strips_json_code_fence() {
222 let input = "```json\n{\"a\": 1}\n```";
223 assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
224 }
225
226 #[test]
227 fn strips_plain_code_fence() {
228 let input = "```\n{\"a\": 1}\n```";
229 assert_eq!(strip_markdown_fences(input), "{\"a\": 1}");
230 }
231
232 #[test]
233 fn passthrough_when_no_fence() {
234 let input = "{\"a\": 1}";
235 assert_eq!(strip_markdown_fences(input), input);
236 }
237
238 #[tokio::test]
241 async fn rejects_missing_model() {
242 let node = ParameterExtractorNode;
243 let err = node
244 .execute(ExecContext {
245 data: json!({
246 "query": "hello",
247 "parameters": [{ "name": "x", "type": "string" }]
248 }),
249 ..Default::default()
250 })
251 .await
252 .unwrap_err();
253 assert!(matches!(err, FlowError::InvalidDefinition(_)));
254 }
255
256 #[tokio::test]
257 async fn rejects_missing_query() {
258 let node = ParameterExtractorNode;
259 let err = node
260 .execute(ExecContext {
261 data: json!({
262 "model": "gpt-4o-mini",
263 "parameters": [{ "name": "x" }]
264 }),
265 ..Default::default()
266 })
267 .await
268 .unwrap_err();
269 assert!(matches!(err, FlowError::InvalidDefinition(_)));
270 }
271
272 #[tokio::test]
273 async fn rejects_empty_parameters() {
274 let node = ParameterExtractorNode;
275 let err = node
276 .execute(ExecContext {
277 data: json!({
278 "model": "gpt-4o-mini",
279 "query": "hello",
280 "parameters": []
281 }),
282 ..Default::default()
283 })
284 .await
285 .unwrap_err();
286 assert!(matches!(err, FlowError::InvalidDefinition(_)));
287 }
288
289 #[tokio::test]
290 async fn rejects_invalid_template() {
291 let node = ParameterExtractorNode;
293 let err = node
294 .execute(ExecContext {
295 data: json!({
296 "model": "gpt-4o-mini",
297 "query": "{{ x | bad_filter }}",
298 "parameters": [{ "name": "x" }]
299 }),
300 variables: HashMap::new(),
301 inputs: HashMap::new(),
302 ..Default::default()
303 })
304 .await
305 .unwrap_err();
306 assert!(matches!(err, FlowError::Internal(_)));
307 }
308}