Skip to main content

hermes_bot/agent/
claude.rs

1use crate::agent::protocol::{self, ContentBlock, StreamMessage};
2use crate::agent::{Agent, AgentEvent, AgentHandle};
3use crate::error::{HermesError, Result};
4use async_trait::async_trait;
5use std::path::Path;
6use std::process::Stdio;
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::process::Command;
9use tokio::sync::{mpsc, oneshot};
10use tracing::{debug, error, info, warn};
11
12pub struct ClaudeAgent;
13
14impl ClaudeAgent {
15    fn build_command(
16        repo_path: &Path,
17        allowed_tools: &[String],
18        system_prompt: Option<&str>,
19        resume_session_id: Option<&str>,
20        model: Option<&str>,
21    ) -> Command {
22        let mut cmd = Command::new("claude");
23        cmd.current_dir(repo_path);
24
25        cmd.arg("--output-format").arg("stream-json");
26        cmd.arg("--input-format").arg("stream-json");
27        cmd.arg("--verbose");
28
29        if let Some(sid) = resume_session_id {
30            cmd.arg("--resume").arg(sid);
31        }
32
33        // Only pass --model for new sessions; the CLI remembers for --resume.
34        if resume_session_id.is_none()
35            && let Some(m) = model
36        {
37            cmd.arg("--model").arg(m);
38        }
39
40        for tool in allowed_tools {
41            cmd.arg("--allowedTools").arg(tool);
42        }
43
44        if let Some(sp) = system_prompt {
45            cmd.arg("--append-system-prompt").arg(sp);
46        }
47
48        cmd.stdin(Stdio::piped());
49        cmd.stdout(Stdio::piped());
50        cmd.stderr(Stdio::piped());
51
52        cmd
53    }
54}
55
56#[async_trait]
57impl Agent for ClaudeAgent {
58    async fn spawn(
59        &self,
60        repo_path: &Path,
61        allowed_tools: &[String],
62        system_prompt: Option<&str>,
63        resume_session_id: Option<&str>,
64        model: Option<&str>,
65    ) -> Result<AgentHandle> {
66        let mut cmd = Self::build_command(
67            repo_path,
68            allowed_tools,
69            system_prompt,
70            resume_session_id,
71            model,
72        );
73        debug!("Spawning claude CLI: {:?}", cmd);
74
75        let mut child = cmd.spawn().map_err(|e| {
76            if e.kind() == std::io::ErrorKind::NotFound {
77                HermesError::ClaudeNotFound
78            } else {
79                HermesError::AgentSpawnFailed {
80                    reason: e.to_string(),
81                }
82            }
83        })?;
84
85        let stdin = child
86            .stdin
87            .take()
88            .ok_or_else(|| HermesError::AgentSpawnFailed {
89                reason: "stdin was not piped".into(),
90            })?;
91        let stdout = child
92            .stdout
93            .take()
94            .ok_or_else(|| HermesError::AgentSpawnFailed {
95                reason: "stdout was not piped".into(),
96            })?;
97        let stderr = child
98            .stderr
99            .take()
100            .ok_or_else(|| HermesError::AgentSpawnFailed {
101                reason: "stderr was not piped".into(),
102            })?;
103
104        // Channels: agent events (out) and user messages (in).
105        let (event_tx, event_rx) = mpsc::channel::<AgentEvent>(256);
106        let (user_tx, mut user_rx) = mpsc::channel::<String>(64);
107        let (kill_tx, kill_rx) = oneshot::channel::<()>();
108
109        // ── Stdout reader task ────────────────────────────────────────
110        let event_tx_stdout = event_tx.clone();
111        let mut stdin_writer = stdin;
112        // Shared mpsc for stdin writes (user messages + control responses).
113        let (stdin_tx, mut stdin_rx) = mpsc::channel::<String>(64);
114
115        tokio::spawn(async move {
116            let reader = BufReader::new(stdout);
117            let mut lines = reader.lines();
118
119            while let Ok(Some(line)) = lines.next_line().await {
120                let msg = match protocol::parse_line(&line) {
121                    Some(m) => m,
122                    None => continue,
123                };
124
125                match msg {
126                    StreamMessage::System(sys) => {
127                        if sys.subtype.as_deref() == Some("init") {
128                            let session_id = sys.session_id.unwrap_or_default();
129                            let model = sys.model.unwrap_or_default();
130                            info!("Claude session init: id={}, model={}", session_id, model);
131                            let _ = event_tx_stdout
132                                .send(AgentEvent::SessionInit { session_id, model })
133                                .await;
134                        }
135                    }
136                    StreamMessage::Assistant(assistant) => {
137                        if let Some(body) = assistant.message {
138                            for block in body.content {
139                                match block {
140                                    ContentBlock::Text { text } => {
141                                        let _ = event_tx_stdout.send(AgentEvent::Text(text)).await;
142                                    }
143                                    ContentBlock::ToolUse { name, input, .. } => {
144                                        let _ = event_tx_stdout
145                                            .send(AgentEvent::ToolUse { name, input })
146                                            .await;
147                                    }
148                                    ContentBlock::ToolResult { .. }
149                                    | ContentBlock::Thinking { .. }
150                                    | ContentBlock::Unknown => {}
151                                }
152                            }
153                        }
154                    }
155                    StreamMessage::Result(result) => {
156                        let _ = event_tx_stdout
157                            .send(AgentEvent::TurnComplete {
158                                result: result.result,
159                                subtype: result.subtype.unwrap_or_else(|| "success".to_string()),
160                                num_turns: result.num_turns.unwrap_or(0),
161                                duration_ms: result.duration_ms.unwrap_or(0),
162                                is_error: result.is_error.unwrap_or(false),
163                                session_id: result.session_id.unwrap_or_default(),
164                            })
165                            .await;
166                    }
167                    StreamMessage::ControlRequest(ctrl) => {
168                        if let Some(request_id) = ctrl.request_id {
169                            let tool_name = ctrl
170                                .request
171                                .as_ref()
172                                .and_then(|r| r.tool_name.as_deref())
173                                .unwrap_or("unknown");
174
175                            if tool_name == "AskUserQuestion" {
176                                // Forward the question to Slack instead of denying.
177                                let questions = ctrl
178                                    .request
179                                    .as_ref()
180                                    .and_then(|r| r.tool_input.clone())
181                                    .unwrap_or_default();
182                                info!(
183                                    "AskUserQuestion control_request (request_id={})",
184                                    request_id
185                                );
186                                let _ = event_tx_stdout
187                                    .send(AgentEvent::QuestionPending {
188                                        request_id,
189                                        questions,
190                                    })
191                                    .await;
192                            } else {
193                                // Forward unapproved tool requests to Slack for
194                                // interactive approval (mirrors the AskUserQuestion flow).
195                                info!(
196                                    "Tool approval requested: {} (request_id={})",
197                                    tool_name, request_id
198                                );
199                                let tool_input = ctrl
200                                    .request
201                                    .as_ref()
202                                    .and_then(|r| r.tool_input.clone())
203                                    .unwrap_or_default();
204                                let _ = event_tx_stdout
205                                    .send(AgentEvent::ToolApprovalPending {
206                                        request_id,
207                                        tool_name: tool_name.to_string(),
208                                        tool_input,
209                                    })
210                                    .await;
211                            }
212                        }
213                    }
214                    StreamMessage::ToolProgress(tp) => {
215                        let tool_name = tp.tool_name.unwrap_or_default();
216                        let _ = event_tx_stdout
217                            .send(AgentEvent::ToolProgress { tool_name })
218                            .await;
219                    }
220                    StreamMessage::User(_)
221                    | StreamMessage::StreamEvent(_)
222                    | StreamMessage::Unknown => {}
223                }
224            }
225
226            // Stdout closed — process likely exited.
227            debug!("Claude stdout reader finished");
228        });
229
230        // ── Stdin writer task ─────────────────────────────────────────
231        let stdin_tx_user = stdin_tx.clone();
232        tokio::spawn(async move {
233            while let Some(msg) = user_rx.recv().await {
234                match protocol::user_message(&msg, None) {
235                    Ok(json) => {
236                        let _ = stdin_tx_user.send(json).await;
237                    }
238                    Err(e) => {
239                        error!("Failed to serialize user message: {}", e);
240                        // Continue processing other messages instead of crashing.
241                    }
242                }
243            }
244        });
245
246        // ── Consolidated stdin writer ─────────────────────────────────
247        tokio::spawn(async move {
248            while let Some(line) = stdin_rx.recv().await {
249                if let Err(e) = stdin_writer.write_all(line.as_bytes()).await {
250                    warn!("Failed to write to claude stdin: {}", e);
251                    break;
252                }
253                if let Err(e) = stdin_writer.write_all(b"\n").await {
254                    warn!("Failed to write newline to claude stdin: {}", e);
255                    break;
256                }
257                if let Err(e) = stdin_writer.flush().await {
258                    warn!("Failed to flush claude stdin: {}", e);
259                    break;
260                }
261            }
262        });
263
264        // ── Stderr drainer ────────────────────────────────────────────
265        tokio::spawn(async move {
266            let reader = BufReader::new(stderr);
267            let mut lines = reader.lines();
268            while let Ok(Some(line)) = lines.next_line().await {
269                debug!("claude stderr: {}", line);
270            }
271        });
272
273        // ── Kill listener + process wait ──────────────────────────────
274        let event_tx_exit = event_tx;
275        tokio::spawn(async move {
276            tokio::select! {
277                _ = kill_rx => {
278                    info!("Kill signal received, terminating claude process");
279                    let _ = child.kill().await;
280                }
281                status = child.wait() => {
282                    match status {
283                        Ok(s) => {
284                            let code = s.code();
285                            debug!("Claude process exited with code: {:?}", code);
286                            let _ = event_tx_exit.send(AgentEvent::ProcessExited { code }).await;
287                        }
288                        Err(e) => {
289                            error!("Error waiting for claude process: {}", e);
290                            let _ = event_tx_exit.send(AgentEvent::ProcessExited { code: None }).await;
291                        }
292                    }
293                }
294            }
295        });
296
297        Ok(AgentHandle {
298            sender: user_tx,
299            receiver: event_rx,
300            kill_tx: Some(kill_tx),
301            session_id: None,
302            stdin_tx,
303        })
304    }
305}