Skip to main content

acp_http_adapter/
process.rs

1use std::collections::{HashMap, VecDeque};
2use std::convert::Infallible;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use axum::response::sse::Event;
8use futures::{stream, Stream, StreamExt};
9use serde_json::{json, Value};
10use thiserror::Error;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::process::{Child, ChildStdin, Command};
13use tokio::sync::{broadcast, oneshot, Mutex};
14use tokio_stream::wrappers::BroadcastStream;
15
16use crate::registry::LaunchSpec;
17
18const RING_BUFFER_SIZE: usize = 1024;
19const STDERR_TAIL_SIZE: usize = 16;
20
21#[derive(Debug, Error)]
22pub enum AdapterError {
23    #[error("failed to spawn subprocess: {0}")]
24    Spawn(std::io::Error),
25    #[error("failed to capture subprocess stdin")]
26    MissingStdin,
27    #[error("failed to capture subprocess stdout")]
28    MissingStdout,
29    #[error("failed to capture subprocess stderr")]
30    MissingStderr,
31    #[error("invalid json-rpc envelope")]
32    InvalidEnvelope,
33    #[error("failed to serialize json-rpc message: {0}")]
34    Serialize(serde_json::Error),
35    #[error("failed to write subprocess stdin: {0}")]
36    Write(std::io::Error),
37    #[error("agent process exited before responding")]
38    Exited {
39        exit_code: Option<i32>,
40        stderr: Option<String>,
41    },
42    #[error("timeout waiting for response")]
43    Timeout,
44}
45
46#[derive(Debug)]
47pub enum PostOutcome {
48    Response(Value),
49    Accepted,
50}
51
52#[derive(Debug, Clone)]
53struct StreamMessage {
54    sequence: u64,
55    payload: Value,
56}
57
58#[derive(Debug)]
59pub struct AdapterRuntime {
60    stdin: Arc<Mutex<ChildStdin>>,
61    child: Arc<Mutex<Child>>,
62    pending: Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
63    sender: broadcast::Sender<StreamMessage>,
64    ring: Arc<Mutex<VecDeque<StreamMessage>>>,
65    sequence: Arc<AtomicU64>,
66    request_timeout: Duration,
67    shutting_down: AtomicBool,
68    spawned_at: Instant,
69    first_stdout: Arc<AtomicBool>,
70    stderr_tail: Arc<Mutex<VecDeque<String>>>,
71}
72
73impl AdapterRuntime {
74    pub async fn start(
75        launch: LaunchSpec,
76        request_timeout: Duration,
77    ) -> Result<Self, AdapterError> {
78        let spawn_start = Instant::now();
79
80        let mut command = Command::new(&launch.program);
81        command
82            .args(&launch.args)
83            .stdin(std::process::Stdio::piped())
84            .stdout(std::process::Stdio::piped())
85            .stderr(std::process::Stdio::piped());
86
87        for (key, value) in &launch.env {
88            command.env(key, value);
89        }
90
91        tracing::info!(
92            program = ?launch.program,
93            args = ?launch.args,
94            "spawning agent process"
95        );
96
97        let mut child = command.spawn().map_err(|err| {
98            tracing::error!(
99                program = ?launch.program,
100                error = %err,
101                "failed to spawn agent process"
102            );
103            AdapterError::Spawn(err)
104        })?;
105
106        let pid = child.id().unwrap_or(0);
107        let spawn_elapsed = spawn_start.elapsed();
108        tracing::info!(
109            pid = pid,
110            elapsed_ms = spawn_elapsed.as_millis() as u64,
111            "agent process spawned"
112        );
113
114        let stdin = child.stdin.take().ok_or(AdapterError::MissingStdin)?;
115        let stdout = child.stdout.take().ok_or(AdapterError::MissingStdout)?;
116        let stderr = child.stderr.take().ok_or(AdapterError::MissingStderr)?;
117
118        let (sender, _rx) = broadcast::channel(512);
119        let runtime = Self {
120            stdin: Arc::new(Mutex::new(stdin)),
121            child: Arc::new(Mutex::new(child)),
122            pending: Arc::new(Mutex::new(HashMap::new())),
123            sender,
124            ring: Arc::new(Mutex::new(VecDeque::with_capacity(RING_BUFFER_SIZE))),
125            sequence: Arc::new(AtomicU64::new(0)),
126            request_timeout,
127            shutting_down: AtomicBool::new(false),
128            spawned_at: spawn_start,
129            first_stdout: Arc::new(AtomicBool::new(false)),
130            stderr_tail: Arc::new(Mutex::new(VecDeque::with_capacity(STDERR_TAIL_SIZE))),
131        };
132
133        runtime.spawn_stdout_loop(stdout);
134        runtime.spawn_stderr_loop(stderr);
135        runtime.spawn_exit_watcher();
136
137        Ok(runtime)
138    }
139
140    pub async fn post(&self, payload: Value) -> Result<PostOutcome, AdapterError> {
141        if !payload.is_object() {
142            return Err(AdapterError::InvalidEnvelope);
143        }
144
145        let method: String = payload
146            .get("method")
147            .and_then(|v| v.as_str())
148            .unwrap_or("<none>")
149            .to_string();
150        let has_method = payload.get("method").is_some();
151        let id = payload.get("id");
152
153        if has_method && id.is_some() {
154            let id_value = id.expect("checked");
155            let key = id_key(id_value);
156            let (tx, rx) = oneshot::channel();
157
158            let pending_count = self.pending.lock().await.len();
159            tracing::info!(
160                method = %method,
161                id = %key,
162                pending_count = pending_count,
163                "post: request → agent (awaiting response)"
164            );
165
166            self.pending.lock().await.insert(key.clone(), tx);
167
168            let write_start = Instant::now();
169            if let Err(err) = self.send_to_subprocess(&payload).await {
170                tracing::error!(
171                    method = %method,
172                    id = %key,
173                    error = %err,
174                    "post: failed to write to agent stdin"
175                );
176                self.pending.lock().await.remove(&key);
177                return Err(err);
178            }
179            let write_ms = write_start.elapsed().as_millis() as u64;
180            tracing::debug!(
181                method = %method,
182                id = %key,
183                write_ms = write_ms,
184                "post: stdin write complete, waiting for response"
185            );
186
187            let wait_start = Instant::now();
188            match tokio::time::timeout(self.request_timeout, rx).await {
189                Ok(Ok(response)) => {
190                    let wait_ms = wait_start.elapsed().as_millis() as u64;
191                    tracing::info!(
192                        method = %method,
193                        id = %key,
194                        response_ms = wait_ms,
195                        total_ms = write_ms + wait_ms,
196                        "post: got response from agent"
197                    );
198                    Ok(PostOutcome::Response(response))
199                }
200                Ok(Err(_)) => {
201                    let wait_ms = wait_start.elapsed().as_millis() as u64;
202                    tracing::error!(
203                        method = %method,
204                        id = %key,
205                        wait_ms = wait_ms,
206                        "post: response channel dropped (agent process may have exited)"
207                    );
208                    self.pending.lock().await.remove(&key);
209                    if let Some((exit_code, stderr)) = self.try_process_exit_info().await {
210                        tracing::error!(
211                            method = %method,
212                            id = %key,
213                            exit_code = ?exit_code,
214                            stderr = ?stderr,
215                            "post: agent process exited before response channel completed"
216                        );
217                        return Err(AdapterError::Exited { exit_code, stderr });
218                    }
219                    Err(AdapterError::Timeout)
220                }
221                Err(_) => {
222                    let pending_keys: Vec<String> =
223                        self.pending.lock().await.keys().cloned().collect();
224                    tracing::error!(
225                        method = %method,
226                        id = %key,
227                        timeout_ms = self.request_timeout.as_millis() as u64,
228                        age_ms = self.spawned_at.elapsed().as_millis() as u64,
229                        pending_keys = ?pending_keys,
230                        first_stdout_seen = self.first_stdout.load(Ordering::Relaxed),
231                        "post: TIMEOUT waiting for agent response"
232                    );
233                    self.pending.lock().await.remove(&key);
234                    if let Some((exit_code, stderr)) = self.try_process_exit_info().await {
235                        tracing::error!(
236                            method = %method,
237                            id = %key,
238                            exit_code = ?exit_code,
239                            stderr = ?stderr,
240                            "post: agent process exited before timeout completed"
241                        );
242                        return Err(AdapterError::Exited { exit_code, stderr });
243                    }
244                    Err(AdapterError::Timeout)
245                }
246            }
247        } else {
248            tracing::debug!(
249                method = %method,
250                "post: notification → agent (fire-and-forget)"
251            );
252            self.send_to_subprocess(&payload).await?;
253            Ok(PostOutcome::Accepted)
254        }
255    }
256
257    async fn subscribe(
258        &self,
259        last_event_id: Option<u64>,
260    ) -> (Vec<(u64, Value)>, broadcast::Receiver<StreamMessage>) {
261        let replay = {
262            let ring = self.ring.lock().await;
263            ring.iter()
264                .filter(|message| {
265                    if let Some(last_event_id) = last_event_id {
266                        message.sequence > last_event_id
267                    } else {
268                        true
269                    }
270                })
271                .map(|message| (message.sequence, message.payload.clone()))
272                .collect::<Vec<_>>()
273        };
274        (replay, self.sender.subscribe())
275    }
276
277    pub async fn sse_stream(
278        self: Arc<Self>,
279        last_event_id: Option<u64>,
280    ) -> impl Stream<Item = Result<Event, Infallible>> + Send + 'static {
281        let (replay, rx) = self.subscribe(last_event_id).await;
282        let replay_stream = stream::iter(replay.into_iter().map(|(sequence, payload)| {
283            let event = Event::default()
284                .event("message")
285                .id(sequence.to_string())
286                .data(payload.to_string());
287            Ok(event)
288        }));
289
290        let live_stream = BroadcastStream::new(rx).filter_map(|item| async move {
291            match item {
292                Ok(message) => {
293                    let event = Event::default()
294                        .event("message")
295                        .id(message.sequence.to_string())
296                        .data(message.payload.to_string());
297                    Some(Ok(event))
298                }
299                Err(_) => None,
300            }
301        });
302
303        replay_stream.chain(live_stream)
304    }
305
306    /// Stream of raw JSON-RPC `Value` payloads (without SSE framing).
307    /// Useful for consumers that need to inspect the payload contents
308    /// rather than forward them as SSE events.
309    pub async fn value_stream(
310        self: Arc<Self>,
311        last_event_id: Option<u64>,
312    ) -> impl Stream<Item = Value> + Send + 'static {
313        let (replay, rx) = self.subscribe(last_event_id).await;
314        let replay_stream = stream::iter(replay.into_iter().map(|(_sequence, payload)| payload));
315        let live_stream = BroadcastStream::new(rx).filter_map(|item| async move {
316            match item {
317                Ok(message) => Some(message.payload),
318                Err(_) => None,
319            }
320        });
321        replay_stream.chain(live_stream)
322    }
323
324    pub async fn shutdown(&self) {
325        if self.shutting_down.swap(true, Ordering::SeqCst) {
326            return;
327        }
328
329        tracing::info!(
330            age_ms = self.spawned_at.elapsed().as_millis() as u64,
331            "shutting down agent process"
332        );
333
334        self.pending.lock().await.clear();
335        let mut child = self.child.lock().await;
336        match child.try_wait() {
337            Ok(Some(_)) => {}
338            Ok(None) => {
339                let _ = child.kill().await;
340                let _ = child.wait().await;
341            }
342            Err(_) => {
343                let _ = child.kill().await;
344            }
345        }
346    }
347
348    fn spawn_stdout_loop(&self, stdout: tokio::process::ChildStdout) {
349        let pending = self.pending.clone();
350        let sender = self.sender.clone();
351        let ring = self.ring.clone();
352        let sequence = self.sequence.clone();
353        let spawned_at = self.spawned_at;
354        let first_stdout = self.first_stdout.clone();
355
356        tokio::spawn(async move {
357            let mut lines = BufReader::new(stdout).lines();
358            let mut line_count: u64 = 0;
359
360            while let Ok(Some(line)) = lines.next_line().await {
361                let trimmed = line.trim();
362                if trimmed.is_empty() {
363                    continue;
364                }
365
366                line_count += 1;
367
368                if !first_stdout.swap(true, Ordering::Relaxed) {
369                    tracing::info!(
370                        first_stdout_ms = spawned_at.elapsed().as_millis() as u64,
371                        line_bytes = trimmed.len(),
372                        "agent process: first stdout line received"
373                    );
374                }
375
376                let payload = match serde_json::from_str::<Value>(trimmed) {
377                    Ok(payload) => payload,
378                    Err(err) => {
379                        tracing::warn!(
380                            error = %err,
381                            line_number = line_count,
382                            raw = %if trimmed.len() > 200 {
383                                format!("{}...", &trimmed[..200])
384                            } else {
385                                trimmed.to_string()
386                            },
387                            "agent stdout: invalid JSON"
388                        );
389                        json!({
390                            "jsonrpc": "2.0",
391                            "method": "_adapter/invalid_stdout",
392                            "params": {
393                                "error": err.to_string(),
394                                "raw": trimmed,
395                            }
396                        })
397                    }
398                };
399
400                let is_response = payload.get("id").is_some() && payload.get("method").is_none();
401                if is_response {
402                    let key = id_key(payload.get("id").expect("checked"));
403                    let has_error = payload.get("error").is_some();
404                    if let Some(tx) = pending.lock().await.remove(&key) {
405                        tracing::debug!(
406                            id = %key,
407                            has_error = has_error,
408                            age_ms = spawned_at.elapsed().as_millis() as u64,
409                            "agent stdout: response matched to pending request"
410                        );
411                        let _ = tx.send(payload.clone());
412                        // Also broadcast the response so SSE/notification subscribers
413                        // see it in order after preceding notifications. This lets the
414                        // SSE translation task detect turn completion after all
415                        // session/update events have been processed.
416                        let seq = sequence.fetch_add(1, Ordering::SeqCst) + 1;
417                        let message = StreamMessage {
418                            sequence: seq,
419                            payload,
420                        };
421                        {
422                            let mut guard = ring.lock().await;
423                            guard.push_back(message.clone());
424                            while guard.len() > RING_BUFFER_SIZE {
425                                guard.pop_front();
426                            }
427                        }
428                        let _ = sender.send(message);
429                        continue;
430                    } else {
431                        tracing::warn!(
432                            id = %key,
433                            has_error = has_error,
434                            "agent stdout: response has no matching pending request (orphan)"
435                        );
436                    }
437                }
438
439                let method = payload
440                    .get("method")
441                    .and_then(|v| v.as_str())
442                    .unwrap_or("<none>");
443                tracing::debug!(
444                    method = method,
445                    line_number = line_count,
446                    "agent stdout: notification/event → SSE broadcast"
447                );
448
449                let seq = sequence.fetch_add(1, Ordering::SeqCst) + 1;
450                let message = StreamMessage {
451                    sequence: seq,
452                    payload,
453                };
454
455                {
456                    let mut guard = ring.lock().await;
457                    guard.push_back(message.clone());
458                    while guard.len() > RING_BUFFER_SIZE {
459                        guard.pop_front();
460                    }
461                }
462
463                let _ = sender.send(message);
464            }
465
466            tracing::info!(
467                total_lines = line_count,
468                age_ms = spawned_at.elapsed().as_millis() as u64,
469                "agent stdout: stream ended"
470            );
471        });
472    }
473
474    fn spawn_stderr_loop(&self, stderr: tokio::process::ChildStderr) {
475        let spawned_at = self.spawned_at;
476        let stderr_tail = self.stderr_tail.clone();
477
478        tokio::spawn(async move {
479            let mut lines = BufReader::new(stderr).lines();
480            let mut line_count: u64 = 0;
481
482            while let Ok(Some(line)) = lines.next_line().await {
483                line_count += 1;
484                {
485                    let mut tail = stderr_tail.lock().await;
486                    tail.push_back(line.clone());
487                    while tail.len() > STDERR_TAIL_SIZE {
488                        tail.pop_front();
489                    }
490                }
491                tracing::info!(
492                    line_number = line_count,
493                    age_ms = spawned_at.elapsed().as_millis() as u64,
494                    "agent stderr: {}",
495                    line
496                );
497            }
498
499            tracing::debug!(
500                total_lines = line_count,
501                age_ms = spawned_at.elapsed().as_millis() as u64,
502                "agent stderr: stream ended"
503            );
504        });
505    }
506
507    fn spawn_exit_watcher(&self) {
508        let child = self.child.clone();
509        let sender = self.sender.clone();
510        let ring = self.ring.clone();
511        let sequence = self.sequence.clone();
512        let spawned_at = self.spawned_at;
513        let pending = self.pending.clone();
514
515        tokio::spawn(async move {
516            let status = {
517                let mut guard = child.lock().await;
518                guard.wait().await.ok()
519            };
520
521            let age_ms = spawned_at.elapsed().as_millis() as u64;
522            let pending_count = pending.lock().await.len();
523
524            if let Some(status) = status {
525                tracing::warn!(
526                    success = status.success(),
527                    code = status.code(),
528                    age_ms = age_ms,
529                    pending_requests = pending_count,
530                    "agent process exited"
531                );
532
533                let payload = json!({
534                    "jsonrpc": "2.0",
535                    "method": "_adapter/agent_exited",
536                    "params": {
537                        "success": status.success(),
538                        "code": status.code(),
539                    }
540                });
541
542                let seq = sequence.fetch_add(1, Ordering::SeqCst) + 1;
543                let message = StreamMessage {
544                    sequence: seq,
545                    payload,
546                };
547
548                {
549                    let mut guard = ring.lock().await;
550                    guard.push_back(message.clone());
551                    while guard.len() > RING_BUFFER_SIZE {
552                        guard.pop_front();
553                    }
554                }
555
556                let _ = sender.send(message);
557            } else {
558                tracing::error!(
559                    age_ms = age_ms,
560                    pending_requests = pending_count,
561                    "agent process: failed to get exit status"
562                );
563            }
564        });
565    }
566
567    async fn send_to_subprocess(&self, payload: &Value) -> Result<(), AdapterError> {
568        let method = payload
569            .get("method")
570            .and_then(|v| v.as_str())
571            .unwrap_or("<none>");
572        let id = payload.get("id").map(|v| v.to_string()).unwrap_or_default();
573
574        tracing::debug!(
575            method = method,
576            id = %id,
577            bytes = serde_json::to_vec(payload).map(|b| b.len()).unwrap_or(0),
578            "stdin: writing message to agent"
579        );
580
581        let mut stdin = self.stdin.lock().await;
582        let bytes = serde_json::to_vec(payload).map_err(AdapterError::Serialize)?;
583        stdin.write_all(&bytes).await.map_err(|err| {
584            tracing::error!(method = method, id = %id, error = %err, "stdin: write_all failed");
585            AdapterError::Write(err)
586        })?;
587        stdin.write_all(b"\n").await.map_err(|err| {
588            tracing::error!(method = method, id = %id, error = %err, "stdin: newline write failed");
589            AdapterError::Write(err)
590        })?;
591        stdin.flush().await.map_err(|err| {
592            tracing::error!(method = method, id = %id, error = %err, "stdin: flush failed");
593            AdapterError::Write(err)
594        })?;
595
596        tracing::debug!(method = method, id = %id, "stdin: write+flush complete");
597        Ok(())
598    }
599
600    async fn try_process_exit_info(&self) -> Option<(Option<i32>, Option<String>)> {
601        let mut child = self.child.lock().await;
602        match child.try_wait() {
603            Ok(Some(status)) => {
604                let exit_code = status.code();
605                drop(child);
606                let stderr = self.stderr_tail_summary().await;
607                Some((exit_code, stderr))
608            }
609            Ok(None) => None,
610            Err(_) => None,
611        }
612    }
613
614    pub async fn stderr_tail_summary(&self) -> Option<String> {
615        let tail = self.stderr_tail.lock().await;
616        if tail.is_empty() {
617            return None;
618        }
619        Some(tail.iter().cloned().collect::<Vec<_>>().join("\n"))
620    }
621}
622
623fn id_key(value: &Value) -> String {
624    serde_json::to_string(value).unwrap_or_else(|_| "null".to_string())
625}