1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12 model: String,
13 max_tokens: u32,
14 messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19 role: String,
20 content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25 content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30 text: String,
31}
32
33#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36 model: String,
37 max_tokens: u32,
38 messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43 role: String,
44 content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49 choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54 message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59 content: String,
60}
61
62pub struct LLMClient {
63 config: Config,
64 api_key: String,
65 client: reqwest::Client,
66}
67
68impl LLMClient {
69 pub fn new() -> Result<Self> {
70 let storage = Storage::new(None);
71 let config = storage.load_config()?;
72
73 let api_key = if config.requires_api_key() {
74 env::var(config.api_key_env_var()).with_context(|| {
75 format!("{} environment variable not set", config.api_key_env_var())
76 })?
77 } else {
78 String::new() };
80
81 Ok(LLMClient {
82 config,
83 api_key,
84 client: reqwest::Client::new(),
85 })
86 }
87
88 pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
89 let storage = Storage::new(Some(project_root));
90 let config = storage.load_config()?;
91
92 let api_key = if config.requires_api_key() {
93 env::var(config.api_key_env_var()).with_context(|| {
94 format!("{} environment variable not set", config.api_key_env_var())
95 })?
96 } else {
97 String::new() };
99
100 Ok(LLMClient {
101 config,
102 api_key,
103 client: reqwest::Client::new(),
104 })
105 }
106
107 pub async fn complete(&self, prompt: &str) -> Result<String> {
108 self.complete_with_model(prompt, None, None).await
109 }
110
111 pub async fn complete_smart(&self, prompt: &str, model_override: Option<&str>) -> Result<String> {
114 let model = model_override.unwrap_or(self.config.smart_model());
115 let provider = self.config.smart_provider();
116 self.complete_with_model(prompt, Some(model), Some(provider)).await
117 }
118
119 pub async fn complete_fast(&self, prompt: &str, model_override: Option<&str>) -> Result<String> {
122 let model = model_override.unwrap_or(self.config.fast_model());
123 let provider = self.config.fast_provider();
124 self.complete_with_model(prompt, Some(model), Some(provider)).await
125 }
126
127 pub async fn complete_with_model(
128 &self,
129 prompt: &str,
130 model_override: Option<&str>,
131 provider_override: Option<&str>,
132 ) -> Result<String> {
133 let provider = provider_override.unwrap_or(&self.config.llm.provider);
134 match provider.as_ref() {
135 "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
136 "codex" => self.complete_codex_cli(prompt, model_override).await,
137 "anthropic" => {
138 self.complete_anthropic_with_model(prompt, model_override)
139 .await
140 }
141 "xai" | "openai" | "openrouter" => {
142 self.complete_openai_compatible_with_model(prompt, model_override)
143 .await
144 }
145 _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
146 }
147 }
148
149 async fn complete_anthropic_with_model(
150 &self,
151 prompt: &str,
152 model_override: Option<&str>,
153 ) -> Result<String> {
154 let model = model_override.unwrap_or(&self.config.llm.model);
155 let request = AnthropicRequest {
156 model: model.to_string(),
157 max_tokens: self.config.llm.max_tokens,
158 messages: vec![AnthropicMessage {
159 role: "user".to_string(),
160 content: prompt.to_string(),
161 }],
162 };
163
164 let response = self
165 .client
166 .post(self.config.api_endpoint())
167 .header("x-api-key", &self.api_key)
168 .header("anthropic-version", "2023-06-01")
169 .header("content-type", "application/json")
170 .json(&request)
171 .send()
172 .await
173 .context("Failed to send request to Anthropic API")?;
174
175 if !response.status().is_success() {
176 let status = response.status();
177 let error_text = response.text().await.unwrap_or_default();
178 anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
179 }
180
181 let api_response: AnthropicResponse = response
182 .json()
183 .await
184 .context("Failed to parse Anthropic API response")?;
185
186 Ok(api_response
187 .content
188 .first()
189 .map(|c| c.text.clone())
190 .unwrap_or_default())
191 }
192
193 async fn complete_openai_compatible_with_model(
194 &self,
195 prompt: &str,
196 model_override: Option<&str>,
197 ) -> Result<String> {
198 let model = model_override.unwrap_or(&self.config.llm.model);
199 let request = OpenAIRequest {
200 model: model.to_string(),
201 max_tokens: self.config.llm.max_tokens,
202 messages: vec![OpenAIMessage {
203 role: "user".to_string(),
204 content: prompt.to_string(),
205 }],
206 };
207
208 let mut request_builder = self
209 .client
210 .post(self.config.api_endpoint())
211 .header("authorization", format!("Bearer {}", self.api_key))
212 .header("content-type", "application/json");
213
214 if self.config.llm.provider == "openrouter" {
216 request_builder = request_builder
217 .header("HTTP-Referer", "https://github.com/scud-cli")
218 .header("X-Title", "SCUD Task Master");
219 }
220
221 let response = request_builder
222 .json(&request)
223 .send()
224 .await
225 .with_context(|| {
226 format!("Failed to send request to {} API", self.config.llm.provider)
227 })?;
228
229 if !response.status().is_success() {
230 let status = response.status();
231 let error_text = response.text().await.unwrap_or_default();
232 anyhow::bail!(
233 "{} API error ({}): {}",
234 self.config.llm.provider,
235 status,
236 error_text
237 );
238 }
239
240 let api_response: OpenAIResponse = response.json().await.with_context(|| {
241 format!("Failed to parse {} API response", self.config.llm.provider)
242 })?;
243
244 Ok(api_response
245 .choices
246 .first()
247 .map(|c| c.message.content.clone())
248 .unwrap_or_default())
249 }
250
251 pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
252 where
253 T: serde::de::DeserializeOwned,
254 {
255 self.complete_json_with_model(prompt, None).await
256 }
257
258 pub async fn complete_json_smart<T>(&self, prompt: &str, model_override: Option<&str>) -> Result<T>
260 where
261 T: serde::de::DeserializeOwned,
262 {
263 let response_text = self.complete_smart(prompt, model_override).await?;
264 Self::parse_json_response(&response_text)
265 }
266
267 pub async fn complete_json_fast<T>(&self, prompt: &str, model_override: Option<&str>) -> Result<T>
269 where
270 T: serde::de::DeserializeOwned,
271 {
272 let response_text = self.complete_fast(prompt, model_override).await?;
273 Self::parse_json_response(&response_text)
274 }
275
276 pub async fn complete_json_with_model<T>(
277 &self,
278 prompt: &str,
279 model_override: Option<&str>,
280 ) -> Result<T>
281 where
282 T: serde::de::DeserializeOwned,
283 {
284 let response_text = self.complete_with_model(prompt, model_override, None).await?;
285 Self::parse_json_response(&response_text)
286 }
287
288 fn parse_json_response<T>(response_text: &str) -> Result<T>
289 where
290 T: serde::de::DeserializeOwned,
291 {
292 let json_str = Self::extract_json(response_text);
294
295 serde_json::from_str(json_str).with_context(|| {
296 let preview = if json_str.len() > 500 {
298 format!("{}...", &json_str[..500])
299 } else {
300 json_str.to_string()
301 };
302 format!(
303 "Failed to parse JSON from LLM response. Response preview:\n{}",
304 preview
305 )
306 })
307 }
308
309 fn extract_json(response: &str) -> &str {
311 if let Some(start) = response.find("```json") {
313 let content_start = start + 7; if let Some(end) = response[content_start..].find("```") {
315 return response[content_start..content_start + end].trim();
316 }
317 }
318
319 if let Some(start) = response.find("```") {
321 let content_start = start + 3;
322 let content_start = response[content_start..]
324 .find('\n')
325 .map(|i| content_start + i + 1)
326 .unwrap_or(content_start);
327 if let Some(end) = response[content_start..].find("```") {
328 return response[content_start..content_start + end].trim();
329 }
330 }
331
332 if let Some(start) = response.find('[') {
334 if let Some(end) = response.rfind(']') {
335 if end > start {
336 return &response[start..=end];
337 }
338 }
339 }
340
341 if let Some(start) = response.find('{') {
343 if let Some(end) = response.rfind('}') {
344 if end > start {
345 return &response[start..=end];
346 }
347 }
348 }
349
350 response.trim()
351 }
352
353 async fn complete_claude_cli(
354 &self,
355 prompt: &str,
356 model_override: Option<&str>,
357 ) -> Result<String> {
358 use std::process::Stdio;
359 use tokio::io::AsyncWriteExt;
360 use tokio::process::Command;
361
362 let model = model_override.unwrap_or(&self.config.llm.model);
363
364 let mut cmd = Command::new("claude");
366 cmd.arg("-p") .arg("--output-format")
368 .arg("json")
369 .arg("--model")
370 .arg(model)
371 .stdin(Stdio::piped())
372 .stdout(Stdio::piped())
373 .stderr(Stdio::piped());
374
375 let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
377
378 if let Some(mut stdin) = child.stdin.take() {
380 stdin
381 .write_all(prompt.as_bytes())
382 .await
383 .context("Failed to write prompt to claude stdin")?;
384 drop(stdin); }
386
387 let output = child
389 .wait_with_output()
390 .await
391 .context("Failed to wait for claude command")?;
392
393 if !output.status.success() {
394 let stderr = String::from_utf8_lossy(&output.stderr);
395 anyhow::bail!("Claude CLI error: {}", stderr);
396 }
397
398 let stdout =
400 String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
401
402 #[derive(Deserialize)]
403 struct ClaudeCliResponse {
404 result: String,
405 }
406
407 let response: ClaudeCliResponse =
408 serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
409
410 Ok(response.result)
411 }
412
413 async fn complete_codex_cli(
414 &self,
415 prompt: &str,
416 model_override: Option<&str>,
417 ) -> Result<String> {
418 use std::process::Stdio;
419 use tokio::io::AsyncWriteExt;
420 use tokio::process::Command;
421
422 let model = model_override.unwrap_or(&self.config.llm.model);
423
424 let mut cmd = Command::new("codex");
427 cmd.arg("-p") .arg("--model")
429 .arg(model)
430 .arg("--output-format")
431 .arg("json")
432 .stdin(Stdio::piped())
433 .stdout(Stdio::piped())
434 .stderr(Stdio::piped());
435
436 let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
438
439 if let Some(mut stdin) = child.stdin.take() {
441 stdin
442 .write_all(prompt.as_bytes())
443 .await
444 .context("Failed to write prompt to codex stdin")?;
445 drop(stdin); }
447
448 let output = child
450 .wait_with_output()
451 .await
452 .context("Failed to wait for codex command")?;
453
454 if !output.status.success() {
455 let stderr = String::from_utf8_lossy(&output.stderr);
456 anyhow::bail!("Codex CLI error: {}", stderr);
457 }
458
459 let stdout =
461 String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
462
463 #[derive(Deserialize)]
465 struct CodexCliResponse {
466 result: String,
467 }
468
469 let response: CodexCliResponse =
470 serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
471
472 Ok(response.result)
473 }
474}