1use crate::error::{SageError, SageResult};
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5
6const DEFAULT_INFER_RETRIES: usize = 3;
8
9#[derive(Clone)]
11pub struct LlmClient {
12 client: reqwest::Client,
13 config: LlmConfig,
14}
15
16#[derive(Clone)]
18pub struct LlmConfig {
19 pub api_key: String,
21 pub base_url: String,
23 pub model: String,
25 pub infer_retries: usize,
27}
28
29impl LlmConfig {
30 pub fn from_env() -> Self {
32 Self {
33 api_key: std::env::var("SAGE_API_KEY").unwrap_or_default(),
34 base_url: std::env::var("SAGE_LLM_URL")
35 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
36 model: std::env::var("SAGE_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()),
37 infer_retries: std::env::var("SAGE_INFER_RETRIES")
38 .ok()
39 .and_then(|s| s.parse().ok())
40 .unwrap_or(DEFAULT_INFER_RETRIES),
41 }
42 }
43
44 pub fn mock() -> Self {
46 Self {
47 api_key: "mock".to_string(),
48 base_url: "mock".to_string(),
49 model: "mock".to_string(),
50 infer_retries: DEFAULT_INFER_RETRIES,
51 }
52 }
53
54 pub fn is_mock(&self) -> bool {
56 self.api_key == "mock"
57 }
58
59 pub fn is_ollama(&self) -> bool {
61 self.base_url.contains("localhost") || self.base_url.contains("127.0.0.1")
62 }
63}
64
65impl LlmClient {
66 pub fn new(config: LlmConfig) -> Self {
68 Self {
69 client: reqwest::Client::new(),
70 config,
71 }
72 }
73
74 pub fn from_env() -> Self {
76 Self::new(LlmConfig::from_env())
77 }
78
79 pub fn mock() -> Self {
81 Self::new(LlmConfig::mock())
82 }
83
84 pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
86 if self.config.is_mock() {
87 return Ok(format!("[Mock LLM response for: {prompt}]"));
88 }
89
90 let request = ChatRequest::new(
91 &self.config.model,
92 vec![ChatMessage {
93 role: "user",
94 content: prompt,
95 }],
96 );
97
98 self.send_request(&request).await
99 }
100
101 pub async fn infer<T>(&self, prompt: &str) -> SageResult<T>
103 where
104 T: DeserializeOwned,
105 {
106 let response = self.infer_string(prompt).await?;
107 parse_json_response(&response)
108 }
109
110 pub async fn infer_structured<T>(&self, prompt: &str, schema: &str) -> SageResult<T>
115 where
116 T: DeserializeOwned,
117 {
118 if self.config.is_mock() {
119 return Err(SageError::Llm(
121 "Mock client cannot produce structured output".to_string(),
122 ));
123 }
124
125 let system_prompt = format!(
126 "You are a precise assistant that always responds with valid JSON.\n\
127 You must respond with a JSON object matching this exact schema:\n\n\
128 {schema}\n\n\
129 Respond with JSON only. No explanation, no markdown, no code blocks."
130 );
131
132 let mut last_error: Option<String> = None;
133
134 for attempt in 0..self.config.infer_retries {
135 let response = if attempt == 0 {
136 self.send_structured_request(&system_prompt, prompt, None)
137 .await?
138 } else {
139 let error_feedback = format!(
140 "Your previous response could not be parsed: {}\n\
141 Please try again, responding with valid JSON only.",
142 last_error.as_deref().unwrap_or("unknown error")
143 );
144 self.send_structured_request(&system_prompt, prompt, Some(&error_feedback))
145 .await?
146 };
147
148 match parse_json_response::<T>(&response) {
149 Ok(value) => return Ok(value),
150 Err(e) => {
151 last_error = Some(e.to_string());
152 }
154 }
155 }
156
157 Err(SageError::Llm(format!(
158 "Failed to parse structured response after {} attempts: {}",
159 self.config.infer_retries,
160 last_error.unwrap_or_else(|| "unknown error".to_string())
161 )))
162 }
163
164 async fn send_structured_request(
166 &self,
167 system_prompt: &str,
168 user_prompt: &str,
169 error_feedback: Option<&str>,
170 ) -> SageResult<String> {
171 let mut messages = vec![
172 ChatMessage {
173 role: "system",
174 content: system_prompt,
175 },
176 ChatMessage {
177 role: "user",
178 content: user_prompt,
179 },
180 ];
181
182 if let Some(feedback) = error_feedback {
183 messages.push(ChatMessage {
184 role: "user",
185 content: feedback,
186 });
187 }
188
189 let mut request = ChatRequest::new(&self.config.model, messages);
190
191 if self.config.is_ollama() {
193 request = request.with_json_format();
194 }
195
196 self.send_request(&request).await
197 }
198
199 async fn send_request(&self, request: &ChatRequest<'_>) -> SageResult<String> {
201 let response = self
202 .client
203 .post(format!("{}/chat/completions", self.config.base_url))
204 .header("Authorization", format!("Bearer {}", self.config.api_key))
205 .header("Content-Type", "application/json")
206 .json(request)
207 .send()
208 .await?;
209
210 if !response.status().is_success() {
211 let status = response.status();
212 let body = response.text().await.unwrap_or_default();
213 return Err(SageError::Llm(format!("API error {status}: {body}")));
214 }
215
216 let chat_response: ChatResponse = response.json().await?;
217 let content = chat_response
218 .choices
219 .into_iter()
220 .next()
221 .map(|c| c.message.content)
222 .unwrap_or_default();
223
224 Ok(content)
225 }
226}
227
228fn parse_json_response<T: DeserializeOwned>(response: &str) -> SageResult<T> {
230 if let Ok(value) = serde_json::from_str(response) {
232 return Ok(value);
233 }
234
235 let cleaned = response
237 .trim()
238 .strip_prefix("```json")
239 .or_else(|| response.trim().strip_prefix("```"))
240 .unwrap_or(response.trim());
241
242 let cleaned = cleaned.strip_suffix("```").unwrap_or(cleaned).trim();
243
244 serde_json::from_str(cleaned).map_err(|e| {
245 SageError::Llm(format!(
246 "Failed to parse LLM response as {}: {e}\nResponse: {response}",
247 std::any::type_name::<T>()
248 ))
249 })
250}
251
252#[derive(Serialize)]
253struct ChatRequest<'a> {
254 model: &'a str,
255 messages: Vec<ChatMessage<'a>>,
256 #[serde(skip_serializing_if = "Option::is_none")]
257 format: Option<&'a str>,
258}
259
260#[derive(Serialize)]
261struct ChatMessage<'a> {
262 role: &'a str,
263 content: &'a str,
264}
265
266impl<'a> ChatRequest<'a> {
267 fn new(model: &'a str, messages: Vec<ChatMessage<'a>>) -> Self {
268 Self {
269 model,
270 messages,
271 format: None,
272 }
273 }
274
275 fn with_json_format(mut self) -> Self {
276 self.format = Some("json");
277 self
278 }
279}
280
281#[derive(Deserialize)]
282struct ChatResponse {
283 choices: Vec<Choice>,
284}
285
286#[derive(Deserialize)]
287struct Choice {
288 message: ResponseMessage,
289}
290
291#[derive(Deserialize)]
292struct ResponseMessage {
293 content: String,
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[tokio::test]
301 async fn mock_client_returns_placeholder() {
302 let client = LlmClient::mock();
303 let response = client.infer_string("test prompt").await.unwrap();
304 assert!(response.contains("Mock LLM response"));
305 assert!(response.contains("test prompt"));
306 }
307
308 #[test]
309 fn parse_json_strips_markdown_fences() {
310 let response = "```json\n{\"value\": 42}\n```";
311 let result: serde_json::Value = parse_json_response(response).unwrap();
312 assert_eq!(result["value"], 42);
313 }
314
315 #[test]
316 fn parse_json_handles_plain_json() {
317 let response = r#"{"name": "test"}"#;
318 let result: serde_json::Value = parse_json_response(response).unwrap();
319 assert_eq!(result["name"], "test");
320 }
321
322 #[test]
323 fn parse_json_handles_generic_code_block() {
324 let response = "```\n{\"x\": 1}\n```";
325 let result: serde_json::Value = parse_json_response(response).unwrap();
326 assert_eq!(result["x"], 1);
327 }
328
329 #[test]
330 fn ollama_detection_localhost() {
331 let config = LlmConfig {
332 api_key: "test".to_string(),
333 base_url: "http://localhost:11434/v1".to_string(),
334 model: "llama2".to_string(),
335 infer_retries: 3,
336 };
337 assert!(config.is_ollama());
338 }
339
340 #[test]
341 fn ollama_detection_127() {
342 let config = LlmConfig {
343 api_key: "test".to_string(),
344 base_url: "http://127.0.0.1:11434/v1".to_string(),
345 model: "llama2".to_string(),
346 infer_retries: 3,
347 };
348 assert!(config.is_ollama());
349 }
350
351 #[test]
352 fn not_ollama_for_openai() {
353 let config = LlmConfig {
354 api_key: "test".to_string(),
355 base_url: "https://api.openai.com/v1".to_string(),
356 model: "gpt-4".to_string(),
357 infer_retries: 3,
358 };
359 assert!(!config.is_ollama());
360 }
361
362 #[test]
363 fn chat_request_json_format() {
364 let request = ChatRequest::new("model", vec![]).with_json_format();
365 let json = serde_json::to_string(&request).unwrap();
366 assert!(json.contains(r#""format":"json""#));
367 }
368
369 #[test]
370 fn chat_request_no_format_by_default() {
371 let request = ChatRequest::new("model", vec![]);
372 let json = serde_json::to_string(&request).unwrap();
373 assert!(!json.contains("format"));
374 }
375
376 #[tokio::test]
377 async fn infer_structured_fails_on_mock() {
378 let client = LlmClient::mock();
379 let result: Result<serde_json::Value, _> = client.infer_structured("test", "{}").await;
380 assert!(result.is_err());
381 assert!(result.unwrap_err().to_string().contains("Mock client"));
382 }
383}