1use std::collections::HashMap;
58
59use async_trait::async_trait;
60use serde::Serialize;
61use serde_json::{json, Value};
62
63use crate::error::{FlowError, Result};
64use crate::node::{ExecContext, Node};
65
66const DEFAULT_API_BASE: &str = "https://api.openai.com/v1";
67const DEFAULT_TEMPERATURE: f64 = 0.7;
68
69pub struct LlmNode;
73
74#[async_trait]
75impl Node for LlmNode {
76 fn node_type(&self) -> &str {
77 "llm"
78 }
79
80 async fn execute(&self, ctx: ExecContext) -> Result<Value> {
81 let config = LlmConfig::from_data(&ctx.data)?;
82 let jinja_ctx = build_jinja_context(&ctx);
83
84 let user_prompt = render(&config.user_prompt, &jinja_ctx)?;
85 let system_prompt = config
86 .system_prompt
87 .as_deref()
88 .map(|t| render(t, &jinja_ctx))
89 .transpose()?;
90
91 let mut messages: Vec<ChatMessage> = Vec::new();
92 if let Some(sys) = system_prompt {
93 messages.push(ChatMessage {
94 role: "system".into(),
95 content: sys,
96 });
97 }
98 messages.push(ChatMessage {
99 role: "user".into(),
100 content: user_prompt,
101 });
102
103 let result = do_chat_completion(
104 &config.api_base,
105 &config.api_key,
106 &config.model,
107 messages,
108 Some(config.temperature),
109 config.max_tokens,
110 )
111 .await?;
112
113 Ok(json!({
114 "text": result.text,
115 "model": result.model,
116 "finish_reason": result.finish_reason,
117 "usage": {
118 "prompt_tokens": result.prompt_tokens,
119 "completion_tokens": result.completion_tokens,
120 "total_tokens": result.total_tokens,
121 }
122 }))
123 }
124}
125
126#[derive(Debug)]
130pub(crate) struct LlmConfig {
131 pub model: String,
132 pub user_prompt: String,
133 pub system_prompt: Option<String>,
134 pub api_base: String,
135 pub api_key: String,
136 pub temperature: f64,
137 pub max_tokens: Option<u64>,
138}
139
140impl LlmConfig {
141 pub(crate) fn from_data(data: &Value) -> Result<Self> {
142 let model = data["model"]
143 .as_str()
144 .ok_or_else(|| FlowError::InvalidDefinition("llm: missing data.model".into()))?
145 .to_string();
146
147 let user_prompt = data["user_prompt"]
148 .as_str()
149 .ok_or_else(|| FlowError::InvalidDefinition("llm: missing data.user_prompt".into()))?
150 .to_string();
151
152 let system_prompt = data["system_prompt"].as_str().map(str::to_string);
153 let api_base = data["api_base"]
154 .as_str()
155 .unwrap_or(DEFAULT_API_BASE)
156 .trim_end_matches('/')
157 .to_string();
158 let api_key = data["api_key"].as_str().unwrap_or("").to_string();
159 let temperature = data["temperature"].as_f64().unwrap_or(DEFAULT_TEMPERATURE);
160 let max_tokens = data["max_tokens"].as_u64();
161
162 Ok(Self {
163 model,
164 user_prompt,
165 system_prompt,
166 api_base,
167 api_key,
168 temperature,
169 max_tokens,
170 })
171 }
172
173 pub(crate) fn from_connection_data(data: &Value) -> Result<Self> {
177 let model = data["model"]
178 .as_str()
179 .ok_or_else(|| FlowError::InvalidDefinition("llm: missing data.model".into()))?
180 .to_string();
181
182 let api_base = data["api_base"]
183 .as_str()
184 .unwrap_or(DEFAULT_API_BASE)
185 .trim_end_matches('/')
186 .to_string();
187 let api_key = data["api_key"].as_str().unwrap_or("").to_string();
188 let temperature = data["temperature"].as_f64().unwrap_or(DEFAULT_TEMPERATURE);
189 let max_tokens = data["max_tokens"].as_u64();
190
191 Ok(Self {
192 model,
193 user_prompt: String::new(),
194 system_prompt: None,
195 api_base,
196 api_key,
197 temperature,
198 max_tokens,
199 })
200 }
201}
202
203#[derive(Debug, Serialize)]
205pub(crate) struct ChatMessage {
206 pub role: String,
207 pub content: String,
208}
209
210#[derive(Debug)]
212pub(crate) struct CompletionResult {
213 pub text: String,
214 pub model: String,
215 pub finish_reason: String,
216 pub prompt_tokens: u64,
217 pub completion_tokens: u64,
218 pub total_tokens: u64,
219}
220
221pub(crate) fn build_jinja_context(ctx: &ExecContext) -> HashMap<String, Value> {
225 let mut map: HashMap<String, Value> = ctx.variables.clone();
226 for (k, v) in &ctx.inputs {
227 map.insert(k.clone(), v.clone());
228 }
229 map
230}
231
232pub(crate) fn render(template: &str, context: &HashMap<String, Value>) -> Result<String> {
234 let env = minijinja::Environment::new();
235 env.render_str(template, context)
236 .map_err(|e| FlowError::Internal(format!("llm: template render error: {e}")))
237}
238
239pub(crate) async fn do_chat_completion(
241 api_base: &str,
242 api_key: &str,
243 model: &str,
244 messages: Vec<ChatMessage>,
245 temperature: Option<f64>,
246 max_tokens: Option<u64>,
247) -> Result<CompletionResult> {
248 let mut body = json!({
249 "model": model,
250 "messages": messages,
251 "temperature": temperature.unwrap_or(DEFAULT_TEMPERATURE),
252 });
253 if let Some(max_tok) = max_tokens {
254 body["max_tokens"] = json!(max_tok);
255 }
256
257 let url = format!("{api_base}/chat/completions");
258 let client = reqwest::Client::new();
259 let mut req = client.post(&url).json(&body);
260 if !api_key.is_empty() {
261 req = req.bearer_auth(api_key);
262 }
263
264 let response = req
265 .send()
266 .await
267 .map_err(|e| FlowError::Internal(format!("llm: HTTP request failed: {e}")))?;
268
269 let status = response.status();
270 let text = response
271 .text()
272 .await
273 .map_err(|e| FlowError::Internal(format!("llm: failed to read response body: {e}")))?;
274
275 if !status.is_success() {
276 return Err(FlowError::Internal(format!(
277 "llm: API returned {status}: {text}"
278 )));
279 }
280
281 let resp: Value = serde_json::from_str(&text)
282 .map_err(|e| FlowError::Internal(format!("llm: failed to parse response JSON: {e}")))?;
283
284 parse_completion_response(&resp)
285}
286
287pub(crate) fn parse_completion_response(resp: &Value) -> Result<CompletionResult> {
291 let text = resp
292 .pointer("/choices/0/message/content")
293 .and_then(|v| v.as_str())
294 .ok_or_else(|| {
295 FlowError::Internal(
296 "llm: unexpected response shape (missing choices[0].message.content)".into(),
297 )
298 })?
299 .to_string();
300
301 let finish_reason = resp
302 .pointer("/choices/0/finish_reason")
303 .and_then(|v| v.as_str())
304 .unwrap_or("stop")
305 .to_string();
306
307 let model = resp["model"].as_str().unwrap_or("unknown").to_string();
308 let prompt_tokens = resp
309 .pointer("/usage/prompt_tokens")
310 .and_then(|v| v.as_u64())
311 .unwrap_or(0);
312 let completion_tokens = resp
313 .pointer("/usage/completion_tokens")
314 .and_then(|v| v.as_u64())
315 .unwrap_or(0);
316 let total_tokens = resp
317 .pointer("/usage/total_tokens")
318 .and_then(|v| v.as_u64())
319 .unwrap_or(0);
320
321 Ok(CompletionResult {
322 text,
323 model,
324 finish_reason,
325 prompt_tokens,
326 completion_tokens,
327 total_tokens,
328 })
329}
330
331#[cfg(test)]
334mod tests {
335 use super::*;
336 use serde_json::json;
337 use std::collections::HashMap;
338
339 #[test]
342 fn rejects_missing_model() {
343 let err = LlmConfig::from_data(&json!({ "user_prompt": "hi" })).unwrap_err();
344 assert!(matches!(err, FlowError::InvalidDefinition(_)));
345 }
346
347 #[test]
348 fn rejects_missing_user_prompt() {
349 let err = LlmConfig::from_data(&json!({ "model": "gpt-4o" })).unwrap_err();
350 assert!(matches!(err, FlowError::InvalidDefinition(_)));
351 }
352
353 #[test]
354 fn applies_defaults() {
355 let cfg = LlmConfig::from_data(&json!({
356 "model": "gpt-4o-mini",
357 "user_prompt": "hello"
358 }))
359 .unwrap();
360 assert_eq!(cfg.api_base, DEFAULT_API_BASE);
361 assert_eq!(cfg.api_key, "");
362 assert!((cfg.temperature - DEFAULT_TEMPERATURE).abs() < 1e-9);
363 assert!(cfg.max_tokens.is_none());
364 assert!(cfg.system_prompt.is_none());
365 }
366
367 #[test]
368 fn trailing_slash_stripped_from_api_base() {
369 let cfg = LlmConfig::from_data(&json!({
370 "model": "x",
371 "user_prompt": "y",
372 "api_base": "http://localhost:11434/v1/"
373 }))
374 .unwrap();
375 assert_eq!(cfg.api_base, "http://localhost:11434/v1");
376 }
377
378 #[test]
381 fn renders_user_prompt_with_variables() {
382 let ctx_map = HashMap::from([("query".to_string(), json!("What is 2+2?"))]);
383 let rendered = render("Answer: {{ query }}", &ctx_map).unwrap();
384 assert_eq!(rendered, "Answer: What is 2+2?");
385 }
386
387 #[test]
388 fn renders_user_prompt_with_upstream_input() {
389 let ctx_map = HashMap::from([("fetch".to_string(), json!({ "body": "data" }))]);
390 let rendered = render("Got: {{ fetch.body }}", &ctx_map).unwrap();
391 assert_eq!(rendered, "Got: data");
392 }
393
394 #[test]
395 fn inputs_shadow_variables_in_context() {
396 let mut ctx = ExecContext {
397 variables: HashMap::from([("x".to_string(), json!("from_var"))]),
398 inputs: HashMap::from([("x".to_string(), json!("from_input"))]),
399 ..Default::default()
400 };
401 ctx.data = json!({});
402 let map = build_jinja_context(&ctx);
403 assert_eq!(map["x"], json!("from_input"));
404 }
405
406 #[test]
409 fn parses_standard_completion_response() {
410 let resp = json!({
411 "model": "gpt-4o-mini",
412 "choices": [{
413 "message": { "role": "assistant", "content": "Hello!" },
414 "finish_reason": "stop"
415 }],
416 "usage": {
417 "prompt_tokens": 10,
418 "completion_tokens": 5,
419 "total_tokens": 15
420 }
421 });
422 let result = parse_completion_response(&resp).unwrap();
423 assert_eq!(result.text, "Hello!");
424 assert_eq!(result.model, "gpt-4o-mini");
425 assert_eq!(result.finish_reason, "stop");
426 assert_eq!(result.prompt_tokens, 10);
427 assert_eq!(result.completion_tokens, 5);
428 assert_eq!(result.total_tokens, 15);
429 }
430
431 #[test]
432 fn missing_choices_returns_error() {
433 let err = parse_completion_response(&json!({ "model": "x", "choices": [] })).unwrap_err();
434 assert!(matches!(err, FlowError::Internal(_)));
435 }
436
437 #[test]
438 fn missing_content_returns_error() {
439 let err = parse_completion_response(&json!({
440 "choices": [{ "message": { "role": "assistant" } }]
441 }))
442 .unwrap_err();
443 assert!(matches!(err, FlowError::Internal(_)));
444 }
445
446 #[test]
447 fn partial_usage_fields_default_to_zero() {
448 let resp = json!({
449 "model": "x",
450 "choices": [{ "message": { "content": "ok" }, "finish_reason": "stop" }]
451 });
452 let result = parse_completion_response(&resp).unwrap();
453 assert_eq!(result.total_tokens, 0);
454 }
455}