scud/llm/
client.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9// Anthropic API structures
10#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12    model: String,
13    max_tokens: u32,
14    messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19    role: String,
20    content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25    content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30    text: String,
31}
32
33// OpenAI-compatible API structures (used by xAI, OpenAI, OpenRouter)
34#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36    model: String,
37    max_tokens: u32,
38    messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43    role: String,
44    content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49    choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54    message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59    content: String,
60}
61
62pub struct LLMClient {
63    config: Config,
64    api_key: String,
65    client: reqwest::Client,
66}
67
68impl LLMClient {
69    pub fn new() -> Result<Self> {
70        let storage = Storage::new(None);
71        let config = storage.load_config()?;
72
73        let api_key = if config.requires_api_key() {
74            env::var(config.api_key_env_var()).with_context(|| {
75                format!("{} environment variable not set", config.api_key_env_var())
76            })?
77        } else {
78            String::new() // Claude CLI doesn't need API key
79        };
80
81        Ok(LLMClient {
82            config,
83            api_key,
84            client: reqwest::Client::new(),
85        })
86    }
87
88    pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
89        let storage = Storage::new(Some(project_root));
90        let config = storage.load_config()?;
91
92        let api_key = if config.requires_api_key() {
93            env::var(config.api_key_env_var()).with_context(|| {
94                format!("{} environment variable not set", config.api_key_env_var())
95            })?
96        } else {
97            String::new() // Claude CLI doesn't need API key
98        };
99
100        Ok(LLMClient {
101            config,
102            api_key,
103            client: reqwest::Client::new(),
104        })
105    }
106
107    pub async fn complete(&self, prompt: &str) -> Result<String> {
108        self.complete_with_model(prompt, None, None).await
109    }
110
111    /// Complete using the smart model (for validation/analysis tasks with large context)
112    /// Use user override if provided, otherwise fall back to configured smart_model
113    pub async fn complete_smart(&self, prompt: &str, model_override: Option<&str>) -> Result<String> {
114        let model = model_override.unwrap_or(self.config.smart_model());
115        let provider = self.config.smart_provider();
116        self.complete_with_model(prompt, Some(model), Some(provider)).await
117    }
118
119    /// Complete using the fast model (for generation tasks)
120    /// Use user override if provided, otherwise fall back to configured fast_model
121    pub async fn complete_fast(&self, prompt: &str, model_override: Option<&str>) -> Result<String> {
122        let model = model_override.unwrap_or(self.config.fast_model());
123        let provider = self.config.fast_provider();
124        self.complete_with_model(prompt, Some(model), Some(provider)).await
125    }
126
127    pub async fn complete_with_model(
128        &self,
129        prompt: &str,
130        model_override: Option<&str>,
131        provider_override: Option<&str>,
132    ) -> Result<String> {
133        let provider = provider_override.unwrap_or(&self.config.llm.provider);
134        match provider.as_ref() {
135            "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
136            "codex" => self.complete_codex_cli(prompt, model_override).await,
137            "anthropic" => {
138                self.complete_anthropic_with_model(prompt, model_override)
139                    .await
140            }
141            "xai" | "openai" | "openrouter" => {
142                self.complete_openai_compatible_with_model(prompt, model_override)
143                    .await
144            }
145            _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
146        }
147    }
148
149    async fn complete_anthropic_with_model(
150        &self,
151        prompt: &str,
152        model_override: Option<&str>,
153    ) -> Result<String> {
154        let model = model_override.unwrap_or(&self.config.llm.model);
155        let request = AnthropicRequest {
156            model: model.to_string(),
157            max_tokens: self.config.llm.max_tokens,
158            messages: vec![AnthropicMessage {
159                role: "user".to_string(),
160                content: prompt.to_string(),
161            }],
162        };
163
164        let response = self
165            .client
166            .post(self.config.api_endpoint())
167            .header("x-api-key", &self.api_key)
168            .header("anthropic-version", "2023-06-01")
169            .header("content-type", "application/json")
170            .json(&request)
171            .send()
172            .await
173            .context("Failed to send request to Anthropic API")?;
174
175        if !response.status().is_success() {
176            let status = response.status();
177            let error_text = response.text().await.unwrap_or_default();
178            anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
179        }
180
181        let api_response: AnthropicResponse = response
182            .json()
183            .await
184            .context("Failed to parse Anthropic API response")?;
185
186        Ok(api_response
187            .content
188            .first()
189            .map(|c| c.text.clone())
190            .unwrap_or_default())
191    }
192
193    async fn complete_openai_compatible_with_model(
194        &self,
195        prompt: &str,
196        model_override: Option<&str>,
197    ) -> Result<String> {
198        let model = model_override.unwrap_or(&self.config.llm.model);
199        let request = OpenAIRequest {
200            model: model.to_string(),
201            max_tokens: self.config.llm.max_tokens,
202            messages: vec![OpenAIMessage {
203                role: "user".to_string(),
204                content: prompt.to_string(),
205            }],
206        };
207
208        let mut request_builder = self
209            .client
210            .post(self.config.api_endpoint())
211            .header("authorization", format!("Bearer {}", self.api_key))
212            .header("content-type", "application/json");
213
214        // OpenRouter requires additional headers
215        if self.config.llm.provider == "openrouter" {
216            request_builder = request_builder
217                .header("HTTP-Referer", "https://github.com/scud-cli")
218                .header("X-Title", "SCUD Task Master");
219        }
220
221        let response = request_builder
222            .json(&request)
223            .send()
224            .await
225            .with_context(|| {
226                format!("Failed to send request to {} API", self.config.llm.provider)
227            })?;
228
229        if !response.status().is_success() {
230            let status = response.status();
231            let error_text = response.text().await.unwrap_or_default();
232            anyhow::bail!(
233                "{} API error ({}): {}",
234                self.config.llm.provider,
235                status,
236                error_text
237            );
238        }
239
240        let api_response: OpenAIResponse = response.json().await.with_context(|| {
241            format!("Failed to parse {} API response", self.config.llm.provider)
242        })?;
243
244        Ok(api_response
245            .choices
246            .first()
247            .map(|c| c.message.content.clone())
248            .unwrap_or_default())
249    }
250
251    pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
252    where
253        T: serde::de::DeserializeOwned,
254    {
255        self.complete_json_with_model(prompt, None).await
256    }
257
258    /// Complete JSON using the smart model (for validation/analysis tasks)
259    pub async fn complete_json_smart<T>(&self, prompt: &str, model_override: Option<&str>) -> Result<T>
260    where
261        T: serde::de::DeserializeOwned,
262    {
263        let response_text = self.complete_smart(prompt, model_override).await?;
264        Self::parse_json_response(&response_text)
265    }
266
267    /// Complete JSON using the fast model (for generation tasks)
268    pub async fn complete_json_fast<T>(&self, prompt: &str, model_override: Option<&str>) -> Result<T>
269    where
270        T: serde::de::DeserializeOwned,
271    {
272        let response_text = self.complete_fast(prompt, model_override).await?;
273        Self::parse_json_response(&response_text)
274    }
275
276    pub async fn complete_json_with_model<T>(
277        &self,
278        prompt: &str,
279        model_override: Option<&str>,
280    ) -> Result<T>
281    where
282        T: serde::de::DeserializeOwned,
283    {
284        let response_text = self.complete_with_model(prompt, model_override, None).await?;
285        Self::parse_json_response(&response_text)
286    }
287
288    fn parse_json_response<T>(response_text: &str) -> Result<T>
289    where
290        T: serde::de::DeserializeOwned,
291    {
292        // Try to find JSON in the response (LLM might include markdown or explanations)
293        let json_str = Self::extract_json(response_text);
294
295        serde_json::from_str(json_str).with_context(|| {
296            // Provide helpful error context
297            let preview = if json_str.len() > 500 {
298                format!("{}...", &json_str[..500])
299            } else {
300                json_str.to_string()
301            };
302            format!(
303                "Failed to parse JSON from LLM response. Response preview:\n{}",
304                preview
305            )
306        })
307    }
308
309    /// Extract JSON from LLM response, handling markdown code blocks and extra text
310    fn extract_json(response: &str) -> &str {
311        // First, try to extract from markdown code blocks
312        if let Some(start) = response.find("```json") {
313            let content_start = start + 7; // Skip "```json"
314            if let Some(end) = response[content_start..].find("```") {
315                return response[content_start..content_start + end].trim();
316            }
317        }
318
319        // Try plain code blocks
320        if let Some(start) = response.find("```") {
321            let content_start = start + 3;
322            // Skip language identifier if present (e.g., "```\n")
323            let content_start = response[content_start..]
324                .find('\n')
325                .map(|i| content_start + i + 1)
326                .unwrap_or(content_start);
327            if let Some(end) = response[content_start..].find("```") {
328                return response[content_start..content_start + end].trim();
329            }
330        }
331
332        // Try to find array JSON
333        if let Some(start) = response.find('[') {
334            if let Some(end) = response.rfind(']') {
335                if end > start {
336                    return &response[start..=end];
337                }
338            }
339        }
340
341        // Try to find object JSON
342        if let Some(start) = response.find('{') {
343            if let Some(end) = response.rfind('}') {
344                if end > start {
345                    return &response[start..=end];
346                }
347            }
348        }
349
350        response.trim()
351    }
352
353    async fn complete_claude_cli(
354        &self,
355        prompt: &str,
356        model_override: Option<&str>,
357    ) -> Result<String> {
358        use std::process::Stdio;
359        use tokio::io::AsyncWriteExt;
360        use tokio::process::Command;
361
362        let model = model_override.unwrap_or(&self.config.llm.model);
363
364        // Build the claude command
365        let mut cmd = Command::new("claude");
366        cmd.arg("-p") // Print mode (headless)
367            .arg("--output-format")
368            .arg("json")
369            .arg("--model")
370            .arg(model)
371            .stdin(Stdio::piped())
372            .stdout(Stdio::piped())
373            .stderr(Stdio::piped());
374
375        // Spawn the process
376        let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
377
378        // Write prompt to stdin
379        if let Some(mut stdin) = child.stdin.take() {
380            stdin
381                .write_all(prompt.as_bytes())
382                .await
383                .context("Failed to write prompt to claude stdin")?;
384            drop(stdin); // Close stdin
385        }
386
387        // Wait for completion
388        let output = child
389            .wait_with_output()
390            .await
391            .context("Failed to wait for claude command")?;
392
393        if !output.status.success() {
394            let stderr = String::from_utf8_lossy(&output.stderr);
395            anyhow::bail!("Claude CLI error: {}", stderr);
396        }
397
398        // Parse JSON output
399        let stdout =
400            String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
401
402        #[derive(Deserialize)]
403        struct ClaudeCliResponse {
404            result: String,
405        }
406
407        let response: ClaudeCliResponse =
408            serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
409
410        Ok(response.result)
411    }
412
413    async fn complete_codex_cli(
414        &self,
415        prompt: &str,
416        model_override: Option<&str>,
417    ) -> Result<String> {
418        use std::process::Stdio;
419        use tokio::io::AsyncWriteExt;
420        use tokio::process::Command;
421
422        let model = model_override.unwrap_or(&self.config.llm.model);
423
424        // Build the codex command
425        // Codex CLI uses similar headless mode to Claude Code
426        let mut cmd = Command::new("codex");
427        cmd.arg("-p") // Prompt mode (headless/non-interactive)
428            .arg("--model")
429            .arg(model)
430            .arg("--output-format")
431            .arg("json")
432            .stdin(Stdio::piped())
433            .stdout(Stdio::piped())
434            .stderr(Stdio::piped());
435
436        // Spawn the process
437        let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
438
439        // Write prompt to stdin
440        if let Some(mut stdin) = child.stdin.take() {
441            stdin
442                .write_all(prompt.as_bytes())
443                .await
444                .context("Failed to write prompt to codex stdin")?;
445            drop(stdin); // Close stdin
446        }
447
448        // Wait for completion
449        let output = child
450            .wait_with_output()
451            .await
452            .context("Failed to wait for codex command")?;
453
454        if !output.status.success() {
455            let stderr = String::from_utf8_lossy(&output.stderr);
456            anyhow::bail!("Codex CLI error: {}", stderr);
457        }
458
459        // Parse JSON output
460        let stdout =
461            String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
462
463        // Codex outputs JSON with a result field similar to Claude CLI
464        #[derive(Deserialize)]
465        struct CodexCliResponse {
466            result: String,
467        }
468
469        let response: CodexCliResponse =
470            serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
471
472        Ok(response.result)
473    }
474}