Skip to main content

a3s_flow/nodes/
llm.rs

1//! `"llm"` node — OpenAI-compatible chat completion.
2//!
3//! Renders system and user prompts as Jinja2 templates, calls a
4//! `/v1/chat/completions` endpoint, and returns the assistant's reply along
5//! with token-usage statistics.
6//!
7//! Works with any OpenAI-compatible API: OpenAI, Anthropic (via proxy),
8//! Ollama, LM Studio, vLLM, Together AI, etc.
9//!
10//! # Config schema
11//!
12//! ```json
13//! {
14//!   "model":         "gpt-4o-mini",
15//!   "user_prompt":   "Answer concisely: {{ query }}",
16//!   "system_prompt": "You are a helpful assistant.",
17//!   "api_base":      "https://api.openai.com/v1",
18//!   "api_key":       "sk-...",
19//!   "temperature":   0.7,
20//!   "max_tokens":    1024
21//! }
22//! ```
23//!
24//! | Field | Type | Required | Default | Description |
25//! |-------|------|:--------:|---------|-------------|
26//! | `model` | string | ✅ | — | Model identifier |
27//! | `user_prompt` | string | ✅ | — | User turn — rendered as Jinja2 template |
28//! | `system_prompt` | string | — | _(none)_ | System turn — rendered as Jinja2 template |
29//! | `api_base` | string | — | `https://api.openai.com/v1` | Base URL (no trailing slash) |
30//! | `api_key` | string | — | `""` | Bearer token; may be empty for local models |
31//! | `temperature` | number | — | `0.7` | Sampling temperature `[0, 2]` |
32//! | `max_tokens` | integer | — | _(none)_ | Max completion tokens |
33//!
34//! ## Template context
35//!
36//! Both prompts are Jinja2 templates. The rendering context contains:
37//! - All global flow `variables` (by key)
38//! - All upstream node outputs (by node ID)
39//!
40//! Upstream inputs shadow variables with the same key.
41//!
42//! # Output schema
43//!
44//! ```json
45//! {
46//!   "text":          "The answer is 42.",
47//!   "model":         "gpt-4o-mini",
48//!   "finish_reason": "stop",
49//!   "usage": {
50//!     "prompt_tokens":     15,
51//!     "completion_tokens":  8,
52//!     "total_tokens":      23
53//!   }
54//! }
55//! ```
56
57use std::collections::HashMap;
58
59use async_trait::async_trait;
60use serde::Serialize;
61use serde_json::{json, Value};
62
63use crate::error::{FlowError, Result};
64use crate::node::{ExecContext, Node};
65
66const DEFAULT_API_BASE: &str = "https://api.openai.com/v1";
67const DEFAULT_TEMPERATURE: f64 = 0.7;
68
69// ── Public node ───────────────────────────────────────────────────────────────
70
71/// LLM chat-completion node (OpenAI-compatible).
72pub struct LlmNode;
73
74#[async_trait]
75impl Node for LlmNode {
76    fn node_type(&self) -> &str {
77        "llm"
78    }
79
80    async fn execute(&self, ctx: ExecContext) -> Result<Value> {
81        let config = LlmConfig::from_data(&ctx.data)?;
82        let jinja_ctx = build_jinja_context(&ctx);
83
84        let user_prompt = render(&config.user_prompt, &jinja_ctx)?;
85        let system_prompt = config
86            .system_prompt
87            .as_deref()
88            .map(|t| render(t, &jinja_ctx))
89            .transpose()?;
90
91        let mut messages: Vec<ChatMessage> = Vec::new();
92        if let Some(sys) = system_prompt {
93            messages.push(ChatMessage {
94                role: "system".into(),
95                content: sys,
96            });
97        }
98        messages.push(ChatMessage {
99            role: "user".into(),
100            content: user_prompt,
101        });
102
103        let result = do_chat_completion(
104            &config.api_base,
105            &config.api_key,
106            &config.model,
107            messages,
108            Some(config.temperature),
109            config.max_tokens,
110        )
111        .await?;
112
113        Ok(json!({
114            "text": result.text,
115            "model": result.model,
116            "finish_reason": result.finish_reason,
117            "usage": {
118                "prompt_tokens":     result.prompt_tokens,
119                "completion_tokens": result.completion_tokens,
120                "total_tokens":      result.total_tokens,
121            }
122        }))
123    }
124}
125
126// ── Shared internals (used by question-classifier too) ────────────────────────
127
128/// Parsed node configuration.
129#[derive(Debug)]
130pub(crate) struct LlmConfig {
131    pub model: String,
132    pub user_prompt: String,
133    pub system_prompt: Option<String>,
134    pub api_base: String,
135    pub api_key: String,
136    pub temperature: f64,
137    pub max_tokens: Option<u64>,
138}
139
140impl LlmConfig {
141    pub(crate) fn from_data(data: &Value) -> Result<Self> {
142        let model = data["model"]
143            .as_str()
144            .ok_or_else(|| FlowError::InvalidDefinition("llm: missing data.model".into()))?
145            .to_string();
146
147        let user_prompt = data["user_prompt"]
148            .as_str()
149            .ok_or_else(|| FlowError::InvalidDefinition("llm: missing data.user_prompt".into()))?
150            .to_string();
151
152        let system_prompt = data["system_prompt"].as_str().map(str::to_string);
153        let api_base = data["api_base"]
154            .as_str()
155            .unwrap_or(DEFAULT_API_BASE)
156            .trim_end_matches('/')
157            .to_string();
158        let api_key = data["api_key"].as_str().unwrap_or("").to_string();
159        let temperature = data["temperature"].as_f64().unwrap_or(DEFAULT_TEMPERATURE);
160        let max_tokens = data["max_tokens"].as_u64();
161
162        Ok(Self {
163            model,
164            user_prompt,
165            system_prompt,
166            api_base,
167            api_key,
168            temperature,
169            max_tokens,
170        })
171    }
172
173    /// Parse only the connection-level fields (model, api_base, api_key, temperature,
174    /// max_tokens). Does NOT require `user_prompt` — used by nodes that build their
175    /// own prompts (e.g. `question-classifier`).
176    pub(crate) fn from_connection_data(data: &Value) -> Result<Self> {
177        let model = data["model"]
178            .as_str()
179            .ok_or_else(|| FlowError::InvalidDefinition("llm: missing data.model".into()))?
180            .to_string();
181
182        let api_base = data["api_base"]
183            .as_str()
184            .unwrap_or(DEFAULT_API_BASE)
185            .trim_end_matches('/')
186            .to_string();
187        let api_key = data["api_key"].as_str().unwrap_or("").to_string();
188        let temperature = data["temperature"].as_f64().unwrap_or(DEFAULT_TEMPERATURE);
189        let max_tokens = data["max_tokens"].as_u64();
190
191        Ok(Self {
192            model,
193            user_prompt: String::new(),
194            system_prompt: None,
195            api_base,
196            api_key,
197            temperature,
198            max_tokens,
199        })
200    }
201}
202
203/// One message in a chat conversation.
204#[derive(Debug, Serialize)]
205pub(crate) struct ChatMessage {
206    pub role: String,
207    pub content: String,
208}
209
210/// Extracted fields from a successful chat-completion response.
211#[derive(Debug)]
212pub(crate) struct CompletionResult {
213    pub text: String,
214    pub model: String,
215    pub finish_reason: String,
216    pub prompt_tokens: u64,
217    pub completion_tokens: u64,
218    pub total_tokens: u64,
219}
220
221/// Build a Jinja2 rendering context from the execution context.
222///
223/// Variables have lower priority; upstream inputs shadow same-named variables.
224pub(crate) fn build_jinja_context(ctx: &ExecContext) -> HashMap<String, Value> {
225    let mut map: HashMap<String, Value> = ctx.variables.clone();
226    for (k, v) in &ctx.inputs {
227        map.insert(k.clone(), v.clone());
228    }
229    map
230}
231
232/// Render a Jinja2 template string against the given context map.
233pub(crate) fn render(template: &str, context: &HashMap<String, Value>) -> Result<String> {
234    let env = minijinja::Environment::new();
235    env.render_str(template, context)
236        .map_err(|e| FlowError::Internal(format!("llm: template render error: {e}")))
237}
238
239/// Call the `/v1/chat/completions` endpoint and return the parsed result.
240pub(crate) async fn do_chat_completion(
241    api_base: &str,
242    api_key: &str,
243    model: &str,
244    messages: Vec<ChatMessage>,
245    temperature: Option<f64>,
246    max_tokens: Option<u64>,
247) -> Result<CompletionResult> {
248    let mut body = json!({
249        "model": model,
250        "messages": messages,
251        "temperature": temperature.unwrap_or(DEFAULT_TEMPERATURE),
252    });
253    if let Some(max_tok) = max_tokens {
254        body["max_tokens"] = json!(max_tok);
255    }
256
257    let url = format!("{api_base}/chat/completions");
258    let client = reqwest::Client::new();
259    let mut req = client.post(&url).json(&body);
260    if !api_key.is_empty() {
261        req = req.bearer_auth(api_key);
262    }
263
264    let response = req
265        .send()
266        .await
267        .map_err(|e| FlowError::Internal(format!("llm: HTTP request failed: {e}")))?;
268
269    let status = response.status();
270    let text = response
271        .text()
272        .await
273        .map_err(|e| FlowError::Internal(format!("llm: failed to read response body: {e}")))?;
274
275    if !status.is_success() {
276        return Err(FlowError::Internal(format!(
277            "llm: API returned {status}: {text}"
278        )));
279    }
280
281    let resp: Value = serde_json::from_str(&text)
282        .map_err(|e| FlowError::Internal(format!("llm: failed to parse response JSON: {e}")))?;
283
284    parse_completion_response(&resp)
285}
286
287/// Parse a raw `/v1/chat/completions` JSON response into [`CompletionResult`].
288///
289/// Extracted as a separate function so it can be unit-tested without network.
290pub(crate) fn parse_completion_response(resp: &Value) -> Result<CompletionResult> {
291    let text = resp
292        .pointer("/choices/0/message/content")
293        .and_then(|v| v.as_str())
294        .ok_or_else(|| {
295            FlowError::Internal(
296                "llm: unexpected response shape (missing choices[0].message.content)".into(),
297            )
298        })?
299        .to_string();
300
301    let finish_reason = resp
302        .pointer("/choices/0/finish_reason")
303        .and_then(|v| v.as_str())
304        .unwrap_or("stop")
305        .to_string();
306
307    let model = resp["model"].as_str().unwrap_or("unknown").to_string();
308    let prompt_tokens = resp
309        .pointer("/usage/prompt_tokens")
310        .and_then(|v| v.as_u64())
311        .unwrap_or(0);
312    let completion_tokens = resp
313        .pointer("/usage/completion_tokens")
314        .and_then(|v| v.as_u64())
315        .unwrap_or(0);
316    let total_tokens = resp
317        .pointer("/usage/total_tokens")
318        .and_then(|v| v.as_u64())
319        .unwrap_or(0);
320
321    Ok(CompletionResult {
322        text,
323        model,
324        finish_reason,
325        prompt_tokens,
326        completion_tokens,
327        total_tokens,
328    })
329}
330
331// ── Tests ─────────────────────────────────────────────────────────────────────
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use serde_json::json;
337    use std::collections::HashMap;
338
339    // ── Config validation ──────────────────────────────────────────────────
340
341    #[test]
342    fn rejects_missing_model() {
343        let err = LlmConfig::from_data(&json!({ "user_prompt": "hi" })).unwrap_err();
344        assert!(matches!(err, FlowError::InvalidDefinition(_)));
345    }
346
347    #[test]
348    fn rejects_missing_user_prompt() {
349        let err = LlmConfig::from_data(&json!({ "model": "gpt-4o" })).unwrap_err();
350        assert!(matches!(err, FlowError::InvalidDefinition(_)));
351    }
352
353    #[test]
354    fn applies_defaults() {
355        let cfg = LlmConfig::from_data(&json!({
356            "model": "gpt-4o-mini",
357            "user_prompt": "hello"
358        }))
359        .unwrap();
360        assert_eq!(cfg.api_base, DEFAULT_API_BASE);
361        assert_eq!(cfg.api_key, "");
362        assert!((cfg.temperature - DEFAULT_TEMPERATURE).abs() < 1e-9);
363        assert!(cfg.max_tokens.is_none());
364        assert!(cfg.system_prompt.is_none());
365    }
366
367    #[test]
368    fn trailing_slash_stripped_from_api_base() {
369        let cfg = LlmConfig::from_data(&json!({
370            "model": "x",
371            "user_prompt": "y",
372            "api_base": "http://localhost:11434/v1/"
373        }))
374        .unwrap();
375        assert_eq!(cfg.api_base, "http://localhost:11434/v1");
376    }
377
378    // ── Template rendering ─────────────────────────────────────────────────
379
380    #[test]
381    fn renders_user_prompt_with_variables() {
382        let ctx_map = HashMap::from([("query".to_string(), json!("What is 2+2?"))]);
383        let rendered = render("Answer: {{ query }}", &ctx_map).unwrap();
384        assert_eq!(rendered, "Answer: What is 2+2?");
385    }
386
387    #[test]
388    fn renders_user_prompt_with_upstream_input() {
389        let ctx_map = HashMap::from([("fetch".to_string(), json!({ "body": "data" }))]);
390        let rendered = render("Got: {{ fetch.body }}", &ctx_map).unwrap();
391        assert_eq!(rendered, "Got: data");
392    }
393
394    #[test]
395    fn inputs_shadow_variables_in_context() {
396        let mut ctx = ExecContext {
397            variables: HashMap::from([("x".to_string(), json!("from_var"))]),
398            inputs: HashMap::from([("x".to_string(), json!("from_input"))]),
399            ..Default::default()
400        };
401        ctx.data = json!({});
402        let map = build_jinja_context(&ctx);
403        assert_eq!(map["x"], json!("from_input"));
404    }
405
406    // ── Response parsing ───────────────────────────────────────────────────
407
408    #[test]
409    fn parses_standard_completion_response() {
410        let resp = json!({
411            "model": "gpt-4o-mini",
412            "choices": [{
413                "message": { "role": "assistant", "content": "Hello!" },
414                "finish_reason": "stop"
415            }],
416            "usage": {
417                "prompt_tokens": 10,
418                "completion_tokens": 5,
419                "total_tokens": 15
420            }
421        });
422        let result = parse_completion_response(&resp).unwrap();
423        assert_eq!(result.text, "Hello!");
424        assert_eq!(result.model, "gpt-4o-mini");
425        assert_eq!(result.finish_reason, "stop");
426        assert_eq!(result.prompt_tokens, 10);
427        assert_eq!(result.completion_tokens, 5);
428        assert_eq!(result.total_tokens, 15);
429    }
430
431    #[test]
432    fn missing_choices_returns_error() {
433        let err = parse_completion_response(&json!({ "model": "x", "choices": [] })).unwrap_err();
434        assert!(matches!(err, FlowError::Internal(_)));
435    }
436
437    #[test]
438    fn missing_content_returns_error() {
439        let err = parse_completion_response(&json!({
440            "choices": [{ "message": { "role": "assistant" } }]
441        }))
442        .unwrap_err();
443        assert!(matches!(err, FlowError::Internal(_)));
444    }
445
446    #[test]
447    fn partial_usage_fields_default_to_zero() {
448        let resp = json!({
449            "model": "x",
450            "choices": [{ "message": { "content": "ok" }, "finish_reason": "stop" }]
451        });
452        let result = parse_completion_response(&resp).unwrap();
453        assert_eq!(result.total_tokens, 0);
454    }
455}