Skip to main content

a3s_flow/nodes/
question_classifier.rs

1//! `"question-classifier"` node — LLM-powered routing.
2//!
3//! Classifies an input question into one of several user-defined classes using
4//! an LLM, then outputs `{ "branch": "class_id" }` — the same shape as the
5//! `"if-else"` node, so `run_if` conditions work identically.
6//!
7//! Uses the same OpenAI-compatible API as the [`LlmNode`](super::llm::LlmNode).
8//!
9//! # Config schema
10//!
11//! ```json
12//! {
13//!   "model":    "gpt-4o-mini",
14//!   "question": "{{ user_input }}",
15//!   "classes": [
16//!     { "id": "technical", "name": "Technical question",
17//!       "description": "Questions about code, APIs, or system behaviour" },
18//!     { "id": "billing",   "name": "Billing question" },
19//!     { "id": "general",   "name": "General question" }
20//!   ],
21//!   "api_base":    "https://api.openai.com/v1",
22//!   "api_key":     "sk-...",
23//!   "temperature": 0.0
24//! }
25//! ```
26//!
27//! | Field | Type | Required | Description |
28//! |-------|------|:--------:|-------------|
29//! | `model` | string | ✅ | Model identifier |
30//! | `question` | string | ✅ | Question to classify — rendered as Jinja2 template |
31//! | `classes` | array | ✅ | At least 2 classes; each requires `id` and `name` |
32//! | `classes[].id` | string | ✅ | Unique identifier returned as `branch` |
33//! | `classes[].name` | string | ✅ | Human-readable class name |
34//! | `classes[].description` | string | — | Optional extra guidance for the LLM |
35//! | `api_base`, `api_key`, `temperature`, `max_tokens` | | — | Same as `"llm"` node |
36//!
37//! # Output schema
38//!
39//! ```json
40//! { "branch": "technical" }
41//! ```
42//!
43//! If the LLM response does not match any declared class ID (case-insensitive),
44//! the node falls back to the first class. Use `run_if` on downstream nodes to
45//! route by branch:
46//!
47//! ```json
48//! {
49//!   "run_if": { "from": "classifier", "path": "branch",
50//!               "op": "eq", "value": "technical" }
51//! }
52//! ```
53
54use async_trait::async_trait;
55use serde::Deserialize;
56use serde_json::{json, Value};
57
58use crate::error::{FlowError, Result};
59use crate::node::{ExecContext, Node};
60
61use super::llm::{build_jinja_context, do_chat_completion, render, ChatMessage, LlmConfig};
62
63#[derive(Debug, Deserialize)]
64struct ClassDecl {
65    id: String,
66    name: String,
67    #[serde(default)]
68    description: String,
69}
70
71pub struct QuestionClassifierNode;
72
73#[async_trait]
74impl Node for QuestionClassifierNode {
75    fn node_type(&self) -> &str {
76        "question-classifier"
77    }
78
79    async fn execute(&self, ctx: ExecContext) -> Result<Value> {
80        let config = LlmConfig::from_connection_data(&ctx.data)?;
81
82        let question_template = ctx.data["question"].as_str().ok_or_else(|| {
83            FlowError::InvalidDefinition("question-classifier: missing data.question".into())
84        })?;
85
86        let classes: Vec<ClassDecl> =
87            serde_json::from_value(ctx.data["classes"].clone()).map_err(|e| {
88                FlowError::InvalidDefinition(format!(
89                    "question-classifier: invalid classes declaration: {e}"
90                ))
91            })?;
92
93        if classes.len() < 2 {
94            return Err(FlowError::InvalidDefinition(
95                "question-classifier: at least 2 classes required".into(),
96            ));
97        }
98
99        let jinja_ctx = build_jinja_context(&ctx);
100        let question = render(question_template, &jinja_ctx)?;
101
102        // Build a deterministic classification prompt.
103        let class_list: String = classes
104            .iter()
105            .map(|c| {
106                if c.description.is_empty() {
107                    format!("- {}: {}", c.id, c.name)
108                } else {
109                    format!("- {}: {} — {}", c.id, c.name, c.description)
110                }
111            })
112            .collect::<Vec<_>>()
113            .join("\n");
114
115        let system_prompt = format!(
116            "You are a question classifier. Classify the user's question into \
117             exactly one of the following categories:\n{class_list}\n\n\
118             Respond with ONLY the category ID (e.g., \"{}\"). \
119             Do not include any explanation or punctuation.",
120            classes[0].id
121        );
122
123        let messages = vec![
124            ChatMessage {
125                role: "system".into(),
126                content: system_prompt,
127            },
128            ChatMessage {
129                role: "user".into(),
130                content: question,
131            },
132        ];
133
134        let result = do_chat_completion(
135            &config.api_base,
136            &config.api_key,
137            &config.model,
138            messages,
139            Some(config.temperature),
140            Some(config.max_tokens.unwrap_or(16)), // short response — just the ID
141        )
142        .await?;
143
144        let branch = pick_class(&result.text.trim().to_lowercase(), &classes);
145        Ok(json!({ "branch": branch }))
146    }
147}
148
149/// Find the matching class ID from the LLM response text.
150///
151/// Strategy (in order):
152/// 1. Exact match against a class ID (case-insensitive)
153/// 2. A class ID appears anywhere in the response
154/// 3. Fall back to the first class
155fn pick_class(response: &str, classes: &[ClassDecl]) -> String {
156    // 1. Exact match.
157    for c in classes {
158        if response == c.id.to_lowercase() {
159            return c.id.clone();
160        }
161    }
162    // 2. Substring match.
163    for c in classes {
164        if response.contains(c.id.to_lowercase().as_str()) {
165            return c.id.clone();
166        }
167    }
168    // 3. Fallback.
169    classes[0].id.clone()
170}
171
172// ── Tests ─────────────────────────────────────────────────────────────────────
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use serde_json::json;
178    use std::collections::HashMap;
179
180    fn make_classes() -> Vec<ClassDecl> {
181        vec![
182            ClassDecl {
183                id: "tech".into(),
184                name: "Technical".into(),
185                description: String::new(),
186            },
187            ClassDecl {
188                id: "billing".into(),
189                name: "Billing".into(),
190                description: String::new(),
191            },
192            ClassDecl {
193                id: "general".into(),
194                name: "General".into(),
195                description: String::new(),
196            },
197        ]
198    }
199
200    // ── pick_class ─────────────────────────────────────────────────────────
201
202    #[test]
203    fn exact_match_returns_class_id() {
204        let classes = make_classes();
205        assert_eq!(pick_class("billing", &classes), "billing");
206    }
207
208    #[test]
209    fn case_insensitive_exact_match() {
210        let classes = make_classes();
211        assert_eq!(pick_class("TECH", &classes), "tech");
212    }
213
214    #[test]
215    fn substring_match_when_llm_adds_extra_words() {
216        let classes = make_classes();
217        // LLM said "technical question" but the ID is "tech"... wait, ID is "tech"
218        // so "technical" contains "tech" — first-prefix match
219        assert_eq!(pick_class("general inquiry", &classes), "general");
220    }
221
222    #[test]
223    fn fallback_to_first_class_on_no_match() {
224        let classes = make_classes();
225        assert_eq!(pick_class("something completely unknown", &classes), "tech");
226    }
227
228    // ── Config validation ──────────────────────────────────────────────────
229
230    #[tokio::test]
231    async fn rejects_missing_question() {
232        let node = QuestionClassifierNode;
233        let err = node
234            .execute(ExecContext {
235                data: json!({
236                    "model": "gpt-4o-mini",
237                    "classes": [
238                        { "id": "a", "name": "A" },
239                        { "id": "b", "name": "B" }
240                    ]
241                }),
242                ..Default::default()
243            })
244            .await
245            .unwrap_err();
246        assert!(matches!(err, FlowError::InvalidDefinition(_)));
247    }
248
249    #[tokio::test]
250    async fn rejects_fewer_than_two_classes() {
251        let node = QuestionClassifierNode;
252        let err = node
253            .execute(ExecContext {
254                data: json!({
255                    "model": "gpt-4o-mini",
256                    "question": "hi",
257                    "classes": [{ "id": "only", "name": "Only one" }]
258                }),
259                ..Default::default()
260            })
261            .await
262            .unwrap_err();
263        assert!(matches!(err, FlowError::InvalidDefinition(_)));
264    }
265
266    #[tokio::test]
267    async fn rejects_missing_model() {
268        let node = QuestionClassifierNode;
269        let err = node
270            .execute(ExecContext {
271                data: json!({
272                    "question": "hi",
273                    "classes": [
274                        { "id": "a", "name": "A" },
275                        { "id": "b", "name": "B" }
276                    ]
277                }),
278                ..Default::default()
279            })
280            .await
281            .unwrap_err();
282        assert!(matches!(err, FlowError::InvalidDefinition(_)));
283    }
284
285    #[tokio::test]
286    async fn question_template_renders_with_variables() {
287        // Verify rendering happens: if the template is invalid the error fires
288        // before any network call, so we get Internal (render) not network error.
289        let node = QuestionClassifierNode;
290        let err = node
291            .execute(ExecContext {
292                data: json!({
293                    "model": "gpt-4o-mini",
294                    "question": "{{ undefined_var | some_unknown_filter }}",
295                    "classes": [
296                        { "id": "a", "name": "A" },
297                        { "id": "b", "name": "B" }
298                    ]
299                }),
300                variables: HashMap::new(),
301                inputs: HashMap::new(),
302                ..Default::default()
303            })
304            .await
305            .unwrap_err();
306        // Render errors map to FlowError::Internal.
307        assert!(matches!(err, FlowError::Internal(_)));
308    }
309}