Skip to main content

fresh/services/remote/
channel.rs

1//! Agent communication channel
2//!
3//! Handles request/response multiplexing over SSH stdin/stdout.
4//! Supports transport hot-swapping for automatic reconnection:
5//! the read/write tasks survive connection drops and resume when
6//! a new transport is provided via `replace_transport()`.
7
8use crate::services::remote::protocol::{AgentRequest, AgentResponse};
9use std::collections::HashMap;
10use std::io;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
15use tokio::sync::{mpsc, oneshot};
16use tracing::warn;
17
18/// Default capacity for the per-request streaming data channel.
19const DEFAULT_DATA_CHANNEL_CAPACITY: usize = 64;
20
21/// Default timeout for remote requests. If a response is not received within
22/// this duration, the request fails with `ChannelError::Timeout` and the
23/// connection is marked as disconnected.
24const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
25
26/// Test-only: microseconds to sleep in the consumer loop between chunks.
27/// Set to a non-zero value from tests to simulate a slow consumer and
28/// deterministically reproduce channel backpressure scenarios.
29/// Always compiled (not cfg(test)) because integration tests need access.
30pub static TEST_RECV_DELAY_US: AtomicU64 = AtomicU64::new(0);
31
32/// Error type for channel operations
33#[derive(Debug, thiserror::Error)]
34pub enum ChannelError {
35    #[error("IO error: {0}")]
36    Io(#[from] io::Error),
37
38    #[error("JSON error: {0}")]
39    Json(#[from] serde_json::Error),
40
41    #[error("Channel closed")]
42    ChannelClosed,
43
44    #[error("Request cancelled")]
45    Cancelled,
46
47    #[error("Request timed out")]
48    Timeout,
49
50    #[error("Remote error: {0}")]
51    Remote(String),
52}
53
54/// Pending request state
55struct PendingRequest {
56    /// Channel for streaming data
57    data_tx: mpsc::Sender<serde_json::Value>,
58    /// Channel for final result
59    result_tx: oneshot::Sender<Result<serde_json::Value, String>>,
60}
61
62/// Boxed async reader type used by the read task.
63type BoxedReader = Box<dyn AsyncBufRead + Unpin + Send>;
64/// Boxed async writer type used by the write task.
65type BoxedWriter = Box<dyn AsyncWrite + Unpin + Send>;
66
67/// Process-global source of stable per-channel ids. Lets the editor map an
68/// `AsyncMessage::RemoteReconnected` back to the window whose authority owns
69/// this channel, without the channel knowing anything about windows.
70static NEXT_CHANNEL_ID: AtomicU64 = AtomicU64::new(1);
71
72/// Communication channel with the remote agent
73pub struct AgentChannel {
74    /// Stable identity for this channel, assigned at creation. Survives
75    /// transport hot-swaps (the channel object is reused across reconnects),
76    /// so it's a durable key for "this remote session".
77    id: u64,
78    /// Notified once each time the transport is hot-swapped back in
79    /// (`replace_transport`). The editor spawns a forwarder that turns each
80    /// notification into an `AsyncMessage::RemoteReconnected`, so a silent
81    /// background reconnect reaches the app event-driven rather than by
82    /// polling `is_connected()`.
83    reconnect_notify: Arc<tokio::sync::Notify>,
84    /// Sender to the write task
85    write_tx: mpsc::Sender<String>,
86    /// Pending requests awaiting responses
87    pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
88    /// Next request ID
89    next_id: AtomicU64,
90    /// Whether the channel is connected
91    connected: Arc<std::sync::atomic::AtomicBool>,
92    /// Runtime handle for blocking operations
93    runtime_handle: tokio::runtime::Handle,
94    /// Capacity for per-request streaming data channels
95    data_channel_capacity: usize,
96    /// Timeout for individual requests (stored as milliseconds for atomic access)
97    request_timeout_ms: AtomicU64,
98    /// Sender to deliver a new reader to the read task after reconnection
99    new_reader_tx: mpsc::Sender<BoxedReader>,
100    /// Sender to deliver a new writer to the write task after reconnection
101    new_writer_tx: mpsc::Sender<BoxedWriter>,
102}
103
104impl AgentChannel {
105    /// Create a new channel from async read/write handles
106    ///
107    /// Must be called from within a Tokio runtime context.
108    pub fn new(
109        reader: tokio::io::BufReader<tokio::process::ChildStdout>,
110        writer: tokio::process::ChildStdin,
111    ) -> Self {
112        Self::with_capacity(reader, writer, DEFAULT_DATA_CHANNEL_CAPACITY)
113    }
114
115    /// Create a new channel with a custom data channel capacity.
116    ///
117    /// Lower capacity makes channel overflow more likely if `try_send` is used,
118    /// which is useful for stress-testing backpressure handling.
119    pub fn with_capacity(
120        reader: tokio::io::BufReader<tokio::process::ChildStdout>,
121        writer: tokio::process::ChildStdin,
122        data_channel_capacity: usize,
123    ) -> Self {
124        Self::from_transport(reader, writer, data_channel_capacity)
125    }
126
127    /// Create a new channel from any async reader/writer pair.
128    ///
129    /// This is the generic constructor used by both production code (via
130    /// `new`/`with_capacity`) and tests (via arbitrary `AsyncBufRead`/`AsyncWrite`
131    /// implementations like `DuplexStream`).
132    ///
133    /// Must be called from within a Tokio runtime context.
134    pub fn from_transport<R, W>(reader: R, writer: W, data_channel_capacity: usize) -> Self
135    where
136        R: AsyncBufRead + Unpin + Send + 'static,
137        W: AsyncWrite + Unpin + Send + 'static,
138    {
139        let pending: Arc<Mutex<HashMap<u64, PendingRequest>>> =
140            Arc::new(Mutex::new(HashMap::new()));
141        let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
142        let runtime_handle = tokio::runtime::Handle::current();
143
144        // Channel for outgoing requests (lives for the lifetime of the AgentChannel)
145        let (write_tx, write_rx) = mpsc::channel::<String>(64);
146
147        // Channels for delivering replacement transports on reconnection.
148        // Capacity 1: at most one pending reconnection at a time.
149        let (new_reader_tx, new_reader_rx) = mpsc::channel::<BoxedReader>(1);
150        let (new_writer_tx, new_writer_rx) = mpsc::channel::<BoxedWriter>(1);
151
152        // Spawn write task (lives for the lifetime of the AgentChannel)
153        let connected_write = connected.clone();
154        tokio::spawn(Self::write_task(
155            Box::new(writer),
156            write_rx,
157            new_writer_rx,
158            connected_write,
159        ));
160
161        // Spawn read task (lives for the lifetime of the AgentChannel)
162        let pending_read = pending.clone();
163        let connected_read = connected.clone();
164        tokio::spawn(Self::read_task(
165            Box::new(reader),
166            new_reader_rx,
167            pending_read,
168            connected_read,
169        ));
170
171        Self {
172            id: NEXT_CHANNEL_ID.fetch_add(1, Ordering::Relaxed),
173            reconnect_notify: Arc::new(tokio::sync::Notify::new()),
174            write_tx,
175            pending,
176            next_id: AtomicU64::new(1),
177            connected,
178            runtime_handle,
179            data_channel_capacity,
180            request_timeout_ms: AtomicU64::new(DEFAULT_REQUEST_TIMEOUT.as_millis() as u64),
181            new_reader_tx,
182            new_writer_tx,
183        }
184    }
185
186    /// Long-lived write task. Reads outgoing messages from `write_rx` and
187    /// writes them to the current transport. On transport error or when a new
188    /// transport arrives via `new_writer_rx`, switches to the new writer.
189    async fn write_task(
190        mut writer: BoxedWriter,
191        mut write_rx: mpsc::Receiver<String>,
192        mut new_writer_rx: mpsc::Receiver<BoxedWriter>,
193        connected: Arc<std::sync::atomic::AtomicBool>,
194    ) {
195        loop {
196            tokio::select! {
197                // Normal path: send outgoing message
198                msg = write_rx.recv() => {
199                    let Some(msg) = msg else { break }; // AgentChannel dropped
200
201                    let write_ok = writer.write_all(msg.as_bytes()).await.is_ok()
202                        && writer.flush().await.is_ok();
203
204                    if !write_ok {
205                        connected.store(false, Ordering::SeqCst);
206                        // Wait for replacement (can't select here, just block)
207                        match new_writer_rx.recv().await {
208                            Some(new_writer) => { writer = new_writer; continue; }
209                            None => break,
210                        }
211                    }
212                }
213                // Reconnection: new transport arrived, switch immediately
214                new_writer = new_writer_rx.recv() => {
215                    match new_writer {
216                        Some(w) => { writer = w; }
217                        None => break, // AgentChannel dropped
218                    }
219                }
220            }
221        }
222    }
223
224    /// Long-lived read task. Reads responses from the current transport and
225    /// dispatches them to pending requests. On transport error or when a new
226    /// transport arrives, cleans up pending requests and switches readers.
227    async fn read_task(
228        mut reader: BoxedReader,
229        mut new_reader_rx: mpsc::Receiver<BoxedReader>,
230        pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
231        connected: Arc<std::sync::atomic::AtomicBool>,
232    ) {
233        let mut line = String::new();
234
235        loop {
236            line.clear();
237
238            tokio::select! {
239                read_result = reader.read_line(&mut line) => {
240                    match read_result {
241                        Ok(0) | Err(_) => {
242                            // EOF or error — transport is dead
243                            connected.store(false, Ordering::SeqCst);
244                            Self::drain_pending(&pending);
245
246                            // Wait for replacement reader
247                            match new_reader_rx.recv().await {
248                                Some(new_reader) => { reader = new_reader; continue; }
249                                None => break,
250                            }
251                        }
252                        Ok(_) => {
253                            if let Ok(resp) = serde_json::from_str::<AgentResponse>(&line) {
254                                Self::handle_response(&pending, resp).await;
255                            }
256                        }
257                    }
258                }
259                // Reconnection: new transport arrived, switch immediately.
260                // Drain pending requests from the old connection first —
261                // they were sent to the old agent and won't get responses
262                // on the new one. Then mark connected so new requests can
263                // be submitted.
264                new_reader = new_reader_rx.recv() => {
265                    match new_reader {
266                        Some(r) => {
267                            Self::drain_pending(&pending);
268                            reader = r;
269                            connected.store(true, Ordering::SeqCst);
270                        }
271                        None => break, // AgentChannel dropped
272                    }
273                }
274            }
275        }
276    }
277
278    /// Fail all pending requests with "connection closed" so callers don't hang.
279    fn drain_pending(pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>) {
280        let mut pending = pending.lock().unwrap();
281        for (id, req) in pending.drain() {
282            match req.result_tx.send(Err("connection closed".to_string())) {
283                Ok(()) => {}
284                Err(_) => {
285                    warn!("request {id}: receiver dropped during disconnect cleanup");
286                }
287            }
288        }
289    }
290
291    /// Handle an incoming response.
292    ///
293    /// For streaming data, uses `send().await` to apply backpressure when the
294    /// consumer is slower than the producer. This prevents silent data loss
295    /// that occurred with `try_send` (#1059).
296    async fn handle_response(
297        pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>,
298        resp: AgentResponse,
299    ) {
300        // Send streaming data without holding the mutex (send().await may yield)
301        if let Some(data) = resp.data {
302            let data_tx = {
303                let pending = pending.lock().unwrap();
304                pending.get(&resp.id).map(|req| req.data_tx.clone())
305            };
306            if let Some(tx) = data_tx {
307                // send().await blocks until the consumer drains a slot, providing
308                // backpressure instead of silently dropping data.
309                if tx.send(data).await.is_err() {
310                    // Receiver was dropped — this is unexpected since callers
311                    // should hold data_rx until the stream ends. Clean up the
312                    // pending entry to avoid leaking the dead request.
313                    warn!("request {}: data receiver dropped mid-stream", resp.id);
314                    let mut pending = pending.lock().unwrap();
315                    pending.remove(&resp.id);
316                    return;
317                }
318            }
319        }
320
321        // Handle final result/error
322        if resp.result.is_some() || resp.error.is_some() {
323            let mut pending = pending.lock().unwrap();
324            if let Some(req) = pending.remove(&resp.id) {
325                let outcome = if let Some(result) = resp.result {
326                    req.result_tx.send(Ok(result))
327                } else if let Some(error) = resp.error {
328                    req.result_tx.send(Err(error))
329                } else {
330                    // resp matched the outer condition (result or error is Some)
331                    // but neither branch fired — unreachable by construction.
332                    return;
333                };
334                match outcome {
335                    Ok(()) => {}
336                    Err(_) => {
337                        // Receiver was dropped — this is unexpected since
338                        // callers should hold result_rx until they get a result.
339                        warn!("request {}: result receiver dropped", resp.id);
340                    }
341                }
342            }
343        }
344    }
345
346    /// Check if the channel is connected
347    pub fn is_connected(&self) -> bool {
348        self.connected.load(Ordering::SeqCst)
349    }
350
351    /// Replace the underlying transport with a new reader/writer pair.
352    ///
353    /// This is used for reconnection: after establishing a new SSH connection,
354    /// call this method to feed the new stdin/stdout to the existing read/write
355    /// tasks. The tasks will resume processing and `is_connected()` will return
356    /// `true` once the first successful read/write completes.
357    ///
358    /// The `connected` flag is set to `true` by the read task after it has
359    /// received the new reader and drained stale pending requests. This
360    /// ensures no race between draining and new request submission.
361    pub async fn replace_transport<R, W>(&self, reader: R, writer: W)
362    where
363        R: AsyncBufRead + Unpin + Send + 'static,
364        W: AsyncWrite + Unpin + Send + 'static,
365    {
366        // Send new transports to the tasks. Order matters: send writer first
367        // so the write task is ready before the read task marks connected
368        // (which allows new requests to flow).
369        // Send can only fail if the task exited (AgentChannel dropped).
370        if self.new_writer_tx.send(Box::new(writer)).await.is_err() {
371            warn!("replace_transport: write task is gone, cannot reconnect");
372            return;
373        }
374        if self.new_reader_tx.send(Box::new(reader)).await.is_err() {
375            warn!("replace_transport: read task is gone, cannot reconnect");
376        }
377        // The carrier was just hot-swapped back in: wake anyone watching for a
378        // reconnect (the editor's forwarder → `AsyncMessage::RemoteReconnected`,
379        // which respawns embedded terminals that died with the old carrier).
380        // Fired here rather than when `connected` flips true because the
381        // terminal respawn opens its own fresh carrier and doesn't depend on
382        // the agent channel's drain completing.
383        //
384        // `notify_one` (not `notify_waiters`) so a swap that lands in the gap
385        // between the forwarder's send and its next `notified()` still stores a
386        // permit and is delivered — reconnect events can't be dropped. Multiple
387        // swaps coalesce to one permit, which is fine: reattach is idempotent.
388        self.reconnect_notify.notify_one();
389        // Note: connected is set to true by the read task after it drains
390        // stale pending requests and switches to the new reader.
391    }
392
393    /// Stable identity for this channel (see the `id` field).
394    pub fn id(&self) -> u64 {
395        self.id
396    }
397
398    /// A handle that is notified once per successful transport hot-swap. The
399    /// editor awaits it to drive event-driven reconnect handling.
400    pub fn reconnect_notify(&self) -> Arc<tokio::sync::Notify> {
401        self.reconnect_notify.clone()
402    }
403
404    /// Replace the underlying transport (blocking version for non-async contexts).
405    ///
406    /// Sends the new transport to the tasks and waits until the channel is
407    /// marked as connected (i.e., the read task has drained stale requests
408    /// and is ready to receive responses on the new reader).
409    pub fn replace_transport_blocking<R, W>(&self, reader: R, writer: W)
410    where
411        R: AsyncBufRead + Unpin + Send + 'static,
412        W: AsyncWrite + Unpin + Send + 'static,
413    {
414        self.runtime_handle
415            .block_on(self.replace_transport(reader, writer));
416
417        // Yield until the read task has processed the new reader.
418        // This is typically immediate since the channel send above wakes
419        // the read task's select!, which drains pending and sets connected.
420        while !self.is_connected() {
421            std::thread::yield_now();
422        }
423    }
424
425    /// Set the request timeout duration.
426    ///
427    /// Requests that don't receive a response within this duration will fail
428    /// with `ChannelError::Timeout` and the connection will be marked as
429    /// disconnected.
430    pub fn set_request_timeout(&self, timeout: Duration) {
431        self.request_timeout_ms
432            .store(timeout.as_millis() as u64, Ordering::SeqCst);
433    }
434
435    /// Get the current request timeout duration.
436    fn request_timeout(&self) -> Duration {
437        Duration::from_millis(self.request_timeout_ms.load(Ordering::SeqCst))
438    }
439
440    /// Send a request and wait for the final result (ignoring streaming data)
441    pub async fn request(
442        &self,
443        method: &str,
444        params: serde_json::Value,
445    ) -> Result<serde_json::Value, ChannelError> {
446        let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
447
448        let timeout = self.request_timeout();
449
450        // Drain streaming data and wait for final result, with timeout.
451        let result = tokio::time::timeout(timeout, async {
452            while data_rx.recv().await.is_some() {}
453            result_rx
454                .await
455                .map_err(|_| ChannelError::ChannelClosed)?
456                .map_err(ChannelError::Remote)
457        })
458        .await;
459
460        match result {
461            Ok(inner) => inner,
462            Err(_elapsed) => {
463                warn!("request '{}' timed out after {:?}", method, timeout);
464                self.connected.store(false, Ordering::SeqCst);
465                Err(ChannelError::Timeout)
466            }
467        }
468    }
469
470    /// Send a request that may stream data
471    pub async fn request_streaming(
472        &self,
473        method: &str,
474        params: serde_json::Value,
475    ) -> Result<
476        (
477            mpsc::Receiver<serde_json::Value>,
478            oneshot::Receiver<Result<serde_json::Value, String>>,
479        ),
480        ChannelError,
481    > {
482        if !self.is_connected() {
483            return Err(ChannelError::ChannelClosed);
484        }
485
486        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
487
488        // Create channels for response
489        let (data_tx, data_rx) = mpsc::channel(self.data_channel_capacity);
490        let (result_tx, result_rx) = oneshot::channel();
491
492        // Register pending request
493        {
494            let mut pending = self.pending.lock().unwrap();
495            pending.insert(id, PendingRequest { data_tx, result_tx });
496        }
497
498        // Build and send request
499        let req = AgentRequest::new(id, method, params);
500        self.write_tx
501            .send(req.to_json_line())
502            .await
503            .map_err(|_| ChannelError::ChannelClosed)?;
504
505        Ok((data_rx, result_rx))
506    }
507
508    /// Send a request synchronously (blocking)
509    ///
510    /// This can be called from outside the Tokio runtime context.
511    pub fn request_blocking(
512        &self,
513        method: &str,
514        params: serde_json::Value,
515    ) -> Result<serde_json::Value, ChannelError> {
516        self.runtime_handle.block_on(self.request(method, params))
517    }
518
519    /// Send a request and collect all streaming data along with the final result
520    pub async fn request_with_data(
521        &self,
522        method: &str,
523        params: serde_json::Value,
524    ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
525        let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
526
527        let timeout = self.request_timeout();
528
529        let result = tokio::time::timeout(timeout, async {
530            // Collect all streaming data
531            let mut data = Vec::new();
532            while let Some(chunk) = data_rx.recv().await {
533                data.push(chunk);
534
535                // Test hook: simulate slow consumer for backpressure testing.
536                // Zero-cost in production (atomic load + branch-not-taken).
537                let delay_us = TEST_RECV_DELAY_US.load(Ordering::Relaxed);
538                if delay_us > 0 {
539                    tokio::time::sleep(tokio::time::Duration::from_micros(delay_us)).await;
540                }
541            }
542
543            // Wait for final result
544            let result = result_rx
545                .await
546                .map_err(|_| ChannelError::ChannelClosed)?
547                .map_err(ChannelError::Remote)?;
548
549            Ok((data, result))
550        })
551        .await;
552
553        match result {
554            Ok(inner) => inner,
555            Err(_elapsed) => {
556                warn!("streaming request timed out after {:?}", timeout);
557                self.connected.store(false, Ordering::SeqCst);
558                Err(ChannelError::Timeout)
559            }
560        }
561    }
562
563    /// Send a request with streaming data, synchronously (blocking)
564    ///
565    /// This can be called from outside the Tokio runtime context.
566    pub fn request_with_data_blocking(
567        &self,
568        method: &str,
569        params: serde_json::Value,
570    ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
571        self.runtime_handle
572            .block_on(self.request_with_data(method, params))
573    }
574
575    /// Send a streaming request synchronously, returning receivers for
576    /// incremental processing.
577    ///
578    /// Unlike `request_with_data_blocking` which collects all data into
579    /// memory, this returns the raw receivers so callers can process each
580    /// chunk as it arrives (e.g., for `walk_files` where the server sends
581    /// file paths in batches).
582    ///
583    /// Use `data_rx.blocking_recv()` to receive chunks from a sync context.
584    #[allow(clippy::type_complexity)]
585    pub fn request_streaming_blocking(
586        &self,
587        method: &str,
588        params: serde_json::Value,
589    ) -> Result<
590        (
591            mpsc::Receiver<serde_json::Value>,
592            oneshot::Receiver<Result<serde_json::Value, String>>,
593        ),
594        ChannelError,
595    > {
596        self.runtime_handle
597            .block_on(self.request_streaming(method, params))
598    }
599
600    /// Cancel a request
601    pub async fn cancel(&self, request_id: u64) -> Result<(), ChannelError> {
602        use crate::services::remote::protocol::cancel_params;
603        self.request("cancel", cancel_params(request_id)).await?;
604        Ok(())
605    }
606}
607
608#[cfg(test)]
609mod tests {
610    // Tests are in the tests module to allow integration testing with mock agent
611}