1use crate::error::{SageError, SageResult};
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5
6const DEFAULT_INFER_RETRIES: usize = 3;
8
9#[derive(Clone)]
11pub struct LlmClient {
12 client: reqwest::Client,
13 config: LlmConfig,
14}
15
16#[derive(Clone)]
18pub struct LlmConfig {
19 pub api_key: String,
21 pub base_url: String,
23 pub model: String,
25 pub infer_retries: usize,
27 pub temperature: Option<f64>,
29 pub max_tokens: Option<i64>,
31}
32
33impl LlmConfig {
34 pub fn from_env() -> Self {
36 Self {
37 api_key: std::env::var("SAGE_API_KEY").unwrap_or_default(),
38 base_url: std::env::var("SAGE_LLM_URL")
39 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
40 model: std::env::var("SAGE_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()),
41 infer_retries: std::env::var("SAGE_INFER_RETRIES")
42 .ok()
43 .and_then(|s| s.parse().ok())
44 .unwrap_or(DEFAULT_INFER_RETRIES),
45 temperature: std::env::var("SAGE_TEMPERATURE")
46 .ok()
47 .and_then(|s| s.parse().ok()),
48 max_tokens: std::env::var("SAGE_MAX_TOKENS")
49 .ok()
50 .and_then(|s| s.parse().ok()),
51 }
52 }
53
54 pub fn mock() -> Self {
56 Self {
57 api_key: "mock".to_string(),
58 base_url: "mock".to_string(),
59 model: "mock".to_string(),
60 infer_retries: DEFAULT_INFER_RETRIES,
61 temperature: None,
62 max_tokens: None,
63 }
64 }
65
66 pub fn with_model(model: impl Into<String>) -> Self {
71 let mut config = Self::from_env();
72 config.model = model.into();
73 config
74 }
75
76 #[must_use]
78 pub fn with_temperature(mut self, temp: f64) -> Self {
79 self.temperature = Some(temp);
80 self
81 }
82
83 #[must_use]
85 pub fn with_max_tokens(mut self, tokens: i64) -> Self {
86 self.max_tokens = Some(tokens);
87 self
88 }
89
90 pub fn is_mock(&self) -> bool {
92 self.api_key == "mock"
93 }
94
95 pub fn is_ollama(&self) -> bool {
97 self.base_url.contains("localhost") || self.base_url.contains("127.0.0.1")
98 }
99}
100
101impl LlmClient {
102 pub fn new(config: LlmConfig) -> Self {
104 Self {
105 client: reqwest::Client::new(),
106 config,
107 }
108 }
109
110 pub fn from_env() -> Self {
112 Self::new(LlmConfig::from_env())
113 }
114
115 pub fn mock() -> Self {
117 Self::new(LlmConfig::mock())
118 }
119
120 pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
122 if self.config.is_mock() {
123 return Ok(format!("[Mock LLM response for: {prompt}]"));
124 }
125
126 let request = ChatRequest::new(
127 &self.config.model,
128 vec![ChatMessage {
129 role: "user",
130 content: prompt,
131 }],
132 )
133 .with_config(&self.config);
134
135 self.send_request(&request).await
136 }
137
138 pub async fn infer<T>(&self, prompt: &str) -> SageResult<T>
140 where
141 T: DeserializeOwned,
142 {
143 let response = self.infer_string(prompt).await?;
144 parse_json_response(&response)
145 }
146
147 pub async fn infer_structured<T>(&self, prompt: &str, schema: &str) -> SageResult<T>
152 where
153 T: DeserializeOwned,
154 {
155 if self.config.is_mock() {
156 return Err(SageError::Llm(
158 "Mock client cannot produce structured output".to_string(),
159 ));
160 }
161
162 let system_prompt = format!(
163 "You are a precise assistant that always responds with valid JSON.\n\
164 You must respond with a JSON object matching this exact schema:\n\n\
165 {schema}\n\n\
166 Respond with JSON only. No explanation, no markdown, no code blocks."
167 );
168
169 let mut last_error: Option<String> = None;
170
171 for attempt in 0..self.config.infer_retries {
172 let response = if attempt == 0 {
173 self.send_structured_request(&system_prompt, prompt, None)
174 .await?
175 } else {
176 let error_feedback = format!(
177 "Your previous response could not be parsed: {}\n\
178 Please try again, responding with valid JSON only.",
179 last_error.as_deref().unwrap_or("unknown error")
180 );
181 self.send_structured_request(&system_prompt, prompt, Some(&error_feedback))
182 .await?
183 };
184
185 match parse_json_response::<T>(&response) {
186 Ok(value) => return Ok(value),
187 Err(e) => {
188 last_error = Some(e.to_string());
189 }
191 }
192 }
193
194 Err(SageError::Llm(format!(
195 "Failed to parse structured response after {} attempts: {}",
196 self.config.infer_retries,
197 last_error.unwrap_or_else(|| "unknown error".to_string())
198 )))
199 }
200
201 async fn send_structured_request(
203 &self,
204 system_prompt: &str,
205 user_prompt: &str,
206 error_feedback: Option<&str>,
207 ) -> SageResult<String> {
208 let mut messages = vec![
209 ChatMessage {
210 role: "system",
211 content: system_prompt,
212 },
213 ChatMessage {
214 role: "user",
215 content: user_prompt,
216 },
217 ];
218
219 if let Some(feedback) = error_feedback {
220 messages.push(ChatMessage {
221 role: "user",
222 content: feedback,
223 });
224 }
225
226 let mut request = ChatRequest::new(&self.config.model, messages).with_config(&self.config);
227
228 if self.config.is_ollama() {
230 request = request.with_json_format();
231 }
232
233 self.send_request(&request).await
234 }
235
236 async fn send_request(&self, request: &ChatRequest<'_>) -> SageResult<String> {
238 let response = self
239 .client
240 .post(format!("{}/chat/completions", self.config.base_url))
241 .header("Authorization", format!("Bearer {}", self.config.api_key))
242 .header("Content-Type", "application/json")
243 .json(request)
244 .send()
245 .await?;
246
247 if !response.status().is_success() {
248 let status = response.status();
249 let body = response.text().await.unwrap_or_default();
250 return Err(SageError::Llm(format!("API error {status}: {body}")));
251 }
252
253 let chat_response: ChatResponse = response.json().await?;
254 let content = chat_response
255 .choices
256 .into_iter()
257 .next()
258 .map(|c| c.message.content)
259 .unwrap_or_default();
260
261 Ok(content)
262 }
263}
264
265fn parse_json_response<T: DeserializeOwned>(response: &str) -> SageResult<T> {
267 if let Ok(value) = serde_json::from_str(response) {
269 return Ok(value);
270 }
271
272 let cleaned = response
274 .trim()
275 .strip_prefix("```json")
276 .or_else(|| response.trim().strip_prefix("```"))
277 .unwrap_or(response.trim());
278
279 let cleaned = cleaned.strip_suffix("```").unwrap_or(cleaned).trim();
280
281 serde_json::from_str(cleaned).map_err(|e| {
282 SageError::Llm(format!(
283 "Failed to parse LLM response as {}: {e}\nResponse: {response}",
284 std::any::type_name::<T>()
285 ))
286 })
287}
288
289#[derive(Serialize)]
290struct ChatRequest<'a> {
291 model: &'a str,
292 messages: Vec<ChatMessage<'a>>,
293 #[serde(skip_serializing_if = "Option::is_none")]
294 format: Option<&'a str>,
295 #[serde(skip_serializing_if = "Option::is_none")]
296 temperature: Option<f64>,
297 #[serde(skip_serializing_if = "Option::is_none")]
298 max_tokens: Option<i64>,
299}
300
301#[derive(Serialize)]
302struct ChatMessage<'a> {
303 role: &'a str,
304 content: &'a str,
305}
306
307impl<'a> ChatRequest<'a> {
308 fn new(model: &'a str, messages: Vec<ChatMessage<'a>>) -> Self {
309 Self {
310 model,
311 messages,
312 format: None,
313 temperature: None,
314 max_tokens: None,
315 }
316 }
317
318 fn with_json_format(mut self) -> Self {
319 self.format = Some("json");
320 self
321 }
322
323 fn with_config(mut self, config: &LlmConfig) -> Self {
324 self.temperature = config.temperature;
325 self.max_tokens = config.max_tokens;
326 self
327 }
328}
329
330#[derive(Deserialize)]
331struct ChatResponse {
332 choices: Vec<Choice>,
333}
334
335#[derive(Deserialize)]
336struct Choice {
337 message: ResponseMessage,
338}
339
340#[derive(Deserialize)]
341struct ResponseMessage {
342 content: String,
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[tokio::test]
350 async fn mock_client_returns_placeholder() {
351 let client = LlmClient::mock();
352 let response = client.infer_string("test prompt").await.unwrap();
353 assert!(response.contains("Mock LLM response"));
354 assert!(response.contains("test prompt"));
355 }
356
357 #[test]
358 fn parse_json_strips_markdown_fences() {
359 let response = "```json\n{\"value\": 42}\n```";
360 let result: serde_json::Value = parse_json_response(response).unwrap();
361 assert_eq!(result["value"], 42);
362 }
363
364 #[test]
365 fn parse_json_handles_plain_json() {
366 let response = r#"{"name": "test"}"#;
367 let result: serde_json::Value = parse_json_response(response).unwrap();
368 assert_eq!(result["name"], "test");
369 }
370
371 #[test]
372 fn parse_json_handles_generic_code_block() {
373 let response = "```\n{\"x\": 1}\n```";
374 let result: serde_json::Value = parse_json_response(response).unwrap();
375 assert_eq!(result["x"], 1);
376 }
377
378 #[test]
379 fn ollama_detection_localhost() {
380 let config = LlmConfig {
381 api_key: "test".to_string(),
382 base_url: "http://localhost:11434/v1".to_string(),
383 model: "llama2".to_string(),
384 infer_retries: 3,
385 temperature: None,
386 max_tokens: None,
387 };
388 assert!(config.is_ollama());
389 }
390
391 #[test]
392 fn ollama_detection_127() {
393 let config = LlmConfig {
394 api_key: "test".to_string(),
395 base_url: "http://127.0.0.1:11434/v1".to_string(),
396 model: "llama2".to_string(),
397 infer_retries: 3,
398 temperature: None,
399 max_tokens: None,
400 };
401 assert!(config.is_ollama());
402 }
403
404 #[test]
405 fn not_ollama_for_openai() {
406 let config = LlmConfig {
407 api_key: "test".to_string(),
408 base_url: "https://api.openai.com/v1".to_string(),
409 model: "gpt-4".to_string(),
410 infer_retries: 3,
411 temperature: None,
412 max_tokens: None,
413 };
414 assert!(!config.is_ollama());
415 }
416
417 #[test]
418 fn chat_request_json_format() {
419 let request = ChatRequest::new("model", vec![]).with_json_format();
420 let json = serde_json::to_string(&request).unwrap();
421 assert!(json.contains(r#""format":"json""#));
422 }
423
424 #[test]
425 fn chat_request_no_format_by_default() {
426 let request = ChatRequest::new("model", vec![]);
427 let json = serde_json::to_string(&request).unwrap();
428 assert!(!json.contains("format"));
429 }
430
431 #[tokio::test]
432 async fn infer_structured_fails_on_mock() {
433 let client = LlmClient::mock();
434 let result: Result<serde_json::Value, _> = client.infer_structured("test", "{}").await;
435 assert!(result.is_err());
436 assert!(result.unwrap_err().to_string().contains("Mock client"));
437 }
438}