Skip to main content

acp_http_adapter/
process.rs

1use std::collections::{HashMap, VecDeque};
2use std::convert::Infallible;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use axum::response::sse::Event;
8use futures::{stream, Stream, StreamExt};
9use serde_json::{json, Value};
10use thiserror::Error;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::process::{Child, ChildStdin, Command};
13use tokio::sync::{broadcast, oneshot, Mutex};
14use tokio_stream::wrappers::BroadcastStream;
15
16use crate::registry::LaunchSpec;
17
18const RING_BUFFER_SIZE: usize = 1024;
19
20#[derive(Debug, Error)]
21pub enum AdapterError {
22    #[error("failed to spawn subprocess: {0}")]
23    Spawn(std::io::Error),
24    #[error("failed to capture subprocess stdin")]
25    MissingStdin,
26    #[error("failed to capture subprocess stdout")]
27    MissingStdout,
28    #[error("failed to capture subprocess stderr")]
29    MissingStderr,
30    #[error("invalid json-rpc envelope")]
31    InvalidEnvelope,
32    #[error("failed to serialize json-rpc message: {0}")]
33    Serialize(serde_json::Error),
34    #[error("failed to write subprocess stdin: {0}")]
35    Write(std::io::Error),
36    #[error("timeout waiting for response")]
37    Timeout,
38}
39
40#[derive(Debug)]
41pub enum PostOutcome {
42    Response(Value),
43    Accepted,
44}
45
46#[derive(Debug, Clone)]
47struct StreamMessage {
48    sequence: u64,
49    payload: Value,
50}
51
52#[derive(Debug)]
53pub struct AdapterRuntime {
54    stdin: Arc<Mutex<ChildStdin>>,
55    child: Arc<Mutex<Child>>,
56    pending: Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
57    sender: broadcast::Sender<StreamMessage>,
58    ring: Arc<Mutex<VecDeque<StreamMessage>>>,
59    sequence: Arc<AtomicU64>,
60    request_timeout: Duration,
61    shutting_down: AtomicBool,
62    spawned_at: Instant,
63    first_stdout: Arc<AtomicBool>,
64}
65
66impl AdapterRuntime {
67    pub async fn start(
68        launch: LaunchSpec,
69        request_timeout: Duration,
70    ) -> Result<Self, AdapterError> {
71        let spawn_start = Instant::now();
72
73        let mut command = Command::new(&launch.program);
74        command
75            .args(&launch.args)
76            .stdin(std::process::Stdio::piped())
77            .stdout(std::process::Stdio::piped())
78            .stderr(std::process::Stdio::piped());
79
80        for (key, value) in &launch.env {
81            command.env(key, value);
82        }
83
84        tracing::info!(
85            program = ?launch.program,
86            args = ?launch.args,
87            "spawning agent process"
88        );
89
90        let mut child = command.spawn().map_err(|err| {
91            tracing::error!(
92                program = ?launch.program,
93                error = %err,
94                "failed to spawn agent process"
95            );
96            AdapterError::Spawn(err)
97        })?;
98
99        let pid = child.id().unwrap_or(0);
100        let spawn_elapsed = spawn_start.elapsed();
101        tracing::info!(
102            pid = pid,
103            elapsed_ms = spawn_elapsed.as_millis() as u64,
104            "agent process spawned"
105        );
106
107        let stdin = child.stdin.take().ok_or(AdapterError::MissingStdin)?;
108        let stdout = child.stdout.take().ok_or(AdapterError::MissingStdout)?;
109        let stderr = child.stderr.take().ok_or(AdapterError::MissingStderr)?;
110
111        let (sender, _rx) = broadcast::channel(512);
112        let runtime = Self {
113            stdin: Arc::new(Mutex::new(stdin)),
114            child: Arc::new(Mutex::new(child)),
115            pending: Arc::new(Mutex::new(HashMap::new())),
116            sender,
117            ring: Arc::new(Mutex::new(VecDeque::with_capacity(RING_BUFFER_SIZE))),
118            sequence: Arc::new(AtomicU64::new(0)),
119            request_timeout,
120            shutting_down: AtomicBool::new(false),
121            spawned_at: spawn_start,
122            first_stdout: Arc::new(AtomicBool::new(false)),
123        };
124
125        runtime.spawn_stdout_loop(stdout);
126        runtime.spawn_stderr_loop(stderr);
127        runtime.spawn_exit_watcher();
128
129        Ok(runtime)
130    }
131
132    pub async fn post(&self, payload: Value) -> Result<PostOutcome, AdapterError> {
133        if !payload.is_object() {
134            return Err(AdapterError::InvalidEnvelope);
135        }
136
137        let method: String = payload
138            .get("method")
139            .and_then(|v| v.as_str())
140            .unwrap_or("<none>")
141            .to_string();
142        let has_method = payload.get("method").is_some();
143        let id = payload.get("id");
144
145        if has_method && id.is_some() {
146            let id_value = id.expect("checked");
147            let key = id_key(id_value);
148            let (tx, rx) = oneshot::channel();
149
150            let pending_count = self.pending.lock().await.len();
151            tracing::info!(
152                method = %method,
153                id = %key,
154                pending_count = pending_count,
155                "post: request → agent (awaiting response)"
156            );
157
158            self.pending.lock().await.insert(key.clone(), tx);
159
160            let write_start = Instant::now();
161            if let Err(err) = self.send_to_subprocess(&payload).await {
162                tracing::error!(
163                    method = %method,
164                    id = %key,
165                    error = %err,
166                    "post: failed to write to agent stdin"
167                );
168                self.pending.lock().await.remove(&key);
169                return Err(err);
170            }
171            let write_ms = write_start.elapsed().as_millis() as u64;
172            tracing::debug!(
173                method = %method,
174                id = %key,
175                write_ms = write_ms,
176                "post: stdin write complete, waiting for response"
177            );
178
179            let wait_start = Instant::now();
180            match tokio::time::timeout(self.request_timeout, rx).await {
181                Ok(Ok(response)) => {
182                    let wait_ms = wait_start.elapsed().as_millis() as u64;
183                    tracing::info!(
184                        method = %method,
185                        id = %key,
186                        response_ms = wait_ms,
187                        total_ms = write_ms + wait_ms,
188                        "post: got response from agent"
189                    );
190                    Ok(PostOutcome::Response(response))
191                }
192                Ok(Err(_)) => {
193                    let wait_ms = wait_start.elapsed().as_millis() as u64;
194                    tracing::error!(
195                        method = %method,
196                        id = %key,
197                        wait_ms = wait_ms,
198                        "post: response channel dropped (agent process may have exited)"
199                    );
200                    self.pending.lock().await.remove(&key);
201                    Err(AdapterError::Timeout)
202                }
203                Err(_) => {
204                    let pending_keys: Vec<String> =
205                        self.pending.lock().await.keys().cloned().collect();
206                    tracing::error!(
207                        method = %method,
208                        id = %key,
209                        timeout_ms = self.request_timeout.as_millis() as u64,
210                        age_ms = self.spawned_at.elapsed().as_millis() as u64,
211                        pending_keys = ?pending_keys,
212                        first_stdout_seen = self.first_stdout.load(Ordering::Relaxed),
213                        "post: TIMEOUT waiting for agent response"
214                    );
215                    self.pending.lock().await.remove(&key);
216                    Err(AdapterError::Timeout)
217                }
218            }
219        } else {
220            tracing::debug!(
221                method = %method,
222                "post: notification → agent (fire-and-forget)"
223            );
224            self.send_to_subprocess(&payload).await?;
225            Ok(PostOutcome::Accepted)
226        }
227    }
228
229    async fn subscribe(
230        &self,
231        last_event_id: Option<u64>,
232    ) -> (Vec<(u64, Value)>, broadcast::Receiver<StreamMessage>) {
233        let replay = {
234            let ring = self.ring.lock().await;
235            ring.iter()
236                .filter(|message| {
237                    if let Some(last_event_id) = last_event_id {
238                        message.sequence > last_event_id
239                    } else {
240                        true
241                    }
242                })
243                .map(|message| (message.sequence, message.payload.clone()))
244                .collect::<Vec<_>>()
245        };
246        (replay, self.sender.subscribe())
247    }
248
249    pub async fn sse_stream(
250        self: Arc<Self>,
251        last_event_id: Option<u64>,
252    ) -> impl Stream<Item = Result<Event, Infallible>> + Send + 'static {
253        let (replay, rx) = self.subscribe(last_event_id).await;
254        let replay_stream = stream::iter(replay.into_iter().map(|(sequence, payload)| {
255            let event = Event::default()
256                .event("message")
257                .id(sequence.to_string())
258                .data(payload.to_string());
259            Ok(event)
260        }));
261
262        let live_stream = BroadcastStream::new(rx).filter_map(|item| async move {
263            match item {
264                Ok(message) => {
265                    let event = Event::default()
266                        .event("message")
267                        .id(message.sequence.to_string())
268                        .data(message.payload.to_string());
269                    Some(Ok(event))
270                }
271                Err(_) => None,
272            }
273        });
274
275        replay_stream.chain(live_stream)
276    }
277
278    /// Stream of raw JSON-RPC `Value` payloads (without SSE framing).
279    /// Useful for consumers that need to inspect the payload contents
280    /// rather than forward them as SSE events.
281    pub async fn value_stream(
282        self: Arc<Self>,
283        last_event_id: Option<u64>,
284    ) -> impl Stream<Item = Value> + Send + 'static {
285        let (replay, rx) = self.subscribe(last_event_id).await;
286        let replay_stream = stream::iter(replay.into_iter().map(|(_sequence, payload)| payload));
287        let live_stream = BroadcastStream::new(rx).filter_map(|item| async move {
288            match item {
289                Ok(message) => Some(message.payload),
290                Err(_) => None,
291            }
292        });
293        replay_stream.chain(live_stream)
294    }
295
296    pub async fn shutdown(&self) {
297        if self.shutting_down.swap(true, Ordering::SeqCst) {
298            return;
299        }
300
301        tracing::info!(
302            age_ms = self.spawned_at.elapsed().as_millis() as u64,
303            "shutting down agent process"
304        );
305
306        self.pending.lock().await.clear();
307        let mut child = self.child.lock().await;
308        match child.try_wait() {
309            Ok(Some(_)) => {}
310            Ok(None) => {
311                let _ = child.kill().await;
312                let _ = child.wait().await;
313            }
314            Err(_) => {
315                let _ = child.kill().await;
316            }
317        }
318    }
319
320    fn spawn_stdout_loop(&self, stdout: tokio::process::ChildStdout) {
321        let pending = self.pending.clone();
322        let sender = self.sender.clone();
323        let ring = self.ring.clone();
324        let sequence = self.sequence.clone();
325        let spawned_at = self.spawned_at;
326        let first_stdout = self.first_stdout.clone();
327
328        tokio::spawn(async move {
329            let mut lines = BufReader::new(stdout).lines();
330            let mut line_count: u64 = 0;
331
332            while let Ok(Some(line)) = lines.next_line().await {
333                let trimmed = line.trim();
334                if trimmed.is_empty() {
335                    continue;
336                }
337
338                line_count += 1;
339
340                if !first_stdout.swap(true, Ordering::Relaxed) {
341                    tracing::info!(
342                        first_stdout_ms = spawned_at.elapsed().as_millis() as u64,
343                        line_bytes = trimmed.len(),
344                        "agent process: first stdout line received"
345                    );
346                }
347
348                let payload = match serde_json::from_str::<Value>(trimmed) {
349                    Ok(payload) => payload,
350                    Err(err) => {
351                        tracing::warn!(
352                            error = %err,
353                            line_number = line_count,
354                            raw = %if trimmed.len() > 200 {
355                                format!("{}...", &trimmed[..200])
356                            } else {
357                                trimmed.to_string()
358                            },
359                            "agent stdout: invalid JSON"
360                        );
361                        json!({
362                            "jsonrpc": "2.0",
363                            "method": "_adapter/invalid_stdout",
364                            "params": {
365                                "error": err.to_string(),
366                                "raw": trimmed,
367                            }
368                        })
369                    }
370                };
371
372                let is_response = payload.get("id").is_some() && payload.get("method").is_none();
373                if is_response {
374                    let key = id_key(payload.get("id").expect("checked"));
375                    let has_error = payload.get("error").is_some();
376                    if let Some(tx) = pending.lock().await.remove(&key) {
377                        tracing::debug!(
378                            id = %key,
379                            has_error = has_error,
380                            age_ms = spawned_at.elapsed().as_millis() as u64,
381                            "agent stdout: response matched to pending request"
382                        );
383                        let _ = tx.send(payload.clone());
384                        // Also broadcast the response so SSE/notification subscribers
385                        // see it in order after preceding notifications. This lets the
386                        // SSE translation task detect turn completion after all
387                        // session/update events have been processed.
388                        let seq = sequence.fetch_add(1, Ordering::SeqCst) + 1;
389                        let message = StreamMessage {
390                            sequence: seq,
391                            payload,
392                        };
393                        {
394                            let mut guard = ring.lock().await;
395                            guard.push_back(message.clone());
396                            while guard.len() > RING_BUFFER_SIZE {
397                                guard.pop_front();
398                            }
399                        }
400                        let _ = sender.send(message);
401                        continue;
402                    } else {
403                        tracing::warn!(
404                            id = %key,
405                            has_error = has_error,
406                            "agent stdout: response has no matching pending request (orphan)"
407                        );
408                    }
409                }
410
411                let method = payload
412                    .get("method")
413                    .and_then(|v| v.as_str())
414                    .unwrap_or("<none>");
415                tracing::debug!(
416                    method = method,
417                    line_number = line_count,
418                    "agent stdout: notification/event → SSE broadcast"
419                );
420
421                let seq = sequence.fetch_add(1, Ordering::SeqCst) + 1;
422                let message = StreamMessage {
423                    sequence: seq,
424                    payload,
425                };
426
427                {
428                    let mut guard = ring.lock().await;
429                    guard.push_back(message.clone());
430                    while guard.len() > RING_BUFFER_SIZE {
431                        guard.pop_front();
432                    }
433                }
434
435                let _ = sender.send(message);
436            }
437
438            tracing::info!(
439                total_lines = line_count,
440                age_ms = spawned_at.elapsed().as_millis() as u64,
441                "agent stdout: stream ended"
442            );
443        });
444    }
445
446    fn spawn_stderr_loop(&self, stderr: tokio::process::ChildStderr) {
447        let spawned_at = self.spawned_at;
448
449        tokio::spawn(async move {
450            let mut lines = BufReader::new(stderr).lines();
451            let mut line_count: u64 = 0;
452
453            while let Ok(Some(line)) = lines.next_line().await {
454                line_count += 1;
455                tracing::info!(
456                    line_number = line_count,
457                    age_ms = spawned_at.elapsed().as_millis() as u64,
458                    "agent stderr: {}",
459                    line
460                );
461            }
462
463            tracing::debug!(
464                total_lines = line_count,
465                age_ms = spawned_at.elapsed().as_millis() as u64,
466                "agent stderr: stream ended"
467            );
468        });
469    }
470
471    fn spawn_exit_watcher(&self) {
472        let child = self.child.clone();
473        let sender = self.sender.clone();
474        let ring = self.ring.clone();
475        let sequence = self.sequence.clone();
476        let spawned_at = self.spawned_at;
477        let pending = self.pending.clone();
478
479        tokio::spawn(async move {
480            let status = {
481                let mut guard = child.lock().await;
482                guard.wait().await.ok()
483            };
484
485            let age_ms = spawned_at.elapsed().as_millis() as u64;
486            let pending_count = pending.lock().await.len();
487
488            if let Some(status) = status {
489                tracing::warn!(
490                    success = status.success(),
491                    code = status.code(),
492                    age_ms = age_ms,
493                    pending_requests = pending_count,
494                    "agent process exited"
495                );
496
497                let payload = json!({
498                    "jsonrpc": "2.0",
499                    "method": "_adapter/agent_exited",
500                    "params": {
501                        "success": status.success(),
502                        "code": status.code(),
503                    }
504                });
505
506                let seq = sequence.fetch_add(1, Ordering::SeqCst) + 1;
507                let message = StreamMessage {
508                    sequence: seq,
509                    payload,
510                };
511
512                {
513                    let mut guard = ring.lock().await;
514                    guard.push_back(message.clone());
515                    while guard.len() > RING_BUFFER_SIZE {
516                        guard.pop_front();
517                    }
518                }
519
520                let _ = sender.send(message);
521            } else {
522                tracing::error!(
523                    age_ms = age_ms,
524                    pending_requests = pending_count,
525                    "agent process: failed to get exit status"
526                );
527            }
528        });
529    }
530
531    async fn send_to_subprocess(&self, payload: &Value) -> Result<(), AdapterError> {
532        let method = payload
533            .get("method")
534            .and_then(|v| v.as_str())
535            .unwrap_or("<none>");
536        let id = payload.get("id").map(|v| v.to_string()).unwrap_or_default();
537
538        tracing::debug!(
539            method = method,
540            id = %id,
541            bytes = serde_json::to_vec(payload).map(|b| b.len()).unwrap_or(0),
542            "stdin: writing message to agent"
543        );
544
545        let mut stdin = self.stdin.lock().await;
546        let bytes = serde_json::to_vec(payload).map_err(AdapterError::Serialize)?;
547        stdin.write_all(&bytes).await.map_err(|err| {
548            tracing::error!(method = method, id = %id, error = %err, "stdin: write_all failed");
549            AdapterError::Write(err)
550        })?;
551        stdin.write_all(b"\n").await.map_err(|err| {
552            tracing::error!(method = method, id = %id, error = %err, "stdin: newline write failed");
553            AdapterError::Write(err)
554        })?;
555        stdin.flush().await.map_err(|err| {
556            tracing::error!(method = method, id = %id, error = %err, "stdin: flush failed");
557            AdapterError::Write(err)
558        })?;
559
560        tracing::debug!(method = method, id = %id, "stdin: write+flush complete");
561        Ok(())
562    }
563}
564
565fn id_key(value: &Value) -> String {
566    serde_json::to_string(value).unwrap_or_else(|_| "null".to_string())
567}