Skip to main content

sage_runtime/
llm.rs

1//! LLM client for inference calls.
2
3use crate::error::{SageError, SageResult};
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5
6/// Default number of retries for structured inference.
7const DEFAULT_INFER_RETRIES: usize = 3;
8
9/// Client for making LLM inference calls.
10#[derive(Clone)]
11pub struct LlmClient {
12    client: reqwest::Client,
13    config: LlmConfig,
14}
15
16/// Configuration for the LLM client.
17#[derive(Clone)]
18pub struct LlmConfig {
19    /// API key for authentication.
20    pub api_key: String,
21    /// Base URL for the API.
22    pub base_url: String,
23    /// Model to use.
24    pub model: String,
25    /// Max retries for structured inference.
26    pub infer_retries: usize,
27    /// Temperature for sampling (0.0 - 2.0). None uses API default.
28    pub temperature: Option<f64>,
29    /// Maximum tokens to generate. None uses API default.
30    pub max_tokens: Option<i64>,
31}
32
33impl LlmConfig {
34    /// Create a config from environment variables (native) or WASM config injection (browser).
35    #[cfg(not(target_arch = "wasm32"))]
36    pub fn from_env() -> Self {
37        Self {
38            api_key: std::env::var("SAGE_API_KEY").unwrap_or_default(),
39            base_url: std::env::var("SAGE_LLM_URL")
40                .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
41            model: std::env::var("SAGE_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()),
42            infer_retries: std::env::var("SAGE_INFER_RETRIES")
43                .ok()
44                .and_then(|s| s.parse().ok())
45                .unwrap_or(DEFAULT_INFER_RETRIES),
46            temperature: std::env::var("SAGE_TEMPERATURE")
47                .ok()
48                .and_then(|s| s.parse().ok()),
49            max_tokens: std::env::var("SAGE_MAX_TOKENS")
50                .ok()
51                .and_then(|s| s.parse().ok()),
52        }
53    }
54
55    /// Create a config from WASM config injection (set via `sage_configure()` from JS).
56    #[cfg(target_arch = "wasm32")]
57    pub fn from_env() -> Self {
58        let wasm_config = sage_runtime_web::get_llm_config();
59        Self {
60            api_key: wasm_config.api_key,
61            base_url: wasm_config.base_url,
62            model: wasm_config.model,
63            infer_retries: DEFAULT_INFER_RETRIES,
64            temperature: None,
65            max_tokens: None,
66        }
67    }
68
69    /// Create a mock config for testing.
70    pub fn mock() -> Self {
71        Self {
72            api_key: "mock".to_string(),
73            base_url: "mock".to_string(),
74            model: "mock".to_string(),
75            infer_retries: DEFAULT_INFER_RETRIES,
76            temperature: None,
77            max_tokens: None,
78        }
79    }
80
81    /// Create a config with specific model and defaults for other settings.
82    ///
83    /// This is useful when you want to override only specific fields like model
84    /// from an effect handler, while keeping API key and base URL from environment.
85    pub fn with_model(model: impl Into<String>) -> Self {
86        let mut config = Self::from_env();
87        config.model = model.into();
88        config
89    }
90
91    /// Set the temperature for this config.
92    #[must_use]
93    pub fn with_temperature(mut self, temp: f64) -> Self {
94        self.temperature = Some(temp);
95        self
96    }
97
98    /// Set the max tokens for this config.
99    #[must_use]
100    pub fn with_max_tokens(mut self, tokens: i64) -> Self {
101        self.max_tokens = Some(tokens);
102        self
103    }
104
105    /// Check if this is a mock configuration.
106    pub fn is_mock(&self) -> bool {
107        self.api_key == "mock"
108    }
109
110    /// Check if the base URL points to a local Ollama instance.
111    pub fn is_ollama(&self) -> bool {
112        self.base_url.contains("localhost") || self.base_url.contains("127.0.0.1")
113    }
114}
115
116impl LlmClient {
117    /// Create a new LLM client with the given configuration.
118    pub fn new(config: LlmConfig) -> Self {
119        Self {
120            client: reqwest::Client::new(),
121            config,
122        }
123    }
124
125    /// Create a client from environment variables.
126    pub fn from_env() -> Self {
127        Self::new(LlmConfig::from_env())
128    }
129
130    /// Create a mock client for testing.
131    pub fn mock() -> Self {
132        Self::new(LlmConfig::mock())
133    }
134
135    /// Call the LLM with a prompt and return the raw string response.
136    pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
137        if self.config.is_mock() {
138            return Ok(format!("[Mock LLM response for: {prompt}]"));
139        }
140
141        let request = ChatRequest::new(
142            &self.config.model,
143            vec![ChatMessage {
144                role: "user",
145                content: prompt,
146            }],
147        )
148        .with_config(&self.config);
149
150        self.send_request(&request).await
151    }
152
153    /// Call the LLM with a prompt and parse the response as the given type.
154    pub async fn infer<T>(&self, prompt: &str) -> SageResult<T>
155    where
156        T: DeserializeOwned,
157    {
158        let response = self.infer_string(prompt).await?;
159        parse_json_response(&response)
160    }
161
162    /// Call the LLM with schema-injected prompt engineering for structured output.
163    ///
164    /// The schema is injected as a system message, and the runtime retries up to
165    /// `SAGE_INFER_RETRIES` times (default 3) on parse failure.
166    pub async fn infer_structured<T>(&self, prompt: &str, schema: &str) -> SageResult<T>
167    where
168        T: DeserializeOwned,
169    {
170        if self.config.is_mock() {
171            // For mock mode, return an error since we can't produce valid structured output
172            return Err(SageError::Llm(
173                "Mock client cannot produce structured output".to_string(),
174            ));
175        }
176
177        let system_prompt = format!(
178            "You are a precise assistant that always responds with valid JSON.\n\
179             You must respond with a JSON object matching this exact schema:\n\n\
180             {schema}\n\n\
181             Respond with JSON only. No explanation, no markdown, no code blocks."
182        );
183
184        let mut last_error: Option<String> = None;
185
186        for attempt in 0..self.config.infer_retries {
187            let response = if attempt == 0 {
188                self.send_structured_request(&system_prompt, prompt, None)
189                    .await?
190            } else {
191                let error_feedback = format!(
192                    "Your previous response could not be parsed: {}\n\
193                     Please try again, responding with valid JSON only.",
194                    last_error.as_deref().unwrap_or("unknown error")
195                );
196                self.send_structured_request(&system_prompt, prompt, Some(&error_feedback))
197                    .await?
198            };
199
200            match parse_json_response::<T>(&response) {
201                Ok(value) => return Ok(value),
202                Err(e) => {
203                    last_error = Some(e.to_string());
204                    // Continue to next retry
205                }
206            }
207        }
208
209        Err(SageError::Llm(format!(
210            "Failed to parse structured response after {} attempts: {}",
211            self.config.infer_retries,
212            last_error.unwrap_or_else(|| "unknown error".to_string())
213        )))
214    }
215
216    /// Send a structured inference request with optional error feedback.
217    async fn send_structured_request(
218        &self,
219        system_prompt: &str,
220        user_prompt: &str,
221        error_feedback: Option<&str>,
222    ) -> SageResult<String> {
223        let mut messages = vec![
224            ChatMessage {
225                role: "system",
226                content: system_prompt,
227            },
228            ChatMessage {
229                role: "user",
230                content: user_prompt,
231            },
232        ];
233
234        if let Some(feedback) = error_feedback {
235            messages.push(ChatMessage {
236                role: "user",
237                content: feedback,
238            });
239        }
240
241        let mut request = ChatRequest::new(&self.config.model, messages).with_config(&self.config);
242
243        // Add format: json hint for Ollama
244        if self.config.is_ollama() {
245            request = request.with_json_format();
246        }
247
248        self.send_request(&request).await
249    }
250
251    /// Send a chat request and return the response content.
252    async fn send_request(&self, request: &ChatRequest<'_>) -> SageResult<String> {
253        let response = self
254            .client
255            .post(format!("{}/chat/completions", self.config.base_url))
256            .header("Authorization", format!("Bearer {}", self.config.api_key))
257            .header("Content-Type", "application/json")
258            .json(request)
259            .send()
260            .await?;
261
262        if !response.status().is_success() {
263            let status = response.status();
264            let body = response.text().await.unwrap_or_default();
265            return Err(SageError::Llm(format!("API error {status}: {body}")));
266        }
267
268        let chat_response: ChatResponse = response.json().await?;
269        let content = chat_response
270            .choices
271            .into_iter()
272            .next()
273            .map(|c| c.message.content)
274            .unwrap_or_default();
275
276        Ok(content)
277    }
278}
279
280/// Strip markdown code fences from a response and parse as JSON.
281fn parse_json_response<T: DeserializeOwned>(response: &str) -> SageResult<T> {
282    // Try to parse as-is first
283    if let Ok(value) = serde_json::from_str(response) {
284        return Ok(value);
285    }
286
287    // Strip markdown code blocks if present
288    let cleaned = response
289        .trim()
290        .strip_prefix("```json")
291        .or_else(|| response.trim().strip_prefix("```"))
292        .unwrap_or(response.trim());
293
294    let cleaned = cleaned.strip_suffix("```").unwrap_or(cleaned).trim();
295
296    serde_json::from_str(cleaned).map_err(|e| {
297        SageError::Llm(format!(
298            "Failed to parse LLM response as {}: {e}\nResponse: {response}",
299            std::any::type_name::<T>()
300        ))
301    })
302}
303
304#[derive(Serialize)]
305struct ChatRequest<'a> {
306    model: &'a str,
307    messages: Vec<ChatMessage<'a>>,
308    #[serde(skip_serializing_if = "Option::is_none")]
309    format: Option<&'a str>,
310    #[serde(skip_serializing_if = "Option::is_none")]
311    temperature: Option<f64>,
312    #[serde(skip_serializing_if = "Option::is_none")]
313    max_tokens: Option<i64>,
314}
315
316#[derive(Serialize)]
317struct ChatMessage<'a> {
318    role: &'a str,
319    content: &'a str,
320}
321
322impl<'a> ChatRequest<'a> {
323    fn new(model: &'a str, messages: Vec<ChatMessage<'a>>) -> Self {
324        Self {
325            model,
326            messages,
327            format: None,
328            temperature: None,
329            max_tokens: None,
330        }
331    }
332
333    fn with_json_format(mut self) -> Self {
334        self.format = Some("json");
335        self
336    }
337
338    fn with_config(mut self, config: &LlmConfig) -> Self {
339        self.temperature = config.temperature;
340        self.max_tokens = config.max_tokens;
341        self
342    }
343}
344
345#[derive(Deserialize)]
346struct ChatResponse {
347    choices: Vec<Choice>,
348}
349
350#[derive(Deserialize)]
351struct Choice {
352    message: ResponseMessage,
353}
354
355#[derive(Deserialize)]
356struct ResponseMessage {
357    content: String,
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[tokio::test]
365    async fn mock_client_returns_placeholder() {
366        let client = LlmClient::mock();
367        let response = client.infer_string("test prompt").await.unwrap();
368        assert!(response.contains("Mock LLM response"));
369        assert!(response.contains("test prompt"));
370    }
371
372    #[test]
373    fn parse_json_strips_markdown_fences() {
374        let response = "```json\n{\"value\": 42}\n```";
375        let result: serde_json::Value = parse_json_response(response).unwrap();
376        assert_eq!(result["value"], 42);
377    }
378
379    #[test]
380    fn parse_json_handles_plain_json() {
381        let response = r#"{"name": "test"}"#;
382        let result: serde_json::Value = parse_json_response(response).unwrap();
383        assert_eq!(result["name"], "test");
384    }
385
386    #[test]
387    fn parse_json_handles_generic_code_block() {
388        let response = "```\n{\"x\": 1}\n```";
389        let result: serde_json::Value = parse_json_response(response).unwrap();
390        assert_eq!(result["x"], 1);
391    }
392
393    #[test]
394    fn ollama_detection_localhost() {
395        let config = LlmConfig {
396            api_key: "test".to_string(),
397            base_url: "http://localhost:11434/v1".to_string(),
398            model: "llama2".to_string(),
399            infer_retries: 3,
400            temperature: None,
401            max_tokens: None,
402        };
403        assert!(config.is_ollama());
404    }
405
406    #[test]
407    fn ollama_detection_127() {
408        let config = LlmConfig {
409            api_key: "test".to_string(),
410            base_url: "http://127.0.0.1:11434/v1".to_string(),
411            model: "llama2".to_string(),
412            infer_retries: 3,
413            temperature: None,
414            max_tokens: None,
415        };
416        assert!(config.is_ollama());
417    }
418
419    #[test]
420    fn not_ollama_for_openai() {
421        let config = LlmConfig {
422            api_key: "test".to_string(),
423            base_url: "https://api.openai.com/v1".to_string(),
424            model: "gpt-4".to_string(),
425            infer_retries: 3,
426            temperature: None,
427            max_tokens: None,
428        };
429        assert!(!config.is_ollama());
430    }
431
432    #[test]
433    fn chat_request_json_format() {
434        let request = ChatRequest::new("model", vec![]).with_json_format();
435        let json = serde_json::to_string(&request).unwrap();
436        assert!(json.contains(r#""format":"json""#));
437    }
438
439    #[test]
440    fn chat_request_no_format_by_default() {
441        let request = ChatRequest::new("model", vec![]);
442        let json = serde_json::to_string(&request).unwrap();
443        assert!(!json.contains("format"));
444    }
445
446    #[tokio::test]
447    async fn infer_structured_fails_on_mock() {
448        let client = LlmClient::mock();
449        let result: Result<serde_json::Value, _> = client.infer_structured("test", "{}").await;
450        assert!(result.is_err());
451        assert!(result.unwrap_err().to_string().contains("Mock client"));
452    }
453}