Skip to main content

sage_runtime/
llm.rs

1//! LLM client for inference calls.
2
3use crate::error::{SageError, SageResult};
4use serde::{Deserialize, Serialize};
5
6/// Client for making LLM inference calls.
7#[derive(Clone)]
8pub struct LlmClient {
9    client: reqwest::Client,
10    config: LlmConfig,
11}
12
13/// Configuration for the LLM client.
14#[derive(Clone)]
15pub struct LlmConfig {
16    /// API key for authentication.
17    pub api_key: String,
18    /// Base URL for the API.
19    pub base_url: String,
20    /// Model to use.
21    pub model: String,
22}
23
24impl LlmConfig {
25    /// Create a config from environment variables.
26    pub fn from_env() -> Self {
27        Self {
28            api_key: std::env::var("SAGE_API_KEY").unwrap_or_default(),
29            base_url: std::env::var("SAGE_LLM_URL")
30                .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
31            model: std::env::var("SAGE_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()),
32        }
33    }
34
35    /// Create a mock config for testing.
36    pub fn mock() -> Self {
37        Self {
38            api_key: "mock".to_string(),
39            base_url: "mock".to_string(),
40            model: "mock".to_string(),
41        }
42    }
43
44    /// Check if this is a mock configuration.
45    pub fn is_mock(&self) -> bool {
46        self.api_key == "mock"
47    }
48}
49
50impl LlmClient {
51    /// Create a new LLM client with the given configuration.
52    pub fn new(config: LlmConfig) -> Self {
53        Self {
54            client: reqwest::Client::new(),
55            config,
56        }
57    }
58
59    /// Create a client from environment variables.
60    pub fn from_env() -> Self {
61        Self::new(LlmConfig::from_env())
62    }
63
64    /// Create a mock client for testing.
65    pub fn mock() -> Self {
66        Self::new(LlmConfig::mock())
67    }
68
69    /// Call the LLM with a prompt and return the raw string response.
70    pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
71        if self.config.is_mock() {
72            return Ok(format!("[Mock LLM response for: {prompt}]"));
73        }
74
75        let request = ChatRequest {
76            model: &self.config.model,
77            messages: vec![ChatMessage {
78                role: "user",
79                content: prompt,
80            }],
81        };
82
83        let response = self
84            .client
85            .post(format!("{}/chat/completions", self.config.base_url))
86            .header("Authorization", format!("Bearer {}", self.config.api_key))
87            .header("Content-Type", "application/json")
88            .json(&request)
89            .send()
90            .await?;
91
92        if !response.status().is_success() {
93            let status = response.status();
94            let body = response.text().await.unwrap_or_default();
95            return Err(SageError::Llm(format!("API error {status}: {body}")));
96        }
97
98        let chat_response: ChatResponse = response.json().await?;
99        let content = chat_response
100            .choices
101            .into_iter()
102            .next()
103            .map(|c| c.message.content)
104            .unwrap_or_default();
105
106        Ok(content)
107    }
108
109    /// Call the LLM with a prompt and parse the response as the given type.
110    pub async fn infer<T>(&self, prompt: &str) -> SageResult<T>
111    where
112        T: serde::de::DeserializeOwned,
113    {
114        let response = self.infer_string(prompt).await?;
115
116        // Try to parse as JSON first
117        if let Ok(value) = serde_json::from_str(&response) {
118            return Ok(value);
119        }
120
121        // Try to parse as JSON, stripping markdown code blocks if present
122        let cleaned = response
123            .trim()
124            .strip_prefix("```json")
125            .unwrap_or(&response)
126            .strip_prefix("```")
127            .unwrap_or(&response)
128            .strip_suffix("```")
129            .unwrap_or(&response)
130            .trim();
131
132        serde_json::from_str(cleaned).map_err(|e| {
133            SageError::Llm(format!(
134                "Failed to parse LLM response as {}: {e}\nResponse: {response}",
135                std::any::type_name::<T>()
136            ))
137        })
138    }
139}
140
141#[derive(Serialize)]
142struct ChatRequest<'a> {
143    model: &'a str,
144    messages: Vec<ChatMessage<'a>>,
145}
146
147#[derive(Serialize)]
148struct ChatMessage<'a> {
149    role: &'a str,
150    content: &'a str,
151}
152
153#[derive(Deserialize)]
154struct ChatResponse {
155    choices: Vec<Choice>,
156}
157
158#[derive(Deserialize)]
159struct Choice {
160    message: ResponseMessage,
161}
162
163#[derive(Deserialize)]
164struct ResponseMessage {
165    content: String,
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[tokio::test]
173    async fn mock_client_returns_placeholder() {
174        let client = LlmClient::mock();
175        let response = client.infer_string("test prompt").await.unwrap();
176        assert!(response.contains("Mock LLM response"));
177        assert!(response.contains("test prompt"));
178    }
179}