fresh/services/remote/channel.rs
1//! Agent communication channel
2//!
3//! Handles request/response multiplexing over SSH stdin/stdout.
4//! Supports transport hot-swapping for automatic reconnection:
5//! the read/write tasks survive connection drops and resume when
6//! a new transport is provided via `replace_transport()`.
7
8use crate::services::remote::protocol::{AgentRequest, AgentResponse};
9use std::collections::HashMap;
10use std::io;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
15use tokio::sync::{mpsc, oneshot};
16use tracing::warn;
17
18/// Default capacity for the per-request streaming data channel.
19const DEFAULT_DATA_CHANNEL_CAPACITY: usize = 64;
20
21/// Default timeout for remote requests. If a response is not received within
22/// this duration, the request fails with `ChannelError::Timeout` and the
23/// connection is marked as disconnected.
24const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
25
26/// Test-only: microseconds to sleep in the consumer loop between chunks.
27/// Set to a non-zero value from tests to simulate a slow consumer and
28/// deterministically reproduce channel backpressure scenarios.
29/// Always compiled (not cfg(test)) because integration tests need access.
30pub static TEST_RECV_DELAY_US: AtomicU64 = AtomicU64::new(0);
31
32/// Error type for channel operations
33#[derive(Debug, thiserror::Error)]
34pub enum ChannelError {
35 #[error("IO error: {0}")]
36 Io(#[from] io::Error),
37
38 #[error("JSON error: {0}")]
39 Json(#[from] serde_json::Error),
40
41 #[error("Channel closed")]
42 ChannelClosed,
43
44 #[error("Request cancelled")]
45 Cancelled,
46
47 #[error("Request timed out")]
48 Timeout,
49
50 #[error("Remote error: {0}")]
51 Remote(String),
52}
53
54/// Pending request state
55struct PendingRequest {
56 /// Channel for streaming data
57 data_tx: mpsc::Sender<serde_json::Value>,
58 /// Channel for final result
59 result_tx: oneshot::Sender<Result<serde_json::Value, String>>,
60}
61
62/// Boxed async reader type used by the read task.
63type BoxedReader = Box<dyn AsyncBufRead + Unpin + Send>;
64/// Boxed async writer type used by the write task.
65type BoxedWriter = Box<dyn AsyncWrite + Unpin + Send>;
66
67/// Communication channel with the remote agent
68pub struct AgentChannel {
69 /// Sender to the write task
70 write_tx: mpsc::Sender<String>,
71 /// Pending requests awaiting responses
72 pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
73 /// Next request ID
74 next_id: AtomicU64,
75 /// Whether the channel is connected
76 connected: Arc<std::sync::atomic::AtomicBool>,
77 /// Runtime handle for blocking operations
78 runtime_handle: tokio::runtime::Handle,
79 /// Capacity for per-request streaming data channels
80 data_channel_capacity: usize,
81 /// Timeout for individual requests (stored as milliseconds for atomic access)
82 request_timeout_ms: AtomicU64,
83 /// Sender to deliver a new reader to the read task after reconnection
84 new_reader_tx: mpsc::Sender<BoxedReader>,
85 /// Sender to deliver a new writer to the write task after reconnection
86 new_writer_tx: mpsc::Sender<BoxedWriter>,
87}
88
89impl AgentChannel {
90 /// Create a new channel from async read/write handles
91 ///
92 /// Must be called from within a Tokio runtime context.
93 pub fn new(
94 reader: tokio::io::BufReader<tokio::process::ChildStdout>,
95 writer: tokio::process::ChildStdin,
96 ) -> Self {
97 Self::with_capacity(reader, writer, DEFAULT_DATA_CHANNEL_CAPACITY)
98 }
99
100 /// Create a new channel with a custom data channel capacity.
101 ///
102 /// Lower capacity makes channel overflow more likely if `try_send` is used,
103 /// which is useful for stress-testing backpressure handling.
104 pub fn with_capacity(
105 reader: tokio::io::BufReader<tokio::process::ChildStdout>,
106 writer: tokio::process::ChildStdin,
107 data_channel_capacity: usize,
108 ) -> Self {
109 Self::from_transport(reader, writer, data_channel_capacity)
110 }
111
112 /// Create a new channel from any async reader/writer pair.
113 ///
114 /// This is the generic constructor used by both production code (via
115 /// `new`/`with_capacity`) and tests (via arbitrary `AsyncBufRead`/`AsyncWrite`
116 /// implementations like `DuplexStream`).
117 ///
118 /// Must be called from within a Tokio runtime context.
119 pub fn from_transport<R, W>(reader: R, writer: W, data_channel_capacity: usize) -> Self
120 where
121 R: AsyncBufRead + Unpin + Send + 'static,
122 W: AsyncWrite + Unpin + Send + 'static,
123 {
124 let pending: Arc<Mutex<HashMap<u64, PendingRequest>>> =
125 Arc::new(Mutex::new(HashMap::new()));
126 let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
127 let runtime_handle = tokio::runtime::Handle::current();
128
129 // Channel for outgoing requests (lives for the lifetime of the AgentChannel)
130 let (write_tx, write_rx) = mpsc::channel::<String>(64);
131
132 // Channels for delivering replacement transports on reconnection.
133 // Capacity 1: at most one pending reconnection at a time.
134 let (new_reader_tx, new_reader_rx) = mpsc::channel::<BoxedReader>(1);
135 let (new_writer_tx, new_writer_rx) = mpsc::channel::<BoxedWriter>(1);
136
137 // Spawn write task (lives for the lifetime of the AgentChannel)
138 let connected_write = connected.clone();
139 tokio::spawn(Self::write_task(
140 Box::new(writer),
141 write_rx,
142 new_writer_rx,
143 connected_write,
144 ));
145
146 // Spawn read task (lives for the lifetime of the AgentChannel)
147 let pending_read = pending.clone();
148 let connected_read = connected.clone();
149 tokio::spawn(Self::read_task(
150 Box::new(reader),
151 new_reader_rx,
152 pending_read,
153 connected_read,
154 ));
155
156 Self {
157 write_tx,
158 pending,
159 next_id: AtomicU64::new(1),
160 connected,
161 runtime_handle,
162 data_channel_capacity,
163 request_timeout_ms: AtomicU64::new(DEFAULT_REQUEST_TIMEOUT.as_millis() as u64),
164 new_reader_tx,
165 new_writer_tx,
166 }
167 }
168
169 /// Long-lived write task. Reads outgoing messages from `write_rx` and
170 /// writes them to the current transport. On transport error or when a new
171 /// transport arrives via `new_writer_rx`, switches to the new writer.
172 async fn write_task(
173 mut writer: BoxedWriter,
174 mut write_rx: mpsc::Receiver<String>,
175 mut new_writer_rx: mpsc::Receiver<BoxedWriter>,
176 connected: Arc<std::sync::atomic::AtomicBool>,
177 ) {
178 loop {
179 tokio::select! {
180 // Normal path: send outgoing message
181 msg = write_rx.recv() => {
182 let Some(msg) = msg else { break }; // AgentChannel dropped
183
184 let write_ok = writer.write_all(msg.as_bytes()).await.is_ok()
185 && writer.flush().await.is_ok();
186
187 if !write_ok {
188 connected.store(false, Ordering::SeqCst);
189 // Wait for replacement (can't select here, just block)
190 match new_writer_rx.recv().await {
191 Some(new_writer) => { writer = new_writer; continue; }
192 None => break,
193 }
194 }
195 }
196 // Reconnection: new transport arrived, switch immediately
197 new_writer = new_writer_rx.recv() => {
198 match new_writer {
199 Some(w) => { writer = w; }
200 None => break, // AgentChannel dropped
201 }
202 }
203 }
204 }
205 }
206
207 /// Long-lived read task. Reads responses from the current transport and
208 /// dispatches them to pending requests. On transport error or when a new
209 /// transport arrives, cleans up pending requests and switches readers.
210 async fn read_task(
211 mut reader: BoxedReader,
212 mut new_reader_rx: mpsc::Receiver<BoxedReader>,
213 pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
214 connected: Arc<std::sync::atomic::AtomicBool>,
215 ) {
216 let mut line = String::new();
217
218 loop {
219 line.clear();
220
221 tokio::select! {
222 read_result = reader.read_line(&mut line) => {
223 match read_result {
224 Ok(0) | Err(_) => {
225 // EOF or error — transport is dead
226 connected.store(false, Ordering::SeqCst);
227 Self::drain_pending(&pending);
228
229 // Wait for replacement reader
230 match new_reader_rx.recv().await {
231 Some(new_reader) => { reader = new_reader; continue; }
232 None => break,
233 }
234 }
235 Ok(_) => {
236 if let Ok(resp) = serde_json::from_str::<AgentResponse>(&line) {
237 Self::handle_response(&pending, resp).await;
238 }
239 }
240 }
241 }
242 // Reconnection: new transport arrived, switch immediately.
243 // Drain pending requests from the old connection first —
244 // they were sent to the old agent and won't get responses
245 // on the new one. Then mark connected so new requests can
246 // be submitted.
247 new_reader = new_reader_rx.recv() => {
248 match new_reader {
249 Some(r) => {
250 Self::drain_pending(&pending);
251 reader = r;
252 connected.store(true, Ordering::SeqCst);
253 }
254 None => break, // AgentChannel dropped
255 }
256 }
257 }
258 }
259 }
260
261 /// Fail all pending requests with "connection closed" so callers don't hang.
262 fn drain_pending(pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>) {
263 let mut pending = pending.lock().unwrap();
264 for (id, req) in pending.drain() {
265 match req.result_tx.send(Err("connection closed".to_string())) {
266 Ok(()) => {}
267 Err(_) => {
268 warn!("request {id}: receiver dropped during disconnect cleanup");
269 }
270 }
271 }
272 }
273
274 /// Handle an incoming response.
275 ///
276 /// For streaming data, uses `send().await` to apply backpressure when the
277 /// consumer is slower than the producer. This prevents silent data loss
278 /// that occurred with `try_send` (#1059).
279 async fn handle_response(
280 pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>,
281 resp: AgentResponse,
282 ) {
283 // Send streaming data without holding the mutex (send().await may yield)
284 if let Some(data) = resp.data {
285 let data_tx = {
286 let pending = pending.lock().unwrap();
287 pending.get(&resp.id).map(|req| req.data_tx.clone())
288 };
289 if let Some(tx) = data_tx {
290 // send().await blocks until the consumer drains a slot, providing
291 // backpressure instead of silently dropping data.
292 if tx.send(data).await.is_err() {
293 // Receiver was dropped — this is unexpected since callers
294 // should hold data_rx until the stream ends. Clean up the
295 // pending entry to avoid leaking the dead request.
296 warn!("request {}: data receiver dropped mid-stream", resp.id);
297 let mut pending = pending.lock().unwrap();
298 pending.remove(&resp.id);
299 return;
300 }
301 }
302 }
303
304 // Handle final result/error
305 if resp.result.is_some() || resp.error.is_some() {
306 let mut pending = pending.lock().unwrap();
307 if let Some(req) = pending.remove(&resp.id) {
308 let outcome = if let Some(result) = resp.result {
309 req.result_tx.send(Ok(result))
310 } else if let Some(error) = resp.error {
311 req.result_tx.send(Err(error))
312 } else {
313 // resp matched the outer condition (result or error is Some)
314 // but neither branch fired — unreachable by construction.
315 return;
316 };
317 match outcome {
318 Ok(()) => {}
319 Err(_) => {
320 // Receiver was dropped — this is unexpected since
321 // callers should hold result_rx until they get a result.
322 warn!("request {}: result receiver dropped", resp.id);
323 }
324 }
325 }
326 }
327 }
328
329 /// Check if the channel is connected
330 pub fn is_connected(&self) -> bool {
331 self.connected.load(Ordering::SeqCst)
332 }
333
334 /// Replace the underlying transport with a new reader/writer pair.
335 ///
336 /// This is used for reconnection: after establishing a new SSH connection,
337 /// call this method to feed the new stdin/stdout to the existing read/write
338 /// tasks. The tasks will resume processing and `is_connected()` will return
339 /// `true` once the first successful read/write completes.
340 ///
341 /// The `connected` flag is set to `true` by the read task after it has
342 /// received the new reader and drained stale pending requests. This
343 /// ensures no race between draining and new request submission.
344 pub async fn replace_transport<R, W>(&self, reader: R, writer: W)
345 where
346 R: AsyncBufRead + Unpin + Send + 'static,
347 W: AsyncWrite + Unpin + Send + 'static,
348 {
349 // Send new transports to the tasks. Order matters: send writer first
350 // so the write task is ready before the read task marks connected
351 // (which allows new requests to flow).
352 // Send can only fail if the task exited (AgentChannel dropped).
353 if self.new_writer_tx.send(Box::new(writer)).await.is_err() {
354 warn!("replace_transport: write task is gone, cannot reconnect");
355 return;
356 }
357 if self.new_reader_tx.send(Box::new(reader)).await.is_err() {
358 warn!("replace_transport: read task is gone, cannot reconnect");
359 }
360 // Note: connected is set to true by the read task after it drains
361 // stale pending requests and switches to the new reader.
362 }
363
364 /// Replace the underlying transport (blocking version for non-async contexts).
365 ///
366 /// Sends the new transport to the tasks and waits until the channel is
367 /// marked as connected (i.e., the read task has drained stale requests
368 /// and is ready to receive responses on the new reader).
369 pub fn replace_transport_blocking<R, W>(&self, reader: R, writer: W)
370 where
371 R: AsyncBufRead + Unpin + Send + 'static,
372 W: AsyncWrite + Unpin + Send + 'static,
373 {
374 self.runtime_handle
375 .block_on(self.replace_transport(reader, writer));
376
377 // Yield until the read task has processed the new reader.
378 // This is typically immediate since the channel send above wakes
379 // the read task's select!, which drains pending and sets connected.
380 while !self.is_connected() {
381 std::thread::yield_now();
382 }
383 }
384
385 /// Set the request timeout duration.
386 ///
387 /// Requests that don't receive a response within this duration will fail
388 /// with `ChannelError::Timeout` and the connection will be marked as
389 /// disconnected.
390 pub fn set_request_timeout(&self, timeout: Duration) {
391 self.request_timeout_ms
392 .store(timeout.as_millis() as u64, Ordering::SeqCst);
393 }
394
395 /// Get the current request timeout duration.
396 fn request_timeout(&self) -> Duration {
397 Duration::from_millis(self.request_timeout_ms.load(Ordering::SeqCst))
398 }
399
400 /// Send a request and wait for the final result (ignoring streaming data)
401 pub async fn request(
402 &self,
403 method: &str,
404 params: serde_json::Value,
405 ) -> Result<serde_json::Value, ChannelError> {
406 let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
407
408 let timeout = self.request_timeout();
409
410 // Drain streaming data and wait for final result, with timeout.
411 let result = tokio::time::timeout(timeout, async {
412 while data_rx.recv().await.is_some() {}
413 result_rx
414 .await
415 .map_err(|_| ChannelError::ChannelClosed)?
416 .map_err(ChannelError::Remote)
417 })
418 .await;
419
420 match result {
421 Ok(inner) => inner,
422 Err(_elapsed) => {
423 warn!("request '{}' timed out after {:?}", method, timeout);
424 self.connected.store(false, Ordering::SeqCst);
425 Err(ChannelError::Timeout)
426 }
427 }
428 }
429
430 /// Send a request that may stream data
431 pub async fn request_streaming(
432 &self,
433 method: &str,
434 params: serde_json::Value,
435 ) -> Result<
436 (
437 mpsc::Receiver<serde_json::Value>,
438 oneshot::Receiver<Result<serde_json::Value, String>>,
439 ),
440 ChannelError,
441 > {
442 if !self.is_connected() {
443 return Err(ChannelError::ChannelClosed);
444 }
445
446 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
447
448 // Create channels for response
449 let (data_tx, data_rx) = mpsc::channel(self.data_channel_capacity);
450 let (result_tx, result_rx) = oneshot::channel();
451
452 // Register pending request
453 {
454 let mut pending = self.pending.lock().unwrap();
455 pending.insert(id, PendingRequest { data_tx, result_tx });
456 }
457
458 // Build and send request
459 let req = AgentRequest::new(id, method, params);
460 self.write_tx
461 .send(req.to_json_line())
462 .await
463 .map_err(|_| ChannelError::ChannelClosed)?;
464
465 Ok((data_rx, result_rx))
466 }
467
468 /// Send a request synchronously (blocking)
469 ///
470 /// This can be called from outside the Tokio runtime context.
471 pub fn request_blocking(
472 &self,
473 method: &str,
474 params: serde_json::Value,
475 ) -> Result<serde_json::Value, ChannelError> {
476 self.runtime_handle.block_on(self.request(method, params))
477 }
478
479 /// Send a request and collect all streaming data along with the final result
480 pub async fn request_with_data(
481 &self,
482 method: &str,
483 params: serde_json::Value,
484 ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
485 let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
486
487 let timeout = self.request_timeout();
488
489 let result = tokio::time::timeout(timeout, async {
490 // Collect all streaming data
491 let mut data = Vec::new();
492 while let Some(chunk) = data_rx.recv().await {
493 data.push(chunk);
494
495 // Test hook: simulate slow consumer for backpressure testing.
496 // Zero-cost in production (atomic load + branch-not-taken).
497 let delay_us = TEST_RECV_DELAY_US.load(Ordering::Relaxed);
498 if delay_us > 0 {
499 tokio::time::sleep(tokio::time::Duration::from_micros(delay_us)).await;
500 }
501 }
502
503 // Wait for final result
504 let result = result_rx
505 .await
506 .map_err(|_| ChannelError::ChannelClosed)?
507 .map_err(ChannelError::Remote)?;
508
509 Ok((data, result))
510 })
511 .await;
512
513 match result {
514 Ok(inner) => inner,
515 Err(_elapsed) => {
516 warn!("streaming request timed out after {:?}", timeout);
517 self.connected.store(false, Ordering::SeqCst);
518 Err(ChannelError::Timeout)
519 }
520 }
521 }
522
523 /// Send a request with streaming data, synchronously (blocking)
524 ///
525 /// This can be called from outside the Tokio runtime context.
526 pub fn request_with_data_blocking(
527 &self,
528 method: &str,
529 params: serde_json::Value,
530 ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
531 self.runtime_handle
532 .block_on(self.request_with_data(method, params))
533 }
534
535 /// Cancel a request
536 pub async fn cancel(&self, request_id: u64) -> Result<(), ChannelError> {
537 use crate::services::remote::protocol::cancel_params;
538 self.request("cancel", cancel_params(request_id)).await?;
539 Ok(())
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 // Tests are in the tests module to allow integration testing with mock agent
546}