Skip to main content

fresh/services/remote/
channel.rs

1//! Agent communication channel
2//!
3//! Handles request/response multiplexing over SSH stdin/stdout.
4//! Supports transport hot-swapping for automatic reconnection:
5//! the read/write tasks survive connection drops and resume when
6//! a new transport is provided via `replace_transport()`.
7
8use crate::services::remote::protocol::{AgentRequest, AgentResponse};
9use std::collections::HashMap;
10use std::io;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
15use tokio::sync::{mpsc, oneshot};
16use tracing::warn;
17
18/// Default capacity for the per-request streaming data channel.
19const DEFAULT_DATA_CHANNEL_CAPACITY: usize = 64;
20
21/// Default timeout for remote requests. If a response is not received within
22/// this duration, the request fails with `ChannelError::Timeout` and the
23/// connection is marked as disconnected.
24const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
25
26/// Test-only: microseconds to sleep in the consumer loop between chunks.
27/// Set to a non-zero value from tests to simulate a slow consumer and
28/// deterministically reproduce channel backpressure scenarios.
29/// Always compiled (not cfg(test)) because integration tests need access.
30pub static TEST_RECV_DELAY_US: AtomicU64 = AtomicU64::new(0);
31
32/// Error type for channel operations
33#[derive(Debug, thiserror::Error)]
34pub enum ChannelError {
35    #[error("IO error: {0}")]
36    Io(#[from] io::Error),
37
38    #[error("JSON error: {0}")]
39    Json(#[from] serde_json::Error),
40
41    #[error("Channel closed")]
42    ChannelClosed,
43
44    #[error("Request cancelled")]
45    Cancelled,
46
47    #[error("Request timed out")]
48    Timeout,
49
50    #[error("Remote error: {0}")]
51    Remote(String),
52}
53
54/// Pending request state
55struct PendingRequest {
56    /// Channel for streaming data
57    data_tx: mpsc::Sender<serde_json::Value>,
58    /// Channel for final result
59    result_tx: oneshot::Sender<Result<serde_json::Value, String>>,
60}
61
62/// Boxed async reader type used by the read task.
63type BoxedReader = Box<dyn AsyncBufRead + Unpin + Send>;
64/// Boxed async writer type used by the write task.
65type BoxedWriter = Box<dyn AsyncWrite + Unpin + Send>;
66
67/// Communication channel with the remote agent
68pub struct AgentChannel {
69    /// Sender to the write task
70    write_tx: mpsc::Sender<String>,
71    /// Pending requests awaiting responses
72    pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
73    /// Next request ID
74    next_id: AtomicU64,
75    /// Whether the channel is connected
76    connected: Arc<std::sync::atomic::AtomicBool>,
77    /// Runtime handle for blocking operations
78    runtime_handle: tokio::runtime::Handle,
79    /// Capacity for per-request streaming data channels
80    data_channel_capacity: usize,
81    /// Timeout for individual requests (stored as milliseconds for atomic access)
82    request_timeout_ms: AtomicU64,
83    /// Sender to deliver a new reader to the read task after reconnection
84    new_reader_tx: mpsc::Sender<BoxedReader>,
85    /// Sender to deliver a new writer to the write task after reconnection
86    new_writer_tx: mpsc::Sender<BoxedWriter>,
87}
88
89impl AgentChannel {
90    /// Create a new channel from async read/write handles
91    ///
92    /// Must be called from within a Tokio runtime context.
93    pub fn new(
94        reader: tokio::io::BufReader<tokio::process::ChildStdout>,
95        writer: tokio::process::ChildStdin,
96    ) -> Self {
97        Self::with_capacity(reader, writer, DEFAULT_DATA_CHANNEL_CAPACITY)
98    }
99
100    /// Create a new channel with a custom data channel capacity.
101    ///
102    /// Lower capacity makes channel overflow more likely if `try_send` is used,
103    /// which is useful for stress-testing backpressure handling.
104    pub fn with_capacity(
105        reader: tokio::io::BufReader<tokio::process::ChildStdout>,
106        writer: tokio::process::ChildStdin,
107        data_channel_capacity: usize,
108    ) -> Self {
109        Self::from_transport(reader, writer, data_channel_capacity)
110    }
111
112    /// Create a new channel from any async reader/writer pair.
113    ///
114    /// This is the generic constructor used by both production code (via
115    /// `new`/`with_capacity`) and tests (via arbitrary `AsyncBufRead`/`AsyncWrite`
116    /// implementations like `DuplexStream`).
117    ///
118    /// Must be called from within a Tokio runtime context.
119    pub fn from_transport<R, W>(reader: R, writer: W, data_channel_capacity: usize) -> Self
120    where
121        R: AsyncBufRead + Unpin + Send + 'static,
122        W: AsyncWrite + Unpin + Send + 'static,
123    {
124        let pending: Arc<Mutex<HashMap<u64, PendingRequest>>> =
125            Arc::new(Mutex::new(HashMap::new()));
126        let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
127        let runtime_handle = tokio::runtime::Handle::current();
128
129        // Channel for outgoing requests (lives for the lifetime of the AgentChannel)
130        let (write_tx, write_rx) = mpsc::channel::<String>(64);
131
132        // Channels for delivering replacement transports on reconnection.
133        // Capacity 1: at most one pending reconnection at a time.
134        let (new_reader_tx, new_reader_rx) = mpsc::channel::<BoxedReader>(1);
135        let (new_writer_tx, new_writer_rx) = mpsc::channel::<BoxedWriter>(1);
136
137        // Spawn write task (lives for the lifetime of the AgentChannel)
138        let connected_write = connected.clone();
139        tokio::spawn(Self::write_task(
140            Box::new(writer),
141            write_rx,
142            new_writer_rx,
143            connected_write,
144        ));
145
146        // Spawn read task (lives for the lifetime of the AgentChannel)
147        let pending_read = pending.clone();
148        let connected_read = connected.clone();
149        tokio::spawn(Self::read_task(
150            Box::new(reader),
151            new_reader_rx,
152            pending_read,
153            connected_read,
154        ));
155
156        Self {
157            write_tx,
158            pending,
159            next_id: AtomicU64::new(1),
160            connected,
161            runtime_handle,
162            data_channel_capacity,
163            request_timeout_ms: AtomicU64::new(DEFAULT_REQUEST_TIMEOUT.as_millis() as u64),
164            new_reader_tx,
165            new_writer_tx,
166        }
167    }
168
169    /// Long-lived write task. Reads outgoing messages from `write_rx` and
170    /// writes them to the current transport. On transport error or when a new
171    /// transport arrives via `new_writer_rx`, switches to the new writer.
172    async fn write_task(
173        mut writer: BoxedWriter,
174        mut write_rx: mpsc::Receiver<String>,
175        mut new_writer_rx: mpsc::Receiver<BoxedWriter>,
176        connected: Arc<std::sync::atomic::AtomicBool>,
177    ) {
178        loop {
179            tokio::select! {
180                // Normal path: send outgoing message
181                msg = write_rx.recv() => {
182                    let Some(msg) = msg else { break }; // AgentChannel dropped
183
184                    let write_ok = writer.write_all(msg.as_bytes()).await.is_ok()
185                        && writer.flush().await.is_ok();
186
187                    if !write_ok {
188                        connected.store(false, Ordering::SeqCst);
189                        // Wait for replacement (can't select here, just block)
190                        match new_writer_rx.recv().await {
191                            Some(new_writer) => { writer = new_writer; continue; }
192                            None => break,
193                        }
194                    }
195                }
196                // Reconnection: new transport arrived, switch immediately
197                new_writer = new_writer_rx.recv() => {
198                    match new_writer {
199                        Some(w) => { writer = w; }
200                        None => break, // AgentChannel dropped
201                    }
202                }
203            }
204        }
205    }
206
207    /// Long-lived read task. Reads responses from the current transport and
208    /// dispatches them to pending requests. On transport error or when a new
209    /// transport arrives, cleans up pending requests and switches readers.
210    async fn read_task(
211        mut reader: BoxedReader,
212        mut new_reader_rx: mpsc::Receiver<BoxedReader>,
213        pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
214        connected: Arc<std::sync::atomic::AtomicBool>,
215    ) {
216        let mut line = String::new();
217
218        loop {
219            line.clear();
220
221            tokio::select! {
222                read_result = reader.read_line(&mut line) => {
223                    match read_result {
224                        Ok(0) | Err(_) => {
225                            // EOF or error — transport is dead
226                            connected.store(false, Ordering::SeqCst);
227                            Self::drain_pending(&pending);
228
229                            // Wait for replacement reader
230                            match new_reader_rx.recv().await {
231                                Some(new_reader) => { reader = new_reader; continue; }
232                                None => break,
233                            }
234                        }
235                        Ok(_) => {
236                            if let Ok(resp) = serde_json::from_str::<AgentResponse>(&line) {
237                                Self::handle_response(&pending, resp).await;
238                            }
239                        }
240                    }
241                }
242                // Reconnection: new transport arrived, switch immediately.
243                // Drain pending requests from the old connection first —
244                // they were sent to the old agent and won't get responses
245                // on the new one. Then mark connected so new requests can
246                // be submitted.
247                new_reader = new_reader_rx.recv() => {
248                    match new_reader {
249                        Some(r) => {
250                            Self::drain_pending(&pending);
251                            reader = r;
252                            connected.store(true, Ordering::SeqCst);
253                        }
254                        None => break, // AgentChannel dropped
255                    }
256                }
257            }
258        }
259    }
260
261    /// Fail all pending requests with "connection closed" so callers don't hang.
262    fn drain_pending(pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>) {
263        let mut pending = pending.lock().unwrap();
264        for (id, req) in pending.drain() {
265            match req.result_tx.send(Err("connection closed".to_string())) {
266                Ok(()) => {}
267                Err(_) => {
268                    warn!("request {id}: receiver dropped during disconnect cleanup");
269                }
270            }
271        }
272    }
273
274    /// Handle an incoming response.
275    ///
276    /// For streaming data, uses `send().await` to apply backpressure when the
277    /// consumer is slower than the producer. This prevents silent data loss
278    /// that occurred with `try_send` (#1059).
279    async fn handle_response(
280        pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>,
281        resp: AgentResponse,
282    ) {
283        // Send streaming data without holding the mutex (send().await may yield)
284        if let Some(data) = resp.data {
285            let data_tx = {
286                let pending = pending.lock().unwrap();
287                pending.get(&resp.id).map(|req| req.data_tx.clone())
288            };
289            if let Some(tx) = data_tx {
290                // send().await blocks until the consumer drains a slot, providing
291                // backpressure instead of silently dropping data.
292                if tx.send(data).await.is_err() {
293                    // Receiver was dropped — this is unexpected since callers
294                    // should hold data_rx until the stream ends. Clean up the
295                    // pending entry to avoid leaking the dead request.
296                    warn!("request {}: data receiver dropped mid-stream", resp.id);
297                    let mut pending = pending.lock().unwrap();
298                    pending.remove(&resp.id);
299                    return;
300                }
301            }
302        }
303
304        // Handle final result/error
305        if resp.result.is_some() || resp.error.is_some() {
306            let mut pending = pending.lock().unwrap();
307            if let Some(req) = pending.remove(&resp.id) {
308                let outcome = if let Some(result) = resp.result {
309                    req.result_tx.send(Ok(result))
310                } else if let Some(error) = resp.error {
311                    req.result_tx.send(Err(error))
312                } else {
313                    // resp matched the outer condition (result or error is Some)
314                    // but neither branch fired — unreachable by construction.
315                    return;
316                };
317                match outcome {
318                    Ok(()) => {}
319                    Err(_) => {
320                        // Receiver was dropped — this is unexpected since
321                        // callers should hold result_rx until they get a result.
322                        warn!("request {}: result receiver dropped", resp.id);
323                    }
324                }
325            }
326        }
327    }
328
329    /// Check if the channel is connected
330    pub fn is_connected(&self) -> bool {
331        self.connected.load(Ordering::SeqCst)
332    }
333
334    /// Replace the underlying transport with a new reader/writer pair.
335    ///
336    /// This is used for reconnection: after establishing a new SSH connection,
337    /// call this method to feed the new stdin/stdout to the existing read/write
338    /// tasks. The tasks will resume processing and `is_connected()` will return
339    /// `true` once the first successful read/write completes.
340    ///
341    /// The `connected` flag is set to `true` by the read task after it has
342    /// received the new reader and drained stale pending requests. This
343    /// ensures no race between draining and new request submission.
344    pub async fn replace_transport<R, W>(&self, reader: R, writer: W)
345    where
346        R: AsyncBufRead + Unpin + Send + 'static,
347        W: AsyncWrite + Unpin + Send + 'static,
348    {
349        // Send new transports to the tasks. Order matters: send writer first
350        // so the write task is ready before the read task marks connected
351        // (which allows new requests to flow).
352        // Send can only fail if the task exited (AgentChannel dropped).
353        if self.new_writer_tx.send(Box::new(writer)).await.is_err() {
354            warn!("replace_transport: write task is gone, cannot reconnect");
355            return;
356        }
357        if self.new_reader_tx.send(Box::new(reader)).await.is_err() {
358            warn!("replace_transport: read task is gone, cannot reconnect");
359        }
360        // Note: connected is set to true by the read task after it drains
361        // stale pending requests and switches to the new reader.
362    }
363
364    /// Replace the underlying transport (blocking version for non-async contexts).
365    ///
366    /// Sends the new transport to the tasks and waits until the channel is
367    /// marked as connected (i.e., the read task has drained stale requests
368    /// and is ready to receive responses on the new reader).
369    pub fn replace_transport_blocking<R, W>(&self, reader: R, writer: W)
370    where
371        R: AsyncBufRead + Unpin + Send + 'static,
372        W: AsyncWrite + Unpin + Send + 'static,
373    {
374        self.runtime_handle
375            .block_on(self.replace_transport(reader, writer));
376
377        // Yield until the read task has processed the new reader.
378        // This is typically immediate since the channel send above wakes
379        // the read task's select!, which drains pending and sets connected.
380        while !self.is_connected() {
381            std::thread::yield_now();
382        }
383    }
384
385    /// Set the request timeout duration.
386    ///
387    /// Requests that don't receive a response within this duration will fail
388    /// with `ChannelError::Timeout` and the connection will be marked as
389    /// disconnected.
390    pub fn set_request_timeout(&self, timeout: Duration) {
391        self.request_timeout_ms
392            .store(timeout.as_millis() as u64, Ordering::SeqCst);
393    }
394
395    /// Get the current request timeout duration.
396    fn request_timeout(&self) -> Duration {
397        Duration::from_millis(self.request_timeout_ms.load(Ordering::SeqCst))
398    }
399
400    /// Send a request and wait for the final result (ignoring streaming data)
401    pub async fn request(
402        &self,
403        method: &str,
404        params: serde_json::Value,
405    ) -> Result<serde_json::Value, ChannelError> {
406        let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
407
408        let timeout = self.request_timeout();
409
410        // Drain streaming data and wait for final result, with timeout.
411        let result = tokio::time::timeout(timeout, async {
412            while data_rx.recv().await.is_some() {}
413            result_rx
414                .await
415                .map_err(|_| ChannelError::ChannelClosed)?
416                .map_err(ChannelError::Remote)
417        })
418        .await;
419
420        match result {
421            Ok(inner) => inner,
422            Err(_elapsed) => {
423                warn!("request '{}' timed out after {:?}", method, timeout);
424                self.connected.store(false, Ordering::SeqCst);
425                Err(ChannelError::Timeout)
426            }
427        }
428    }
429
430    /// Send a request that may stream data
431    pub async fn request_streaming(
432        &self,
433        method: &str,
434        params: serde_json::Value,
435    ) -> Result<
436        (
437            mpsc::Receiver<serde_json::Value>,
438            oneshot::Receiver<Result<serde_json::Value, String>>,
439        ),
440        ChannelError,
441    > {
442        if !self.is_connected() {
443            return Err(ChannelError::ChannelClosed);
444        }
445
446        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
447
448        // Create channels for response
449        let (data_tx, data_rx) = mpsc::channel(self.data_channel_capacity);
450        let (result_tx, result_rx) = oneshot::channel();
451
452        // Register pending request
453        {
454            let mut pending = self.pending.lock().unwrap();
455            pending.insert(id, PendingRequest { data_tx, result_tx });
456        }
457
458        // Build and send request
459        let req = AgentRequest::new(id, method, params);
460        self.write_tx
461            .send(req.to_json_line())
462            .await
463            .map_err(|_| ChannelError::ChannelClosed)?;
464
465        Ok((data_rx, result_rx))
466    }
467
468    /// Send a request synchronously (blocking)
469    ///
470    /// This can be called from outside the Tokio runtime context.
471    pub fn request_blocking(
472        &self,
473        method: &str,
474        params: serde_json::Value,
475    ) -> Result<serde_json::Value, ChannelError> {
476        self.runtime_handle.block_on(self.request(method, params))
477    }
478
479    /// Send a request and collect all streaming data along with the final result
480    pub async fn request_with_data(
481        &self,
482        method: &str,
483        params: serde_json::Value,
484    ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
485        let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
486
487        let timeout = self.request_timeout();
488
489        let result = tokio::time::timeout(timeout, async {
490            // Collect all streaming data
491            let mut data = Vec::new();
492            while let Some(chunk) = data_rx.recv().await {
493                data.push(chunk);
494
495                // Test hook: simulate slow consumer for backpressure testing.
496                // Zero-cost in production (atomic load + branch-not-taken).
497                let delay_us = TEST_RECV_DELAY_US.load(Ordering::Relaxed);
498                if delay_us > 0 {
499                    tokio::time::sleep(tokio::time::Duration::from_micros(delay_us)).await;
500                }
501            }
502
503            // Wait for final result
504            let result = result_rx
505                .await
506                .map_err(|_| ChannelError::ChannelClosed)?
507                .map_err(ChannelError::Remote)?;
508
509            Ok((data, result))
510        })
511        .await;
512
513        match result {
514            Ok(inner) => inner,
515            Err(_elapsed) => {
516                warn!("streaming request timed out after {:?}", timeout);
517                self.connected.store(false, Ordering::SeqCst);
518                Err(ChannelError::Timeout)
519            }
520        }
521    }
522
523    /// Send a request with streaming data, synchronously (blocking)
524    ///
525    /// This can be called from outside the Tokio runtime context.
526    pub fn request_with_data_blocking(
527        &self,
528        method: &str,
529        params: serde_json::Value,
530    ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
531        self.runtime_handle
532            .block_on(self.request_with_data(method, params))
533    }
534
535    /// Cancel a request
536    pub async fn cancel(&self, request_id: u64) -> Result<(), ChannelError> {
537        use crate::services::remote::protocol::cancel_params;
538        self.request("cancel", cancel_params(request_id)).await?;
539        Ok(())
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    // Tests are in the tests module to allow integration testing with mock agent
546}