Skip to main content

aster/providers/
claude_code.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use rmcp::model::Role;
4use serde_json::{json, Value};
5use std::ffi::OsString;
6use std::path::PathBuf;
7use std::process::Stdio;
8use tokio::io::{AsyncBufReadExt, BufReader};
9use tokio::process::Command;
10
11use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage, Usage};
12use super::errors::ProviderError;
13use super::utils::{filter_extensions_from_system_prompt, RequestLog};
14use crate::config::base::ClaudeCodeCommand;
15use crate::config::search_path::SearchPaths;
16use crate::config::{AsterMode, Config};
17use crate::conversation::message::{Message, MessageContent};
18use crate::model::ModelConfig;
19use crate::subprocess::configure_command_no_window;
20use rmcp::model::Tool;
21
22pub const CLAUDE_CODE_DEFAULT_MODEL: &str = "claude-sonnet-4-20250514";
23pub const CLAUDE_CODE_KNOWN_MODELS: &[&str] = &["sonnet", "opus"];
24pub const CLAUDE_CODE_DOC_URL: &str = "https://code.claude.com/docs/en/setup";
25
26#[derive(Debug, serde::Serialize)]
27pub struct ClaudeCodeProvider {
28    command: PathBuf,
29    model: ModelConfig,
30    #[serde(skip)]
31    name: String,
32}
33
34impl ClaudeCodeProvider {
35    pub async fn from_env(model: ModelConfig) -> Result<Self> {
36        let config = crate::config::Config::global();
37        let command: OsString = config.get_claude_code_command().unwrap_or_default().into();
38        let resolved_command = SearchPaths::builder().with_npm().resolve(command)?;
39
40        Ok(Self {
41            command: resolved_command,
42            model,
43            name: Self::metadata().name,
44        })
45    }
46
47    /// Convert aster messages to the format expected by claude CLI
48    fn messages_to_claude_format(&self, _system: &str, messages: &[Message]) -> Result<Value> {
49        let mut claude_messages = Vec::new();
50
51        for message in messages.iter().filter(|m| m.is_agent_visible()) {
52            let role = match message.role {
53                Role::User => "user",
54                Role::Assistant => "assistant",
55            };
56
57            let mut content_parts = Vec::new();
58            for content in &message.content {
59                match content {
60                    MessageContent::Text(text_content) => {
61                        content_parts.push(json!({
62                            "type": "text",
63                            "text": text_content.text
64                        }));
65                    }
66                    MessageContent::ToolRequest(tool_request) => {
67                        if let Ok(tool_call) = &tool_request.tool_call {
68                            content_parts.push(json!({
69                                "type": "tool_use",
70                                "id": tool_request.id,
71                                "name": tool_call.name,
72                                "input": tool_call.arguments
73                            }));
74                        }
75                    }
76                    MessageContent::ToolResponse(tool_response) => {
77                        if let Ok(result) = &tool_response.tool_result {
78                            // Convert tool result contents to text
79                            let content_text = result
80                                .content
81                                .iter()
82                                .filter_map(|content| match &content.raw {
83                                    rmcp::model::RawContent::Text(text_content) => {
84                                        Some(text_content.text.as_str())
85                                    }
86                                    _ => None,
87                                })
88                                .collect::<Vec<&str>>()
89                                .join("\n");
90
91                            content_parts.push(json!({
92                                "type": "tool_result",
93                                "tool_use_id": tool_response.id,
94                                "content": content_text
95                            }));
96                        }
97                    }
98                    _ => {
99                        // Skip other content types for now
100                    }
101                }
102            }
103
104            claude_messages.push(json!({
105                "role": role,
106                "content": content_parts
107            }));
108        }
109
110        Ok(json!(claude_messages))
111    }
112
113    /// Parse the JSON response from claude CLI
114    fn apply_permission_flags(cmd: &mut Command) -> Result<(), ProviderError> {
115        let config = Config::global();
116        let aster_mode = config.get_aster_mode().unwrap_or(AsterMode::Auto);
117
118        match aster_mode {
119            AsterMode::Auto => {
120                cmd.arg("--dangerously-skip-permissions");
121            }
122            AsterMode::SmartApprove => {
123                cmd.arg("--permission-mode").arg("acceptEdits");
124            }
125            AsterMode::Approve => {
126                return Err(ProviderError::RequestFailed(
127                    "\n\n\n### NOTE\n\n\n \
128                    Claude Code CLI provider does not support Approve mode.\n \
129                    Please use Auto (which will run anything it needs to) or \
130                    SmartApprove (most things will run or Chat Mode)\n\n\n"
131                        .to_string(),
132                ));
133            }
134            AsterMode::Chat => {
135                // Chat mode doesn't need permission flags
136            }
137        }
138        Ok(())
139    }
140
141    fn parse_claude_response(
142        &self,
143        json_lines: &[String],
144    ) -> Result<(Message, Usage), ProviderError> {
145        let mut all_text_content = Vec::new();
146        let mut usage = Usage::default();
147
148        // Join all lines and parse as a single JSON array
149        let full_response = json_lines.join("");
150        let json_array: Vec<Value> = serde_json::from_str(&full_response).map_err(|e| {
151            ProviderError::RequestFailed(format!("Failed to parse JSON response: {}", e))
152        })?;
153
154        for parsed in json_array {
155            if let Some(msg_type) = parsed.get("type").and_then(|t| t.as_str()) {
156                match msg_type {
157                    "assistant" => {
158                        if let Some(message) = parsed.get("message") {
159                            // Extract text content from this assistant message
160                            if let Some(content) = message.get("content").and_then(|c| c.as_array())
161                            {
162                                for item in content {
163                                    if let Some(content_type) =
164                                        item.get("type").and_then(|t| t.as_str())
165                                    {
166                                        if content_type == "text" {
167                                            if let Some(text) =
168                                                item.get("text").and_then(|t| t.as_str())
169                                            {
170                                                all_text_content.push(text.to_string());
171                                            }
172                                        }
173                                        // Skip tool_use - those are claude CLI's internal tools
174                                    }
175                                }
176                            }
177
178                            // Extract usage information
179                            if let Some(usage_info) = message.get("usage") {
180                                usage.input_tokens = usage_info
181                                    .get("input_tokens")
182                                    .and_then(|v| v.as_i64())
183                                    .map(|v| v as i32);
184                                usage.output_tokens = usage_info
185                                    .get("output_tokens")
186                                    .and_then(|v| v.as_i64())
187                                    .map(|v| v as i32);
188
189                                // Calculate total if not provided
190                                if usage.total_tokens.is_none() {
191                                    if let (Some(input), Some(output)) =
192                                        (usage.input_tokens, usage.output_tokens)
193                                    {
194                                        usage.total_tokens = Some(input + output);
195                                    }
196                                }
197                            }
198                        }
199                    }
200                    "result" => {
201                        // Extract additional usage info from result if available
202                        if let Some(result_usage) = parsed.get("usage") {
203                            if usage.input_tokens.is_none() {
204                                usage.input_tokens = result_usage
205                                    .get("input_tokens")
206                                    .and_then(|v| v.as_i64())
207                                    .map(|v| v as i32);
208                            }
209                            if usage.output_tokens.is_none() {
210                                usage.output_tokens = result_usage
211                                    .get("output_tokens")
212                                    .and_then(|v| v.as_i64())
213                                    .map(|v| v as i32);
214                            }
215                        }
216                    }
217                    _ => {} // Ignore other message types
218                }
219            }
220        }
221
222        // Combine all text content into a single message
223        let combined_text = all_text_content.join("\n\n");
224        if combined_text.is_empty() {
225            return Err(ProviderError::RequestFailed(
226                "No text content found in response".to_string(),
227            ));
228        }
229
230        let message_content = vec![MessageContent::text(combined_text)];
231
232        let response_message = Message::new(
233            Role::Assistant,
234            chrono::Utc::now().timestamp(),
235            message_content,
236        );
237
238        Ok((response_message, usage))
239    }
240
241    async fn execute_command(
242        &self,
243        system: &str,
244        messages: &[Message],
245        _tools: &[Tool],
246    ) -> Result<Vec<String>, ProviderError> {
247        let messages_json = self
248            .messages_to_claude_format(system, messages)
249            .map_err(|e| {
250                ProviderError::RequestFailed(format!("Failed to format messages: {}", e))
251            })?;
252
253        let filtered_system = filter_extensions_from_system_prompt(system);
254
255        if std::env::var("ASTER_CLAUDE_CODE_DEBUG").is_ok() {
256            println!("=== CLAUDE CODE PROVIDER DEBUG ===");
257            println!("Command: {:?}", self.command);
258            println!("Original system prompt length: {} chars", system.len());
259            println!(
260                "Filtered system prompt length: {} chars",
261                filtered_system.len()
262            );
263            println!("Filtered system prompt: {}", filtered_system);
264            println!(
265                "Messages JSON: {}",
266                serde_json::to_string_pretty(&messages_json)
267                    .unwrap_or_else(|_| "Failed to serialize".to_string())
268            );
269            println!("================================");
270        }
271
272        let mut cmd = Command::new(&self.command);
273        configure_command_no_window(&mut cmd);
274        cmd.arg("-p")
275            .arg(messages_json.to_string())
276            .arg("--system-prompt")
277            .arg(&filtered_system);
278
279        // Only pass model parameter if it's in the known models list
280        if CLAUDE_CODE_KNOWN_MODELS.contains(&self.model.model_name.as_str()) {
281            cmd.arg("--model").arg(&self.model.model_name);
282        }
283
284        cmd.arg("--verbose").arg("--output-format").arg("json");
285
286        // Add permission mode based on ASTER_MODE setting
287        Self::apply_permission_flags(&mut cmd)?;
288
289        cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
290
291        let mut child = cmd.spawn().map_err(|e| {
292            ProviderError::RequestFailed(format!(
293                "Failed to spawn Claude CLI command '{:?}': {}.",
294                self.command, e
295            ))
296        })?;
297
298        let stdout = child
299            .stdout
300            .take()
301            .ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?;
302
303        let mut reader = BufReader::new(stdout);
304        let mut lines = Vec::new();
305        let mut line = String::new();
306
307        loop {
308            line.clear();
309            match reader.read_line(&mut line).await {
310                Ok(0) => break, // EOF
311                Ok(_) => {
312                    let trimmed = line.trim();
313                    if !trimmed.is_empty() {
314                        lines.push(trimmed.to_string());
315                    }
316                }
317                Err(e) => {
318                    return Err(ProviderError::RequestFailed(format!(
319                        "Failed to read output: {}",
320                        e
321                    )));
322                }
323            }
324        }
325
326        let exit_status = child.wait().await.map_err(|e| {
327            ProviderError::RequestFailed(format!("Failed to wait for command: {}", e))
328        })?;
329
330        if !exit_status.success() {
331            return Err(ProviderError::RequestFailed(format!(
332                "Command failed with exit code: {:?}",
333                exit_status.code()
334            )));
335        }
336
337        tracing::debug!("Command executed successfully, got {} lines", lines.len());
338        for (i, line) in lines.iter().enumerate() {
339            tracing::debug!("Line {}: {}", i, line);
340        }
341
342        Ok(lines)
343    }
344
345    /// Generate a simple session description without calling subprocess
346    fn generate_simple_session_description(
347        &self,
348        messages: &[Message],
349    ) -> Result<(Message, ProviderUsage), ProviderError> {
350        // Extract the first user message text
351        let description = messages
352            .iter()
353            .find(|m| m.role == Role::User)
354            .and_then(|m| {
355                m.content.iter().find_map(|c| match c {
356                    MessageContent::Text(text_content) => Some(&text_content.text),
357                    _ => None,
358                })
359            })
360            .map(|text| {
361                // Take first few words, limit to 4 words
362                text.split_whitespace()
363                    .take(4)
364                    .collect::<Vec<_>>()
365                    .join(" ")
366            })
367            .unwrap_or_else(|| "Simple task".to_string());
368
369        if std::env::var("ASTER_CLAUDE_CODE_DEBUG").is_ok() {
370            println!("=== CLAUDE CODE PROVIDER DEBUG ===");
371            println!("Generated simple session description: {}", description);
372            println!("Skipped subprocess call for session description");
373            println!("================================");
374        }
375
376        let message = Message::new(
377            Role::Assistant,
378            chrono::Utc::now().timestamp(),
379            vec![MessageContent::text(description.clone())],
380        );
381
382        let usage = Usage::default();
383
384        Ok((
385            message,
386            ProviderUsage::new(self.model.model_name.clone(), usage),
387        ))
388    }
389}
390
391#[async_trait]
392impl Provider for ClaudeCodeProvider {
393    fn metadata() -> ProviderMetadata {
394        ProviderMetadata::new(
395            "claude-code",
396            "Claude Code CLI",
397            "Requires claude CLI installed, no MCPs. Use Anthropic provider for full features.",
398            CLAUDE_CODE_DEFAULT_MODEL,
399            CLAUDE_CODE_KNOWN_MODELS.to_vec(),
400            CLAUDE_CODE_DOC_URL,
401            vec![ConfigKey::from_value_type::<ClaudeCodeCommand>(true, false)],
402        )
403    }
404
405    fn get_name(&self) -> &str {
406        &self.name
407    }
408
409    fn get_model_config(&self) -> ModelConfig {
410        // Return the model config with appropriate context limit for Claude models
411        self.model.clone()
412    }
413
414    #[tracing::instrument(
415        skip(self, model_config, system, messages, tools),
416        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
417    )]
418    async fn complete_with_model(
419        &self,
420        model_config: &ModelConfig,
421        system: &str,
422        messages: &[Message],
423        tools: &[Tool],
424    ) -> Result<(Message, ProviderUsage), ProviderError> {
425        // Check if this is a session description request (short system prompt asking for 4 words or less)
426        if system.contains("four words or less") || system.contains("4 words or less") {
427            return self.generate_simple_session_description(messages);
428        }
429
430        let json_lines = self.execute_command(system, messages, tools).await?;
431
432        let (message, usage) = self.parse_claude_response(&json_lines)?;
433
434        // Create a dummy payload for debug tracing
435        let payload = json!({
436            "command": self.command,
437            "model": model_config.model_name,
438            "system": system,
439            "messages": messages.len()
440        });
441        let mut log = RequestLog::start(model_config, &payload)?;
442
443        let response = json!({
444            "lines": json_lines.len(),
445            "usage": usage
446        });
447
448        log.write(&response, Some(&usage))?;
449
450        Ok((
451            message,
452            ProviderUsage::new(model_config.model_name.clone(), usage),
453        ))
454    }
455}