1use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone)]
12pub enum LlmProviderType {
13 Ollama,
15 OpenAICompatible,
17}
18
19#[derive(Debug, Clone)]
21pub struct LlmConfig {
22 pub provider: LlmProviderType,
23 pub endpoint: String,
24 pub model: String,
25 pub api_key: Option<String>,
26 pub timeout_secs: u64,
27}
28
29impl LlmConfig {
30 pub fn from_args(
35 provider: &str,
36 endpoint: Option<&str>,
37 model: Option<&str>,
38 api_key: Option<&str>,
39 timeout: u64,
40 ) -> Self {
41 let provider_type = match provider.to_lowercase().as_str() {
42 "ollama" => LlmProviderType::Ollama,
43 _ => LlmProviderType::OpenAICompatible,
44 };
45
46 let default_endpoint = match provider_type {
47 LlmProviderType::Ollama => "http://localhost:11434",
48 LlmProviderType::OpenAICompatible => "http://localhost:1234",
49 };
50
51 let default_model = match provider_type {
52 LlmProviderType::Ollama => "llama3.2",
53 LlmProviderType::OpenAICompatible => "gpt-3.5-turbo",
54 };
55
56 Self {
57 provider: provider_type,
58 endpoint: endpoint.unwrap_or(default_endpoint).to_string(),
59 model: model.unwrap_or(default_model).to_string(),
60 api_key: api_key.map(String::from),
61 timeout_secs: timeout,
62 }
63 }
64}
65
66#[derive(Serialize)]
69struct OllamaRequest {
70 model: String,
71 prompt: String,
72 stream: bool,
73 format: Option<String>,
74}
75
76#[derive(Deserialize)]
77struct OllamaResponse {
78 response: String,
79}
80
81#[derive(Serialize)]
84struct OpenAIRequest {
85 model: String,
86 messages: Vec<OpenAIMessage>,
87 temperature: f64,
88 response_format: Option<serde_json::Value>,
89}
90
91#[derive(Serialize, Deserialize)]
92struct OpenAIMessage {
93 role: String,
94 content: String,
95}
96
97#[derive(Deserialize)]
98struct OpenAIResponse {
99 choices: Vec<OpenAIChoice>,
100}
101
102#[derive(Deserialize)]
103struct OpenAIChoice {
104 message: OpenAIMessage,
105}
106
107pub struct LlmClient {
114 config: LlmConfig,
115}
116
117impl LlmClient {
118 pub fn new(config: LlmConfig) -> Self {
120 Self { config }
121 }
122
123 pub fn call_blocking(&self, prompt: &str) -> Result<String> {
127 let rt = tokio::runtime::Builder::new_current_thread()
128 .enable_all()
129 .build()
130 .context("Failed to create tokio runtime")?;
131
132 rt.block_on(self.call_async(prompt))
133 }
134
135 async fn call_async(&self, prompt: &str) -> Result<String> {
136 let client = reqwest::Client::builder()
137 .timeout(std::time::Duration::from_secs(
138 self.config.timeout_secs.max(120),
139 ))
140 .build()
141 .context("Failed to build HTTP client")?;
142
143 match self.config.provider {
144 LlmProviderType::Ollama => self.call_ollama(&client, prompt).await,
145 LlmProviderType::OpenAICompatible => self.call_openai_compatible(&client, prompt).await,
146 }
147 }
148
149 async fn call_ollama(&self, client: &reqwest::Client, prompt: &str) -> Result<String> {
150 let url = format!("{}/api/generate", self.config.endpoint);
151
152 let request = OllamaRequest {
153 model: self.config.model.clone(),
154 prompt: prompt.to_string(),
155 stream: false,
156 format: None,
160 };
161
162 tracing::debug!(
163 "Ollama request: model={}, endpoint={}",
164 self.config.model,
165 self.config.endpoint
166 );
167
168 let resp = client
169 .post(&url)
170 .json(&request)
171 .send()
172 .await
173 .context("Failed to send request to Ollama")?;
174
175 let body: OllamaResponse = resp
176 .json()
177 .await
178 .context("Failed to parse Ollama response")?;
179
180 tracing::debug!(
181 "Ollama raw response ({} chars): {}",
182 body.response.len(),
183 &body.response[..body.response.len().min(500)]
184 );
185
186 Ok(body.response)
187 }
188
189 async fn call_openai_compatible(
190 &self,
191 client: &reqwest::Client,
192 prompt: &str,
193 ) -> Result<String> {
194 let url = format!("{}/v1/chat/completions", self.config.endpoint);
195
196 let messages = vec![
197 OpenAIMessage {
198 role: "system".to_string(),
199 content: "You are a sarcastic code reviewer. Always respond with valid JSON."
200 .to_string(),
201 },
202 OpenAIMessage {
203 role: "user".to_string(),
204 content: prompt.to_string(),
205 },
206 ];
207
208 let request = OpenAIRequest {
209 model: self.config.model.clone(),
210 messages,
211 temperature: 0.8,
212 response_format: Some(serde_json::json!({"type": "json_object"})),
213 };
214
215 let mut req_builder = client.post(&url).json(&request);
216
217 if let Some(ref api_key) = self.config.api_key {
218 req_builder = req_builder.bearer_auth(api_key);
219 }
220
221 let resp = req_builder
222 .send()
223 .await
224 .context("Failed to send request to OpenAI-compatible endpoint")?;
225
226 let body: OpenAIResponse = resp
227 .json()
228 .await
229 .context("Failed to parse OpenAI-compatible response")?;
230
231 body.choices
232 .into_iter()
233 .next()
234 .map(|c| c.message.content)
235 .ok_or_else(|| anyhow::anyhow!("No choices in LLM response"))
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_config_defaults_for_ollama() {
245 let config = LlmConfig::from_args("ollama", None, None, None, 30);
248
249 assert!(
250 matches!(config.provider, LlmProviderType::Ollama),
251 "Provider type must be Ollama"
252 );
253 assert_eq!(
254 config.endpoint, "http://localhost:11434",
255 "Default Ollama endpoint must be localhost:11434"
256 );
257 assert_eq!(
258 config.model, "llama3.2",
259 "Default Ollama model must be llama3.2"
260 );
261 assert!(
262 config.api_key.is_none(),
263 "Ollama should not require an API key"
264 );
265 }
266
267 #[test]
268 fn test_config_defaults_for_openai_compatible() {
269 let config = LlmConfig::from_args("openai-compatible", None, None, None, 30);
272
273 assert!(
274 matches!(config.provider, LlmProviderType::OpenAICompatible),
275 "Provider type must be OpenAICompatible"
276 );
277 assert_eq!(
278 config.endpoint, "http://localhost:1234",
279 "Default OpenAI-compatible endpoint must be localhost:1234"
280 );
281 }
282
283 #[test]
284 fn test_config_overrides_defaults() {
285 let config = LlmConfig::from_args(
288 "ollama",
289 Some("http://custom:9999"),
290 Some("mistral"),
291 Some("sk-test"),
292 60,
293 );
294
295 assert_eq!(config.endpoint, "http://custom:9999");
296 assert_eq!(config.model, "mistral");
297 assert_eq!(config.api_key.as_deref(), Some("sk-test"));
298 assert_eq!(config.timeout_secs, 60);
299 }
300
301 #[test]
302 fn test_config_unknown_provider_defaults_to_openai_compatible() {
303 let config = LlmConfig::from_args("lmstudio", None, None, None, 30);
306 assert!(
307 matches!(config.provider, LlmProviderType::OpenAICompatible),
308 "Unknown provider '{}' should default to OpenAICompatible",
309 "lmstudio"
310 );
311 }
312}