1use crate::error::{SageError, SageResult};
4use serde::{Deserialize, Serialize};
5
6#[derive(Clone)]
8pub struct LlmClient {
9 client: reqwest::Client,
10 config: LlmConfig,
11}
12
13#[derive(Clone)]
15pub struct LlmConfig {
16 pub api_key: String,
18 pub base_url: String,
20 pub model: String,
22}
23
24impl LlmConfig {
25 pub fn from_env() -> Self {
27 Self {
28 api_key: std::env::var("SAGE_API_KEY").unwrap_or_default(),
29 base_url: std::env::var("SAGE_LLM_URL")
30 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()),
31 model: std::env::var("SAGE_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()),
32 }
33 }
34
35 pub fn mock() -> Self {
37 Self {
38 api_key: "mock".to_string(),
39 base_url: "mock".to_string(),
40 model: "mock".to_string(),
41 }
42 }
43
44 pub fn is_mock(&self) -> bool {
46 self.api_key == "mock"
47 }
48}
49
50impl LlmClient {
51 pub fn new(config: LlmConfig) -> Self {
53 Self {
54 client: reqwest::Client::new(),
55 config,
56 }
57 }
58
59 pub fn from_env() -> Self {
61 Self::new(LlmConfig::from_env())
62 }
63
64 pub fn mock() -> Self {
66 Self::new(LlmConfig::mock())
67 }
68
69 pub async fn infer_string(&self, prompt: &str) -> SageResult<String> {
71 if self.config.is_mock() {
72 return Ok(format!("[Mock LLM response for: {prompt}]"));
73 }
74
75 let request = ChatRequest {
76 model: &self.config.model,
77 messages: vec![ChatMessage {
78 role: "user",
79 content: prompt,
80 }],
81 };
82
83 let response = self
84 .client
85 .post(format!("{}/chat/completions", self.config.base_url))
86 .header("Authorization", format!("Bearer {}", self.config.api_key))
87 .header("Content-Type", "application/json")
88 .json(&request)
89 .send()
90 .await?;
91
92 if !response.status().is_success() {
93 let status = response.status();
94 let body = response.text().await.unwrap_or_default();
95 return Err(SageError::Llm(format!("API error {status}: {body}")));
96 }
97
98 let chat_response: ChatResponse = response.json().await?;
99 let content = chat_response
100 .choices
101 .into_iter()
102 .next()
103 .map(|c| c.message.content)
104 .unwrap_or_default();
105
106 Ok(content)
107 }
108
109 pub async fn infer<T>(&self, prompt: &str) -> SageResult<T>
111 where
112 T: serde::de::DeserializeOwned,
113 {
114 let response = self.infer_string(prompt).await?;
115
116 if let Ok(value) = serde_json::from_str(&response) {
118 return Ok(value);
119 }
120
121 let cleaned = response
123 .trim()
124 .strip_prefix("```json")
125 .unwrap_or(&response)
126 .strip_prefix("```")
127 .unwrap_or(&response)
128 .strip_suffix("```")
129 .unwrap_or(&response)
130 .trim();
131
132 serde_json::from_str(cleaned).map_err(|e| {
133 SageError::Llm(format!(
134 "Failed to parse LLM response as {}: {e}\nResponse: {response}",
135 std::any::type_name::<T>()
136 ))
137 })
138 }
139}
140
141#[derive(Serialize)]
142struct ChatRequest<'a> {
143 model: &'a str,
144 messages: Vec<ChatMessage<'a>>,
145}
146
147#[derive(Serialize)]
148struct ChatMessage<'a> {
149 role: &'a str,
150 content: &'a str,
151}
152
153#[derive(Deserialize)]
154struct ChatResponse {
155 choices: Vec<Choice>,
156}
157
158#[derive(Deserialize)]
159struct Choice {
160 message: ResponseMessage,
161}
162
163#[derive(Deserialize)]
164struct ResponseMessage {
165 content: String,
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[tokio::test]
173 async fn mock_client_returns_placeholder() {
174 let client = LlmClient::mock();
175 let response = client.infer_string("test prompt").await.unwrap();
176 assert!(response.contains("Mock LLM response"));
177 assert!(response.contains("test prompt"));
178 }
179}