1use std::collections::{HashMap, VecDeque};
2use std::convert::Infallible;
3use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use axum::response::sse::Event;
8use futures::{stream, Stream, StreamExt};
9use serde_json::{json, Value};
10use thiserror::Error;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tokio::process::{Child, ChildStdin, Command};
13use tokio::sync::{broadcast, oneshot, Mutex};
14use tokio_stream::wrappers::BroadcastStream;
15
16use crate::registry::LaunchSpec;
17
18const RING_BUFFER_SIZE: usize = 1024;
19
20#[derive(Debug, Error)]
21pub enum AdapterError {
22 #[error("failed to spawn subprocess: {0}")]
23 Spawn(std::io::Error),
24 #[error("failed to capture subprocess stdin")]
25 MissingStdin,
26 #[error("failed to capture subprocess stdout")]
27 MissingStdout,
28 #[error("failed to capture subprocess stderr")]
29 MissingStderr,
30 #[error("invalid json-rpc envelope")]
31 InvalidEnvelope,
32 #[error("failed to serialize json-rpc message: {0}")]
33 Serialize(serde_json::Error),
34 #[error("failed to write subprocess stdin: {0}")]
35 Write(std::io::Error),
36 #[error("timeout waiting for response")]
37 Timeout,
38}
39
40#[derive(Debug)]
41pub enum PostOutcome {
42 Response(Value),
43 Accepted,
44}
45
46#[derive(Debug, Clone)]
47struct StreamMessage {
48 sequence: u64,
49 payload: Value,
50}
51
52#[derive(Debug)]
53pub struct AdapterRuntime {
54 stdin: Arc<Mutex<ChildStdin>>,
55 child: Arc<Mutex<Child>>,
56 pending: Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
57 sender: broadcast::Sender<StreamMessage>,
58 ring: Arc<Mutex<VecDeque<StreamMessage>>>,
59 sequence: Arc<AtomicU64>,
60 request_timeout: Duration,
61 shutting_down: AtomicBool,
62 spawned_at: Instant,
63 first_stdout: Arc<AtomicBool>,
64}
65
66impl AdapterRuntime {
67 pub async fn start(
68 launch: LaunchSpec,
69 request_timeout: Duration,
70 ) -> Result<Self, AdapterError> {
71 let spawn_start = Instant::now();
72
73 let mut command = Command::new(&launch.program);
74 command
75 .args(&launch.args)
76 .stdin(std::process::Stdio::piped())
77 .stdout(std::process::Stdio::piped())
78 .stderr(std::process::Stdio::piped());
79
80 for (key, value) in &launch.env {
81 command.env(key, value);
82 }
83
84 tracing::info!(
85 program = ?launch.program,
86 args = ?launch.args,
87 "spawning agent process"
88 );
89
90 let mut child = command.spawn().map_err(|err| {
91 tracing::error!(
92 program = ?launch.program,
93 error = %err,
94 "failed to spawn agent process"
95 );
96 AdapterError::Spawn(err)
97 })?;
98
99 let pid = child.id().unwrap_or(0);
100 let spawn_elapsed = spawn_start.elapsed();
101 tracing::info!(
102 pid = pid,
103 elapsed_ms = spawn_elapsed.as_millis() as u64,
104 "agent process spawned"
105 );
106
107 let stdin = child.stdin.take().ok_or(AdapterError::MissingStdin)?;
108 let stdout = child.stdout.take().ok_or(AdapterError::MissingStdout)?;
109 let stderr = child.stderr.take().ok_or(AdapterError::MissingStderr)?;
110
111 let (sender, _rx) = broadcast::channel(512);
112 let runtime = Self {
113 stdin: Arc::new(Mutex::new(stdin)),
114 child: Arc::new(Mutex::new(child)),
115 pending: Arc::new(Mutex::new(HashMap::new())),
116 sender,
117 ring: Arc::new(Mutex::new(VecDeque::with_capacity(RING_BUFFER_SIZE))),
118 sequence: Arc::new(AtomicU64::new(0)),
119 request_timeout,
120 shutting_down: AtomicBool::new(false),
121 spawned_at: spawn_start,
122 first_stdout: Arc::new(AtomicBool::new(false)),
123 };
124
125 runtime.spawn_stdout_loop(stdout);
126 runtime.spawn_stderr_loop(stderr);
127 runtime.spawn_exit_watcher();
128
129 Ok(runtime)
130 }
131
132 pub async fn post(&self, payload: Value) -> Result<PostOutcome, AdapterError> {
133 if !payload.is_object() {
134 return Err(AdapterError::InvalidEnvelope);
135 }
136
137 let method: String = payload
138 .get("method")
139 .and_then(|v| v.as_str())
140 .unwrap_or("<none>")
141 .to_string();
142 let has_method = payload.get("method").is_some();
143 let id = payload.get("id");
144
145 if has_method && id.is_some() {
146 let id_value = id.expect("checked");
147 let key = id_key(id_value);
148 let (tx, rx) = oneshot::channel();
149
150 let pending_count = self.pending.lock().await.len();
151 tracing::info!(
152 method = %method,
153 id = %key,
154 pending_count = pending_count,
155 "post: request → agent (awaiting response)"
156 );
157
158 self.pending.lock().await.insert(key.clone(), tx);
159
160 let write_start = Instant::now();
161 if let Err(err) = self.send_to_subprocess(&payload).await {
162 tracing::error!(
163 method = %method,
164 id = %key,
165 error = %err,
166 "post: failed to write to agent stdin"
167 );
168 self.pending.lock().await.remove(&key);
169 return Err(err);
170 }
171 let write_ms = write_start.elapsed().as_millis() as u64;
172 tracing::debug!(
173 method = %method,
174 id = %key,
175 write_ms = write_ms,
176 "post: stdin write complete, waiting for response"
177 );
178
179 let wait_start = Instant::now();
180 match tokio::time::timeout(self.request_timeout, rx).await {
181 Ok(Ok(response)) => {
182 let wait_ms = wait_start.elapsed().as_millis() as u64;
183 tracing::info!(
184 method = %method,
185 id = %key,
186 response_ms = wait_ms,
187 total_ms = write_ms + wait_ms,
188 "post: got response from agent"
189 );
190 Ok(PostOutcome::Response(response))
191 }
192 Ok(Err(_)) => {
193 let wait_ms = wait_start.elapsed().as_millis() as u64;
194 tracing::error!(
195 method = %method,
196 id = %key,
197 wait_ms = wait_ms,
198 "post: response channel dropped (agent process may have exited)"
199 );
200 self.pending.lock().await.remove(&key);
201 Err(AdapterError::Timeout)
202 }
203 Err(_) => {
204 let pending_keys: Vec<String> =
205 self.pending.lock().await.keys().cloned().collect();
206 tracing::error!(
207 method = %method,
208 id = %key,
209 timeout_ms = self.request_timeout.as_millis() as u64,
210 age_ms = self.spawned_at.elapsed().as_millis() as u64,
211 pending_keys = ?pending_keys,
212 first_stdout_seen = self.first_stdout.load(Ordering::Relaxed),
213 "post: TIMEOUT waiting for agent response"
214 );
215 self.pending.lock().await.remove(&key);
216 Err(AdapterError::Timeout)
217 }
218 }
219 } else {
220 tracing::debug!(
221 method = %method,
222 "post: notification → agent (fire-and-forget)"
223 );
224 self.send_to_subprocess(&payload).await?;
225 Ok(PostOutcome::Accepted)
226 }
227 }
228
229 async fn subscribe(
230 &self,
231 last_event_id: Option<u64>,
232 ) -> (Vec<(u64, Value)>, broadcast::Receiver<StreamMessage>) {
233 let replay = {
234 let ring = self.ring.lock().await;
235 ring.iter()
236 .filter(|message| {
237 if let Some(last_event_id) = last_event_id {
238 message.sequence > last_event_id
239 } else {
240 true
241 }
242 })
243 .map(|message| (message.sequence, message.payload.clone()))
244 .collect::<Vec<_>>()
245 };
246 (replay, self.sender.subscribe())
247 }
248
249 pub async fn sse_stream(
250 self: Arc<Self>,
251 last_event_id: Option<u64>,
252 ) -> impl Stream<Item = Result<Event, Infallible>> + Send + 'static {
253 let (replay, rx) = self.subscribe(last_event_id).await;
254 let replay_stream = stream::iter(replay.into_iter().map(|(sequence, payload)| {
255 let event = Event::default()
256 .event("message")
257 .id(sequence.to_string())
258 .data(payload.to_string());
259 Ok(event)
260 }));
261
262 let live_stream = BroadcastStream::new(rx).filter_map(|item| async move {
263 match item {
264 Ok(message) => {
265 let event = Event::default()
266 .event("message")
267 .id(message.sequence.to_string())
268 .data(message.payload.to_string());
269 Some(Ok(event))
270 }
271 Err(_) => None,
272 }
273 });
274
275 replay_stream.chain(live_stream)
276 }
277
278 pub async fn value_stream(
282 self: Arc<Self>,
283 last_event_id: Option<u64>,
284 ) -> impl Stream<Item = Value> + Send + 'static {
285 let (replay, rx) = self.subscribe(last_event_id).await;
286 let replay_stream = stream::iter(replay.into_iter().map(|(_sequence, payload)| payload));
287 let live_stream = BroadcastStream::new(rx).filter_map(|item| async move {
288 match item {
289 Ok(message) => Some(message.payload),
290 Err(_) => None,
291 }
292 });
293 replay_stream.chain(live_stream)
294 }
295
296 pub async fn shutdown(&self) {
297 if self.shutting_down.swap(true, Ordering::SeqCst) {
298 return;
299 }
300
301 tracing::info!(
302 age_ms = self.spawned_at.elapsed().as_millis() as u64,
303 "shutting down agent process"
304 );
305
306 self.pending.lock().await.clear();
307 let mut child = self.child.lock().await;
308 match child.try_wait() {
309 Ok(Some(_)) => {}
310 Ok(None) => {
311 let _ = child.kill().await;
312 let _ = child.wait().await;
313 }
314 Err(_) => {
315 let _ = child.kill().await;
316 }
317 }
318 }
319
320 fn spawn_stdout_loop(&self, stdout: tokio::process::ChildStdout) {
321 let pending = self.pending.clone();
322 let sender = self.sender.clone();
323 let ring = self.ring.clone();
324 let sequence = self.sequence.clone();
325 let spawned_at = self.spawned_at;
326 let first_stdout = self.first_stdout.clone();
327
328 tokio::spawn(async move {
329 let mut lines = BufReader::new(stdout).lines();
330 let mut line_count: u64 = 0;
331
332 while let Ok(Some(line)) = lines.next_line().await {
333 let trimmed = line.trim();
334 if trimmed.is_empty() {
335 continue;
336 }
337
338 line_count += 1;
339
340 if !first_stdout.swap(true, Ordering::Relaxed) {
341 tracing::info!(
342 first_stdout_ms = spawned_at.elapsed().as_millis() as u64,
343 line_bytes = trimmed.len(),
344 "agent process: first stdout line received"
345 );
346 }
347
348 let payload = match serde_json::from_str::<Value>(trimmed) {
349 Ok(payload) => payload,
350 Err(err) => {
351 tracing::warn!(
352 error = %err,
353 line_number = line_count,
354 raw = %if trimmed.len() > 200 {
355 format!("{}...", &trimmed[..200])
356 } else {
357 trimmed.to_string()
358 },
359 "agent stdout: invalid JSON"
360 );
361 json!({
362 "jsonrpc": "2.0",
363 "method": "_adapter/invalid_stdout",
364 "params": {
365 "error": err.to_string(),
366 "raw": trimmed,
367 }
368 })
369 }
370 };
371
372 let is_response = payload.get("id").is_some() && payload.get("method").is_none();
373 if is_response {
374 let key = id_key(payload.get("id").expect("checked"));
375 let has_error = payload.get("error").is_some();
376 if let Some(tx) = pending.lock().await.remove(&key) {
377 tracing::debug!(
378 id = %key,
379 has_error = has_error,
380 age_ms = spawned_at.elapsed().as_millis() as u64,
381 "agent stdout: response matched to pending request"
382 );
383 let _ = tx.send(payload.clone());
384 let seq = sequence.fetch_add(1, Ordering::SeqCst) + 1;
389 let message = StreamMessage {
390 sequence: seq,
391 payload,
392 };
393 {
394 let mut guard = ring.lock().await;
395 guard.push_back(message.clone());
396 while guard.len() > RING_BUFFER_SIZE {
397 guard.pop_front();
398 }
399 }
400 let _ = sender.send(message);
401 continue;
402 } else {
403 tracing::warn!(
404 id = %key,
405 has_error = has_error,
406 "agent stdout: response has no matching pending request (orphan)"
407 );
408 }
409 }
410
411 let method = payload
412 .get("method")
413 .and_then(|v| v.as_str())
414 .unwrap_or("<none>");
415 tracing::debug!(
416 method = method,
417 line_number = line_count,
418 "agent stdout: notification/event → SSE broadcast"
419 );
420
421 let seq = sequence.fetch_add(1, Ordering::SeqCst) + 1;
422 let message = StreamMessage {
423 sequence: seq,
424 payload,
425 };
426
427 {
428 let mut guard = ring.lock().await;
429 guard.push_back(message.clone());
430 while guard.len() > RING_BUFFER_SIZE {
431 guard.pop_front();
432 }
433 }
434
435 let _ = sender.send(message);
436 }
437
438 tracing::info!(
439 total_lines = line_count,
440 age_ms = spawned_at.elapsed().as_millis() as u64,
441 "agent stdout: stream ended"
442 );
443 });
444 }
445
446 fn spawn_stderr_loop(&self, stderr: tokio::process::ChildStderr) {
447 let spawned_at = self.spawned_at;
448
449 tokio::spawn(async move {
450 let mut lines = BufReader::new(stderr).lines();
451 let mut line_count: u64 = 0;
452
453 while let Ok(Some(line)) = lines.next_line().await {
454 line_count += 1;
455 tracing::info!(
456 line_number = line_count,
457 age_ms = spawned_at.elapsed().as_millis() as u64,
458 "agent stderr: {}",
459 line
460 );
461 }
462
463 tracing::debug!(
464 total_lines = line_count,
465 age_ms = spawned_at.elapsed().as_millis() as u64,
466 "agent stderr: stream ended"
467 );
468 });
469 }
470
471 fn spawn_exit_watcher(&self) {
472 let child = self.child.clone();
473 let sender = self.sender.clone();
474 let ring = self.ring.clone();
475 let sequence = self.sequence.clone();
476 let spawned_at = self.spawned_at;
477 let pending = self.pending.clone();
478
479 tokio::spawn(async move {
480 let status = {
481 let mut guard = child.lock().await;
482 guard.wait().await.ok()
483 };
484
485 let age_ms = spawned_at.elapsed().as_millis() as u64;
486 let pending_count = pending.lock().await.len();
487
488 if let Some(status) = status {
489 tracing::warn!(
490 success = status.success(),
491 code = status.code(),
492 age_ms = age_ms,
493 pending_requests = pending_count,
494 "agent process exited"
495 );
496
497 let payload = json!({
498 "jsonrpc": "2.0",
499 "method": "_adapter/agent_exited",
500 "params": {
501 "success": status.success(),
502 "code": status.code(),
503 }
504 });
505
506 let seq = sequence.fetch_add(1, Ordering::SeqCst) + 1;
507 let message = StreamMessage {
508 sequence: seq,
509 payload,
510 };
511
512 {
513 let mut guard = ring.lock().await;
514 guard.push_back(message.clone());
515 while guard.len() > RING_BUFFER_SIZE {
516 guard.pop_front();
517 }
518 }
519
520 let _ = sender.send(message);
521 } else {
522 tracing::error!(
523 age_ms = age_ms,
524 pending_requests = pending_count,
525 "agent process: failed to get exit status"
526 );
527 }
528 });
529 }
530
531 async fn send_to_subprocess(&self, payload: &Value) -> Result<(), AdapterError> {
532 let method = payload
533 .get("method")
534 .and_then(|v| v.as_str())
535 .unwrap_or("<none>");
536 let id = payload.get("id").map(|v| v.to_string()).unwrap_or_default();
537
538 tracing::debug!(
539 method = method,
540 id = %id,
541 bytes = serde_json::to_vec(payload).map(|b| b.len()).unwrap_or(0),
542 "stdin: writing message to agent"
543 );
544
545 let mut stdin = self.stdin.lock().await;
546 let bytes = serde_json::to_vec(payload).map_err(AdapterError::Serialize)?;
547 stdin.write_all(&bytes).await.map_err(|err| {
548 tracing::error!(method = method, id = %id, error = %err, "stdin: write_all failed");
549 AdapterError::Write(err)
550 })?;
551 stdin.write_all(b"\n").await.map_err(|err| {
552 tracing::error!(method = method, id = %id, error = %err, "stdin: newline write failed");
553 AdapterError::Write(err)
554 })?;
555 stdin.flush().await.map_err(|err| {
556 tracing::error!(method = method, id = %id, error = %err, "stdin: flush failed");
557 AdapterError::Write(err)
558 })?;
559
560 tracing::debug!(method = method, id = %id, "stdin: write+flush complete");
561 Ok(())
562 }
563}
564
565fn id_key(value: &Value) -> String {
566 serde_json::to_string(value).unwrap_or_else(|_| "null".to_string())
567}