Skip to main content

garbage_code_hunter/llm/
client.rs

1//! HTTP client for communicating with LLM endpoints.
2//!
3//! Supports two provider types:
4//! - Ollama: Local LLM inference via `/api/generate`
5//! - OpenAI-compatible: Any endpoint implementing the OpenAI chat completions API
6
7use anyhow::{Context, Result};
8use serde::{Deserialize, Serialize};
9
10/// Supported LLM provider types.
11#[derive(Debug, Clone)]
12pub enum LlmProviderType {
13    /// Local Ollama server (default endpoint: http://localhost:11434)
14    Ollama,
15    /// OpenAI-compatible API (default endpoint: http://localhost:1234 for LM Studio)
16    OpenAICompatible,
17}
18
19/// Configuration for connecting to an LLM endpoint.
20#[derive(Debug, Clone)]
21pub struct LlmConfig {
22    pub provider: LlmProviderType,
23    pub endpoint: String,
24    pub model: String,
25    pub api_key: Option<String>,
26    pub timeout_secs: u64,
27}
28
29impl LlmConfig {
30    /// Build configuration from CLI arguments with sensible defaults.
31    ///
32    /// - Ollama default endpoint: `http://localhost:11434`, model: `llama3.2`
33    /// - OpenAI-compatible default endpoint: `http://localhost:1234`, model: `default`
34    pub fn from_args(
35        provider: &str,
36        endpoint: Option<&str>,
37        model: Option<&str>,
38        api_key: Option<&str>,
39        timeout: u64,
40    ) -> Self {
41        let provider_type = match provider.to_lowercase().as_str() {
42            "ollama" => LlmProviderType::Ollama,
43            _ => LlmProviderType::OpenAICompatible,
44        };
45
46        let default_endpoint = match provider_type {
47            LlmProviderType::Ollama => "http://localhost:11434",
48            LlmProviderType::OpenAICompatible => "http://localhost:1234",
49        };
50
51        let default_model = match provider_type {
52            LlmProviderType::Ollama => "llama3.2",
53            LlmProviderType::OpenAICompatible => "gpt-3.5-turbo",
54        };
55
56        Self {
57            provider: provider_type,
58            endpoint: endpoint.unwrap_or(default_endpoint).to_string(),
59            model: model.unwrap_or(default_model).to_string(),
60            api_key: api_key.map(String::from),
61            timeout_secs: timeout,
62        }
63    }
64}
65
66// --- Ollama request/response types ---
67
68#[derive(Serialize)]
69struct OllamaRequest {
70    model: String,
71    prompt: String,
72    stream: bool,
73    format: Option<String>,
74}
75
76#[derive(Deserialize)]
77struct OllamaResponse {
78    response: String,
79}
80
81// --- OpenAI-compatible request/response types ---
82
83#[derive(Serialize)]
84struct OpenAIRequest {
85    model: String,
86    messages: Vec<OpenAIMessage>,
87    temperature: f64,
88    response_format: Option<serde_json::Value>,
89}
90
91#[derive(Serialize, Deserialize)]
92struct OpenAIMessage {
93    role: String,
94    content: String,
95}
96
97#[derive(Deserialize)]
98struct OpenAIResponse {
99    choices: Vec<OpenAIChoice>,
100}
101
102#[derive(Deserialize)]
103struct OpenAIChoice {
104    message: OpenAIMessage,
105}
106
107// --- Client ---
108
109/// HTTP client that communicates with LLM endpoints.
110///
111/// Creates a minimal tokio runtime for blocking HTTP calls.
112/// Each call is a single request-response cycle.
113pub struct LlmClient {
114    config: LlmConfig,
115}
116
117impl LlmClient {
118    /// Create a new client with the given configuration.
119    pub fn new(config: LlmConfig) -> Self {
120        Self { config }
121    }
122
123    /// Send a prompt to the LLM and return the response text.
124    ///
125    /// This is a blocking call that creates a temporary tokio runtime.
126    pub fn call_blocking(&self, prompt: &str) -> Result<String> {
127        let rt = tokio::runtime::Builder::new_current_thread()
128            .enable_all()
129            .build()
130            .context("Failed to create tokio runtime")?;
131
132        rt.block_on(self.call_async(prompt))
133    }
134
135    async fn call_async(&self, prompt: &str) -> Result<String> {
136        let client = reqwest::Client::builder()
137            .timeout(std::time::Duration::from_secs(
138                self.config.timeout_secs.max(120),
139            ))
140            .build()
141            .context("Failed to build HTTP client")?;
142
143        match self.config.provider {
144            LlmProviderType::Ollama => self.call_ollama(&client, prompt).await,
145            LlmProviderType::OpenAICompatible => self.call_openai_compatible(&client, prompt).await,
146        }
147    }
148
149    async fn call_ollama(&self, client: &reqwest::Client, prompt: &str) -> Result<String> {
150        let url = format!("{}/api/generate", self.config.endpoint);
151
152        let request = OllamaRequest {
153            model: self.config.model.clone(),
154            prompt: prompt.to_string(),
155            stream: false,
156            // Don't force JSON format — some models (gemma, etc.) return empty
157            // responses when the json format flag is set. Instead, we instruct
158            // JSON output in the prompt and parse it from the free-form response.
159            format: None,
160        };
161
162        tracing::debug!(
163            "Ollama request: model={}, endpoint={}",
164            self.config.model,
165            self.config.endpoint
166        );
167
168        let resp = client
169            .post(&url)
170            .json(&request)
171            .send()
172            .await
173            .context("Failed to send request to Ollama")?;
174
175        let body: OllamaResponse = resp
176            .json()
177            .await
178            .context("Failed to parse Ollama response")?;
179
180        tracing::debug!(
181            "Ollama raw response ({} chars): {}",
182            body.response.len(),
183            &body.response[..body.response.len().min(500)]
184        );
185
186        Ok(body.response)
187    }
188
189    async fn call_openai_compatible(
190        &self,
191        client: &reqwest::Client,
192        prompt: &str,
193    ) -> Result<String> {
194        let url = format!("{}/v1/chat/completions", self.config.endpoint);
195
196        let messages = vec![
197            OpenAIMessage {
198                role: "system".to_string(),
199                content: "You are a sarcastic code reviewer. Always respond with valid JSON."
200                    .to_string(),
201            },
202            OpenAIMessage {
203                role: "user".to_string(),
204                content: prompt.to_string(),
205            },
206        ];
207
208        let request = OpenAIRequest {
209            model: self.config.model.clone(),
210            messages,
211            temperature: 0.8,
212            response_format: Some(serde_json::json!({"type": "json_object"})),
213        };
214
215        let mut req_builder = client.post(&url).json(&request);
216
217        if let Some(ref api_key) = self.config.api_key {
218            req_builder = req_builder.bearer_auth(api_key);
219        }
220
221        let resp = req_builder
222            .send()
223            .await
224            .context("Failed to send request to OpenAI-compatible endpoint")?;
225
226        let body: OpenAIResponse = resp
227            .json()
228            .await
229            .context("Failed to parse OpenAI-compatible response")?;
230
231        body.choices
232            .into_iter()
233            .next()
234            .map(|c| c.message.content)
235            .ok_or_else(|| anyhow::anyhow!("No choices in LLM response"))
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn test_config_defaults_for_ollama() {
245        // Objective: Verify Ollama config has correct default endpoint and model.
246        // Invariants: Default endpoint must be localhost:11434, model must be llama3.2.
247        let config = LlmConfig::from_args("ollama", None, None, None, 30);
248
249        assert!(
250            matches!(config.provider, LlmProviderType::Ollama),
251            "Provider type must be Ollama"
252        );
253        assert_eq!(
254            config.endpoint, "http://localhost:11434",
255            "Default Ollama endpoint must be localhost:11434"
256        );
257        assert_eq!(
258            config.model, "llama3.2",
259            "Default Ollama model must be llama3.2"
260        );
261        assert!(
262            config.api_key.is_none(),
263            "Ollama should not require an API key"
264        );
265    }
266
267    #[test]
268    fn test_config_defaults_for_openai_compatible() {
269        // Objective: Verify OpenAI-compatible config has correct defaults.
270        // Invariants: Default endpoint must be localhost:1234.
271        let config = LlmConfig::from_args("openai-compatible", None, None, None, 30);
272
273        assert!(
274            matches!(config.provider, LlmProviderType::OpenAICompatible),
275            "Provider type must be OpenAICompatible"
276        );
277        assert_eq!(
278            config.endpoint, "http://localhost:1234",
279            "Default OpenAI-compatible endpoint must be localhost:1234"
280        );
281    }
282
283    #[test]
284    fn test_config_overrides_defaults() {
285        // Objective: Verify custom values override defaults.
286        // Invariants: All custom values must be preserved exactly.
287        let config = LlmConfig::from_args(
288            "ollama",
289            Some("http://custom:9999"),
290            Some("mistral"),
291            Some("sk-test"),
292            60,
293        );
294
295        assert_eq!(config.endpoint, "http://custom:9999");
296        assert_eq!(config.model, "mistral");
297        assert_eq!(config.api_key.as_deref(), Some("sk-test"));
298        assert_eq!(config.timeout_secs, 60);
299    }
300
301    #[test]
302    fn test_config_unknown_provider_defaults_to_openai_compatible() {
303        // Objective: Verify unknown provider strings default to OpenAI-compatible.
304        // Invariants: Any non-"ollama" string must produce OpenAICompatible variant.
305        let config = LlmConfig::from_args("lmstudio", None, None, None, 30);
306        assert!(
307            matches!(config.provider, LlmProviderType::OpenAICompatible),
308            "Unknown provider '{}' should default to OpenAICompatible",
309            "lmstudio"
310        );
311    }
312}