Skip to main content

sage_runtime/
llm.rs

1//! LLM client for inference calls.
2
3use crate::error::{SageError, SageResult};
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5
6/// Default number of retries for structured inference.
7const DEFAULT_INFER_RETRIES: usize = 3;
8
9/// Client for making LLM inference calls.
10#[derive(Clone)]
11pub struct LlmClient {
12    client: reqwest::Client,
13    config: LlmConfig,
14}
15
16/// Configuration for the LLM client.
17#[derive(Clone)]
18pub struct LlmConfig {
19    /// API key for authentication.
20    pub api_key: String,
21    /// Base URL for the API.
22    pub base_url: String,
23    /// Model to use.
24    pub model: String,
25    /// Max retries for structured inference.
26    pub infer_retries: usize,
27    /// Temperature for sampling (0.0 - 2.0). None uses API default.
28    pub temperature: Option<f64>,
29    /// Maximum tokens to generate. None uses API default.
30    pub max_tokens: Option<i64>,
31}
32
33impl LlmConfig {
34    /// Create a config from environment variables.
35    pub fn from_env() -> Self {
36        Self {
37            api_key: std::env::var("SAGE_API_KEY").unwrap_or_default(),
38            base_url: std::env::var("SAGE_LLM_URL")
39                .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
40            model: std::env::var("SAGE_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()),
41            infer_retries: std::env::var("SAGE_INFER_RETRIES")
42                .ok()
43                .and_then(|s| s.parse().ok())
44                .unwrap_or(DEFAULT_INFER_RETRIES),
45            temperature: std::env::var("SAGE_TEMPERATURE")
46                .ok()
47                .and_then(|s| s.parse().ok()),
48            max_tokens: std::env::var("SAGE_MAX_TOKENS")
49                .ok()
50                .and_then(|s| s.parse().ok()),
51        }
52    }
53
54    /// Create a mock config for testing.
55    pub fn mock() -> Self {
56        Self {
57            api_key: "mock".to_string(),
58            base_url: "mock".to_string(),
59            model: "mock".to_string(),
60            infer_retries: DEFAULT_INFER_RETRIES,
61            temperature: None,
62            max_tokens: None,
63        }
64    }
65
66    /// Create a config with specific model and defaults for other settings.
67    ///
68    /// This is useful when you want to override only specific fields like model
69    /// from an effect handler, while keeping API key and base URL from environment.
70    pub fn with_model(model: impl Into<String>) -> Self {
71        let mut config = Self::from_env();
72        config.model = model.into();
73        config
74    }
75
76    /// Set the temperature for this config.
77    #[must_use]
78    pub fn with_temperature(mut self, temp: f64) -> Self {
79        self.temperature = Some(temp);
80        self
81    }
82
83    /// Set the max tokens for this config.
84    #[must_use]
85    pub fn with_max_tokens(mut self, tokens: i64) -> Self {
86        self.max_tokens = Some(tokens);
87        self
88    }
89
90    /// Check if this is a mock configuration.
91    pub fn is_mock(&self) -> bool {
92        self.api_key == "mock"
93    }
94
95    /// Check if the base URL points to a local Ollama instance.
96    pub fn is_ollama(&self) -> bool {
97        self.base_url.contains("localhost") || self.base_url.contains("127.0.0.1")
98    }
99}
100
101impl LlmClient {
102    /// Create a new LLM client with the given configuration.
103    pub fn new(config: LlmConfig) -> Self {
104        Self {
105            client: reqwest::Client::new(),
106            config,
107        }
108    }
109
110    /// Create a client from environment variables.
111    pub fn from_env() -> Self {
112        Self::new(LlmConfig::from_env())
113    }
114
115    /// Create a mock client for testing.
116    pub fn mock() -> Self {
117        Self::new(LlmConfig::mock())
118    }
119
120    /// Call the LLM with a prompt and return the raw string response.
121    pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
122        if self.config.is_mock() {
123            return Ok(format!("[Mock LLM response for: {prompt}]"));
124        }
125
126        let request = ChatRequest::new(
127            &self.config.model,
128            vec![ChatMessage {
129                role: "user",
130                content: prompt,
131            }],
132        )
133        .with_config(&self.config);
134
135        self.send_request(&request).await
136    }
137
138    /// Call the LLM with a prompt and parse the response as the given type.
139    pub async fn infer<T>(&self, prompt: &str) -> SageResult<T>
140    where
141        T: DeserializeOwned,
142    {
143        let response = self.infer_string(prompt).await?;
144        parse_json_response(&response)
145    }
146
147    /// Call the LLM with schema-injected prompt engineering for structured output.
148    ///
149    /// The schema is injected as a system message, and the runtime retries up to
150    /// `SAGE_INFER_RETRIES` times (default 3) on parse failure.
151    pub async fn infer_structured<T>(&self, prompt: &str, schema: &str) -> SageResult<T>
152    where
153        T: DeserializeOwned,
154    {
155        if self.config.is_mock() {
156            // For mock mode, return an error since we can't produce valid structured output
157            return Err(SageError::Llm(
158                "Mock client cannot produce structured output".to_string(),
159            ));
160        }
161
162        let system_prompt = format!(
163            "You are a precise assistant that always responds with valid JSON.\n\
164             You must respond with a JSON object matching this exact schema:\n\n\
165             {schema}\n\n\
166             Respond with JSON only. No explanation, no markdown, no code blocks."
167        );
168
169        let mut last_error: Option<String> = None;
170
171        for attempt in 0..self.config.infer_retries {
172            let response = if attempt == 0 {
173                self.send_structured_request(&system_prompt, prompt, None)
174                    .await?
175            } else {
176                let error_feedback = format!(
177                    "Your previous response could not be parsed: {}\n\
178                     Please try again, responding with valid JSON only.",
179                    last_error.as_deref().unwrap_or("unknown error")
180                );
181                self.send_structured_request(&system_prompt, prompt, Some(&error_feedback))
182                    .await?
183            };
184
185            match parse_json_response::<T>(&response) {
186                Ok(value) => return Ok(value),
187                Err(e) => {
188                    last_error = Some(e.to_string());
189                    // Continue to next retry
190                }
191            }
192        }
193
194        Err(SageError::Llm(format!(
195            "Failed to parse structured response after {} attempts: {}",
196            self.config.infer_retries,
197            last_error.unwrap_or_else(|| "unknown error".to_string())
198        )))
199    }
200
201    /// Send a structured inference request with optional error feedback.
202    async fn send_structured_request(
203        &self,
204        system_prompt: &str,
205        user_prompt: &str,
206        error_feedback: Option<&str>,
207    ) -> SageResult<String> {
208        let mut messages = vec![
209            ChatMessage {
210                role: "system",
211                content: system_prompt,
212            },
213            ChatMessage {
214                role: "user",
215                content: user_prompt,
216            },
217        ];
218
219        if let Some(feedback) = error_feedback {
220            messages.push(ChatMessage {
221                role: "user",
222                content: feedback,
223            });
224        }
225
226        let mut request = ChatRequest::new(&self.config.model, messages).with_config(&self.config);
227
228        // Add format: json hint for Ollama
229        if self.config.is_ollama() {
230            request = request.with_json_format();
231        }
232
233        self.send_request(&request).await
234    }
235
236    /// Send a chat request and return the response content.
237    async fn send_request(&self, request: &ChatRequest<'_>) -> SageResult<String> {
238        let response = self
239            .client
240            .post(format!("{}/chat/completions", self.config.base_url))
241            .header("Authorization", format!("Bearer {}", self.config.api_key))
242            .header("Content-Type", "application/json")
243            .json(request)
244            .send()
245            .await?;
246
247        if !response.status().is_success() {
248            let status = response.status();
249            let body = response.text().await.unwrap_or_default();
250            return Err(SageError::Llm(format!("API error {status}: {body}")));
251        }
252
253        let chat_response: ChatResponse = response.json().await?;
254        let content = chat_response
255            .choices
256            .into_iter()
257            .next()
258            .map(|c| c.message.content)
259            .unwrap_or_default();
260
261        Ok(content)
262    }
263}
264
265/// Strip markdown code fences from a response and parse as JSON.
266fn parse_json_response<T: DeserializeOwned>(response: &str) -> SageResult<T> {
267    // Try to parse as-is first
268    if let Ok(value) = serde_json::from_str(response) {
269        return Ok(value);
270    }
271
272    // Strip markdown code blocks if present
273    let cleaned = response
274        .trim()
275        .strip_prefix("```json")
276        .or_else(|| response.trim().strip_prefix("```"))
277        .unwrap_or(response.trim());
278
279    let cleaned = cleaned.strip_suffix("```").unwrap_or(cleaned).trim();
280
281    serde_json::from_str(cleaned).map_err(|e| {
282        SageError::Llm(format!(
283            "Failed to parse LLM response as {}: {e}\nResponse: {response}",
284            std::any::type_name::<T>()
285        ))
286    })
287}
288
289#[derive(Serialize)]
290struct ChatRequest<'a> {
291    model: &'a str,
292    messages: Vec<ChatMessage<'a>>,
293    #[serde(skip_serializing_if = "Option::is_none")]
294    format: Option<&'a str>,
295    #[serde(skip_serializing_if = "Option::is_none")]
296    temperature: Option<f64>,
297    #[serde(skip_serializing_if = "Option::is_none")]
298    max_tokens: Option<i64>,
299}
300
301#[derive(Serialize)]
302struct ChatMessage<'a> {
303    role: &'a str,
304    content: &'a str,
305}
306
307impl<'a> ChatRequest<'a> {
308    fn new(model: &'a str, messages: Vec<ChatMessage<'a>>) -> Self {
309        Self {
310            model,
311            messages,
312            format: None,
313            temperature: None,
314            max_tokens: None,
315        }
316    }
317
318    fn with_json_format(mut self) -> Self {
319        self.format = Some("json");
320        self
321    }
322
323    fn with_config(mut self, config: &LlmConfig) -> Self {
324        self.temperature = config.temperature;
325        self.max_tokens = config.max_tokens;
326        self
327    }
328}
329
330#[derive(Deserialize)]
331struct ChatResponse {
332    choices: Vec<Choice>,
333}
334
335#[derive(Deserialize)]
336struct Choice {
337    message: ResponseMessage,
338}
339
340#[derive(Deserialize)]
341struct ResponseMessage {
342    content: String,
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[tokio::test]
350    async fn mock_client_returns_placeholder() {
351        let client = LlmClient::mock();
352        let response = client.infer_string("test prompt").await.unwrap();
353        assert!(response.contains("Mock LLM response"));
354        assert!(response.contains("test prompt"));
355    }
356
357    #[test]
358    fn parse_json_strips_markdown_fences() {
359        let response = "```json\n{\"value\": 42}\n```";
360        let result: serde_json::Value = parse_json_response(response).unwrap();
361        assert_eq!(result["value"], 42);
362    }
363
364    #[test]
365    fn parse_json_handles_plain_json() {
366        let response = r#"{"name": "test"}"#;
367        let result: serde_json::Value = parse_json_response(response).unwrap();
368        assert_eq!(result["name"], "test");
369    }
370
371    #[test]
372    fn parse_json_handles_generic_code_block() {
373        let response = "```\n{\"x\": 1}\n```";
374        let result: serde_json::Value = parse_json_response(response).unwrap();
375        assert_eq!(result["x"], 1);
376    }
377
378    #[test]
379    fn ollama_detection_localhost() {
380        let config = LlmConfig {
381            api_key: "test".to_string(),
382            base_url: "http://localhost:11434/v1".to_string(),
383            model: "llama2".to_string(),
384            infer_retries: 3,
385            temperature: None,
386            max_tokens: None,
387        };
388        assert!(config.is_ollama());
389    }
390
391    #[test]
392    fn ollama_detection_127() {
393        let config = LlmConfig {
394            api_key: "test".to_string(),
395            base_url: "http://127.0.0.1:11434/v1".to_string(),
396            model: "llama2".to_string(),
397            infer_retries: 3,
398            temperature: None,
399            max_tokens: None,
400        };
401        assert!(config.is_ollama());
402    }
403
404    #[test]
405    fn not_ollama_for_openai() {
406        let config = LlmConfig {
407            api_key: "test".to_string(),
408            base_url: "https://api.openai.com/v1".to_string(),
409            model: "gpt-4".to_string(),
410            infer_retries: 3,
411            temperature: None,
412            max_tokens: None,
413        };
414        assert!(!config.is_ollama());
415    }
416
417    #[test]
418    fn chat_request_json_format() {
419        let request = ChatRequest::new("model", vec![]).with_json_format();
420        let json = serde_json::to_string(&request).unwrap();
421        assert!(json.contains(r#""format":"json""#));
422    }
423
424    #[test]
425    fn chat_request_no_format_by_default() {
426        let request = ChatRequest::new("model", vec![]);
427        let json = serde_json::to_string(&request).unwrap();
428        assert!(!json.contains("format"));
429    }
430
431    #[tokio::test]
432    async fn infer_structured_fails_on_mock() {
433        let client = LlmClient::mock();
434        let result: Result<serde_json::Value, _> = client.infer_structured("test", "{}").await;
435        assert!(result.is_err());
436        assert!(result.unwrap_err().to_string().contains("Mock client"));
437    }
438}