a3s_flow/nodes/
question_classifier.rs1use async_trait::async_trait;
55use serde::Deserialize;
56use serde_json::{json, Value};
57
58use crate::error::{FlowError, Result};
59use crate::node::{ExecContext, Node};
60
61use super::llm::{build_jinja_context, do_chat_completion, render, ChatMessage, LlmConfig};
62
63#[derive(Debug, Deserialize)]
64struct ClassDecl {
65 id: String,
66 name: String,
67 #[serde(default)]
68 description: String,
69}
70
71pub struct QuestionClassifierNode;
72
73#[async_trait]
74impl Node for QuestionClassifierNode {
75 fn node_type(&self) -> &str {
76 "question-classifier"
77 }
78
79 async fn execute(&self, ctx: ExecContext) -> Result<Value> {
80 let config = LlmConfig::from_connection_data(&ctx.data)?;
81
82 let question_template = ctx.data["question"].as_str().ok_or_else(|| {
83 FlowError::InvalidDefinition("question-classifier: missing data.question".into())
84 })?;
85
86 let classes: Vec<ClassDecl> =
87 serde_json::from_value(ctx.data["classes"].clone()).map_err(|e| {
88 FlowError::InvalidDefinition(format!(
89 "question-classifier: invalid classes declaration: {e}"
90 ))
91 })?;
92
93 if classes.len() < 2 {
94 return Err(FlowError::InvalidDefinition(
95 "question-classifier: at least 2 classes required".into(),
96 ));
97 }
98
99 let jinja_ctx = build_jinja_context(&ctx);
100 let question = render(question_template, &jinja_ctx)?;
101
102 let class_list: String = classes
104 .iter()
105 .map(|c| {
106 if c.description.is_empty() {
107 format!("- {}: {}", c.id, c.name)
108 } else {
109 format!("- {}: {} — {}", c.id, c.name, c.description)
110 }
111 })
112 .collect::<Vec<_>>()
113 .join("\n");
114
115 let system_prompt = format!(
116 "You are a question classifier. Classify the user's question into \
117 exactly one of the following categories:\n{class_list}\n\n\
118 Respond with ONLY the category ID (e.g., \"{}\"). \
119 Do not include any explanation or punctuation.",
120 classes[0].id
121 );
122
123 let messages = vec![
124 ChatMessage {
125 role: "system".into(),
126 content: system_prompt,
127 },
128 ChatMessage {
129 role: "user".into(),
130 content: question,
131 },
132 ];
133
134 let result = do_chat_completion(
135 &config.api_base,
136 &config.api_key,
137 &config.model,
138 messages,
139 Some(config.temperature),
140 Some(config.max_tokens.unwrap_or(16)), )
142 .await?;
143
144 let branch = pick_class(&result.text.trim().to_lowercase(), &classes);
145 Ok(json!({ "branch": branch }))
146 }
147}
148
149fn pick_class(response: &str, classes: &[ClassDecl]) -> String {
156 for c in classes {
158 if response == c.id.to_lowercase() {
159 return c.id.clone();
160 }
161 }
162 for c in classes {
164 if response.contains(c.id.to_lowercase().as_str()) {
165 return c.id.clone();
166 }
167 }
168 classes[0].id.clone()
170}
171
172#[cfg(test)]
175mod tests {
176 use super::*;
177 use serde_json::json;
178 use std::collections::HashMap;
179
180 fn make_classes() -> Vec<ClassDecl> {
181 vec![
182 ClassDecl {
183 id: "tech".into(),
184 name: "Technical".into(),
185 description: String::new(),
186 },
187 ClassDecl {
188 id: "billing".into(),
189 name: "Billing".into(),
190 description: String::new(),
191 },
192 ClassDecl {
193 id: "general".into(),
194 name: "General".into(),
195 description: String::new(),
196 },
197 ]
198 }
199
200 #[test]
203 fn exact_match_returns_class_id() {
204 let classes = make_classes();
205 assert_eq!(pick_class("billing", &classes), "billing");
206 }
207
208 #[test]
209 fn case_insensitive_exact_match() {
210 let classes = make_classes();
211 assert_eq!(pick_class("TECH", &classes), "tech");
212 }
213
214 #[test]
215 fn substring_match_when_llm_adds_extra_words() {
216 let classes = make_classes();
217 assert_eq!(pick_class("general inquiry", &classes), "general");
220 }
221
222 #[test]
223 fn fallback_to_first_class_on_no_match() {
224 let classes = make_classes();
225 assert_eq!(pick_class("something completely unknown", &classes), "tech");
226 }
227
228 #[tokio::test]
231 async fn rejects_missing_question() {
232 let node = QuestionClassifierNode;
233 let err = node
234 .execute(ExecContext {
235 data: json!({
236 "model": "gpt-4o-mini",
237 "classes": [
238 { "id": "a", "name": "A" },
239 { "id": "b", "name": "B" }
240 ]
241 }),
242 ..Default::default()
243 })
244 .await
245 .unwrap_err();
246 assert!(matches!(err, FlowError::InvalidDefinition(_)));
247 }
248
249 #[tokio::test]
250 async fn rejects_fewer_than_two_classes() {
251 let node = QuestionClassifierNode;
252 let err = node
253 .execute(ExecContext {
254 data: json!({
255 "model": "gpt-4o-mini",
256 "question": "hi",
257 "classes": [{ "id": "only", "name": "Only one" }]
258 }),
259 ..Default::default()
260 })
261 .await
262 .unwrap_err();
263 assert!(matches!(err, FlowError::InvalidDefinition(_)));
264 }
265
266 #[tokio::test]
267 async fn rejects_missing_model() {
268 let node = QuestionClassifierNode;
269 let err = node
270 .execute(ExecContext {
271 data: json!({
272 "question": "hi",
273 "classes": [
274 { "id": "a", "name": "A" },
275 { "id": "b", "name": "B" }
276 ]
277 }),
278 ..Default::default()
279 })
280 .await
281 .unwrap_err();
282 assert!(matches!(err, FlowError::InvalidDefinition(_)));
283 }
284
285 #[tokio::test]
286 async fn question_template_renders_with_variables() {
287 let node = QuestionClassifierNode;
290 let err = node
291 .execute(ExecContext {
292 data: json!({
293 "model": "gpt-4o-mini",
294 "question": "{{ undefined_var | some_unknown_filter }}",
295 "classes": [
296 { "id": "a", "name": "A" },
297 { "id": "b", "name": "B" }
298 ]
299 }),
300 variables: HashMap::new(),
301 inputs: HashMap::new(),
302 ..Default::default()
303 })
304 .await
305 .unwrap_err();
306 assert!(matches!(err, FlowError::Internal(_)));
308 }
309}