fresh/services/remote/channel.rs
1//! Agent communication channel
2//!
3//! Handles request/response multiplexing over SSH stdin/stdout.
4//! Supports transport hot-swapping for automatic reconnection:
5//! the read/write tasks survive connection drops and resume when
6//! a new transport is provided via `replace_transport()`.
7
8use crate::services::remote::protocol::{AgentRequest, AgentResponse};
9use std::collections::HashMap;
10use std::io;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::{Arc, Mutex};
13use std::time::Duration;
14use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
15use tokio::sync::{mpsc, oneshot};
16use tracing::warn;
17
18/// Default capacity for the per-request streaming data channel.
19const DEFAULT_DATA_CHANNEL_CAPACITY: usize = 64;
20
21/// Default timeout for remote requests. If a response is not received within
22/// this duration, the request fails with `ChannelError::Timeout` and the
23/// connection is marked as disconnected.
24const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
25
26/// Test-only: microseconds to sleep in the consumer loop between chunks.
27/// Set to a non-zero value from tests to simulate a slow consumer and
28/// deterministically reproduce channel backpressure scenarios.
29/// Always compiled (not cfg(test)) because integration tests need access.
30pub static TEST_RECV_DELAY_US: AtomicU64 = AtomicU64::new(0);
31
32/// Error type for channel operations
33#[derive(Debug, thiserror::Error)]
34pub enum ChannelError {
35 #[error("IO error: {0}")]
36 Io(#[from] io::Error),
37
38 #[error("JSON error: {0}")]
39 Json(#[from] serde_json::Error),
40
41 #[error("Channel closed")]
42 ChannelClosed,
43
44 #[error("Request cancelled")]
45 Cancelled,
46
47 #[error("Request timed out")]
48 Timeout,
49
50 #[error("Remote error: {0}")]
51 Remote(String),
52}
53
54/// Pending request state
55struct PendingRequest {
56 /// Channel for streaming data
57 data_tx: mpsc::Sender<serde_json::Value>,
58 /// Channel for final result
59 result_tx: oneshot::Sender<Result<serde_json::Value, String>>,
60}
61
62/// Boxed async reader type used by the read task.
63type BoxedReader = Box<dyn AsyncBufRead + Unpin + Send>;
64/// Boxed async writer type used by the write task.
65type BoxedWriter = Box<dyn AsyncWrite + Unpin + Send>;
66
67/// Process-global source of stable per-channel ids. Lets the editor map an
68/// `AsyncMessage::RemoteReconnected` back to the window whose authority owns
69/// this channel, without the channel knowing anything about windows.
70static NEXT_CHANNEL_ID: AtomicU64 = AtomicU64::new(1);
71
72/// Communication channel with the remote agent
73pub struct AgentChannel {
74 /// Stable identity for this channel, assigned at creation. Survives
75 /// transport hot-swaps (the channel object is reused across reconnects),
76 /// so it's a durable key for "this remote session".
77 id: u64,
78 /// Notified once each time the transport is hot-swapped back in
79 /// (`replace_transport`). The editor spawns a forwarder that turns each
80 /// notification into an `AsyncMessage::RemoteReconnected`, so a silent
81 /// background reconnect reaches the app event-driven rather than by
82 /// polling `is_connected()`.
83 reconnect_notify: Arc<tokio::sync::Notify>,
84 /// Sender to the write task
85 write_tx: mpsc::Sender<String>,
86 /// Pending requests awaiting responses
87 pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
88 /// Next request ID
89 next_id: AtomicU64,
90 /// Whether the channel is connected
91 connected: Arc<std::sync::atomic::AtomicBool>,
92 /// Runtime handle for blocking operations
93 runtime_handle: tokio::runtime::Handle,
94 /// Capacity for per-request streaming data channels
95 data_channel_capacity: usize,
96 /// Timeout for individual requests (stored as milliseconds for atomic access)
97 request_timeout_ms: AtomicU64,
98 /// Sender to deliver a new reader to the read task after reconnection
99 new_reader_tx: mpsc::Sender<BoxedReader>,
100 /// Sender to deliver a new writer to the write task after reconnection
101 new_writer_tx: mpsc::Sender<BoxedWriter>,
102}
103
104impl AgentChannel {
105 /// Create a new channel from async read/write handles
106 ///
107 /// Must be called from within a Tokio runtime context.
108 pub fn new(
109 reader: tokio::io::BufReader<tokio::process::ChildStdout>,
110 writer: tokio::process::ChildStdin,
111 ) -> Self {
112 Self::with_capacity(reader, writer, DEFAULT_DATA_CHANNEL_CAPACITY)
113 }
114
115 /// Create a new channel with a custom data channel capacity.
116 ///
117 /// Lower capacity makes channel overflow more likely if `try_send` is used,
118 /// which is useful for stress-testing backpressure handling.
119 pub fn with_capacity(
120 reader: tokio::io::BufReader<tokio::process::ChildStdout>,
121 writer: tokio::process::ChildStdin,
122 data_channel_capacity: usize,
123 ) -> Self {
124 Self::from_transport(reader, writer, data_channel_capacity)
125 }
126
127 /// Create a new channel from any async reader/writer pair.
128 ///
129 /// This is the generic constructor used by both production code (via
130 /// `new`/`with_capacity`) and tests (via arbitrary `AsyncBufRead`/`AsyncWrite`
131 /// implementations like `DuplexStream`).
132 ///
133 /// Must be called from within a Tokio runtime context.
134 pub fn from_transport<R, W>(reader: R, writer: W, data_channel_capacity: usize) -> Self
135 where
136 R: AsyncBufRead + Unpin + Send + 'static,
137 W: AsyncWrite + Unpin + Send + 'static,
138 {
139 let pending: Arc<Mutex<HashMap<u64, PendingRequest>>> =
140 Arc::new(Mutex::new(HashMap::new()));
141 let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
142 let runtime_handle = tokio::runtime::Handle::current();
143
144 // Channel for outgoing requests (lives for the lifetime of the AgentChannel)
145 let (write_tx, write_rx) = mpsc::channel::<String>(64);
146
147 // Channels for delivering replacement transports on reconnection.
148 // Capacity 1: at most one pending reconnection at a time.
149 let (new_reader_tx, new_reader_rx) = mpsc::channel::<BoxedReader>(1);
150 let (new_writer_tx, new_writer_rx) = mpsc::channel::<BoxedWriter>(1);
151
152 // Spawn write task (lives for the lifetime of the AgentChannel)
153 let connected_write = connected.clone();
154 tokio::spawn(Self::write_task(
155 Box::new(writer),
156 write_rx,
157 new_writer_rx,
158 connected_write,
159 ));
160
161 // Spawn read task (lives for the lifetime of the AgentChannel)
162 let pending_read = pending.clone();
163 let connected_read = connected.clone();
164 tokio::spawn(Self::read_task(
165 Box::new(reader),
166 new_reader_rx,
167 pending_read,
168 connected_read,
169 ));
170
171 Self {
172 id: NEXT_CHANNEL_ID.fetch_add(1, Ordering::Relaxed),
173 reconnect_notify: Arc::new(tokio::sync::Notify::new()),
174 write_tx,
175 pending,
176 next_id: AtomicU64::new(1),
177 connected,
178 runtime_handle,
179 data_channel_capacity,
180 request_timeout_ms: AtomicU64::new(DEFAULT_REQUEST_TIMEOUT.as_millis() as u64),
181 new_reader_tx,
182 new_writer_tx,
183 }
184 }
185
186 /// Long-lived write task. Reads outgoing messages from `write_rx` and
187 /// writes them to the current transport. On transport error or when a new
188 /// transport arrives via `new_writer_rx`, switches to the new writer.
189 async fn write_task(
190 mut writer: BoxedWriter,
191 mut write_rx: mpsc::Receiver<String>,
192 mut new_writer_rx: mpsc::Receiver<BoxedWriter>,
193 connected: Arc<std::sync::atomic::AtomicBool>,
194 ) {
195 loop {
196 tokio::select! {
197 // Normal path: send outgoing message
198 msg = write_rx.recv() => {
199 let Some(msg) = msg else { break }; // AgentChannel dropped
200
201 let write_ok = writer.write_all(msg.as_bytes()).await.is_ok()
202 && writer.flush().await.is_ok();
203
204 if !write_ok {
205 connected.store(false, Ordering::SeqCst);
206 // Wait for replacement (can't select here, just block)
207 match new_writer_rx.recv().await {
208 Some(new_writer) => { writer = new_writer; continue; }
209 None => break,
210 }
211 }
212 }
213 // Reconnection: new transport arrived, switch immediately
214 new_writer = new_writer_rx.recv() => {
215 match new_writer {
216 Some(w) => { writer = w; }
217 None => break, // AgentChannel dropped
218 }
219 }
220 }
221 }
222 }
223
224 /// Long-lived read task. Reads responses from the current transport and
225 /// dispatches them to pending requests. On transport error or when a new
226 /// transport arrives, cleans up pending requests and switches readers.
227 async fn read_task(
228 mut reader: BoxedReader,
229 mut new_reader_rx: mpsc::Receiver<BoxedReader>,
230 pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
231 connected: Arc<std::sync::atomic::AtomicBool>,
232 ) {
233 let mut line = String::new();
234
235 loop {
236 line.clear();
237
238 tokio::select! {
239 read_result = reader.read_line(&mut line) => {
240 match read_result {
241 Ok(0) | Err(_) => {
242 // EOF or error — transport is dead
243 connected.store(false, Ordering::SeqCst);
244 Self::drain_pending(&pending);
245
246 // Wait for replacement reader
247 match new_reader_rx.recv().await {
248 Some(new_reader) => { reader = new_reader; continue; }
249 None => break,
250 }
251 }
252 Ok(_) => {
253 if let Ok(resp) = serde_json::from_str::<AgentResponse>(&line) {
254 Self::handle_response(&pending, resp).await;
255 }
256 }
257 }
258 }
259 // Reconnection: new transport arrived, switch immediately.
260 // Drain pending requests from the old connection first —
261 // they were sent to the old agent and won't get responses
262 // on the new one. Then mark connected so new requests can
263 // be submitted.
264 new_reader = new_reader_rx.recv() => {
265 match new_reader {
266 Some(r) => {
267 Self::drain_pending(&pending);
268 reader = r;
269 connected.store(true, Ordering::SeqCst);
270 }
271 None => break, // AgentChannel dropped
272 }
273 }
274 }
275 }
276 }
277
278 /// Fail all pending requests with "connection closed" so callers don't hang.
279 fn drain_pending(pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>) {
280 let mut pending = pending.lock().unwrap();
281 for (id, req) in pending.drain() {
282 match req.result_tx.send(Err("connection closed".to_string())) {
283 Ok(()) => {}
284 Err(_) => {
285 warn!("request {id}: receiver dropped during disconnect cleanup");
286 }
287 }
288 }
289 }
290
291 /// Handle an incoming response.
292 ///
293 /// For streaming data, uses `send().await` to apply backpressure when the
294 /// consumer is slower than the producer. This prevents silent data loss
295 /// that occurred with `try_send` (#1059).
296 async fn handle_response(
297 pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>,
298 resp: AgentResponse,
299 ) {
300 // Send streaming data without holding the mutex (send().await may yield)
301 if let Some(data) = resp.data {
302 let data_tx = {
303 let pending = pending.lock().unwrap();
304 pending.get(&resp.id).map(|req| req.data_tx.clone())
305 };
306 if let Some(tx) = data_tx {
307 // send().await blocks until the consumer drains a slot, providing
308 // backpressure instead of silently dropping data.
309 if tx.send(data).await.is_err() {
310 // Receiver was dropped — this is unexpected since callers
311 // should hold data_rx until the stream ends. Clean up the
312 // pending entry to avoid leaking the dead request.
313 warn!("request {}: data receiver dropped mid-stream", resp.id);
314 let mut pending = pending.lock().unwrap();
315 pending.remove(&resp.id);
316 return;
317 }
318 }
319 }
320
321 // Handle final result/error
322 if resp.result.is_some() || resp.error.is_some() {
323 let mut pending = pending.lock().unwrap();
324 if let Some(req) = pending.remove(&resp.id) {
325 let outcome = if let Some(result) = resp.result {
326 req.result_tx.send(Ok(result))
327 } else if let Some(error) = resp.error {
328 req.result_tx.send(Err(error))
329 } else {
330 // resp matched the outer condition (result or error is Some)
331 // but neither branch fired — unreachable by construction.
332 return;
333 };
334 match outcome {
335 Ok(()) => {}
336 Err(_) => {
337 // Receiver was dropped — this is unexpected since
338 // callers should hold result_rx until they get a result.
339 warn!("request {}: result receiver dropped", resp.id);
340 }
341 }
342 }
343 }
344 }
345
346 /// Check if the channel is connected
347 pub fn is_connected(&self) -> bool {
348 self.connected.load(Ordering::SeqCst)
349 }
350
351 /// Replace the underlying transport with a new reader/writer pair.
352 ///
353 /// This is used for reconnection: after establishing a new SSH connection,
354 /// call this method to feed the new stdin/stdout to the existing read/write
355 /// tasks. The tasks will resume processing and `is_connected()` will return
356 /// `true` once the first successful read/write completes.
357 ///
358 /// The `connected` flag is set to `true` by the read task after it has
359 /// received the new reader and drained stale pending requests. This
360 /// ensures no race between draining and new request submission.
361 pub async fn replace_transport<R, W>(&self, reader: R, writer: W)
362 where
363 R: AsyncBufRead + Unpin + Send + 'static,
364 W: AsyncWrite + Unpin + Send + 'static,
365 {
366 // Send new transports to the tasks. Order matters: send writer first
367 // so the write task is ready before the read task marks connected
368 // (which allows new requests to flow).
369 // Send can only fail if the task exited (AgentChannel dropped).
370 if self.new_writer_tx.send(Box::new(writer)).await.is_err() {
371 warn!("replace_transport: write task is gone, cannot reconnect");
372 return;
373 }
374 if self.new_reader_tx.send(Box::new(reader)).await.is_err() {
375 warn!("replace_transport: read task is gone, cannot reconnect");
376 }
377 // The carrier was just hot-swapped back in: wake anyone watching for a
378 // reconnect (the editor's forwarder → `AsyncMessage::RemoteReconnected`,
379 // which respawns embedded terminals that died with the old carrier).
380 // Fired here rather than when `connected` flips true because the
381 // terminal respawn opens its own fresh carrier and doesn't depend on
382 // the agent channel's drain completing.
383 //
384 // `notify_one` (not `notify_waiters`) so a swap that lands in the gap
385 // between the forwarder's send and its next `notified()` still stores a
386 // permit and is delivered — reconnect events can't be dropped. Multiple
387 // swaps coalesce to one permit, which is fine: reattach is idempotent.
388 self.reconnect_notify.notify_one();
389 // Note: connected is set to true by the read task after it drains
390 // stale pending requests and switches to the new reader.
391 }
392
393 /// Stable identity for this channel (see the `id` field).
394 pub fn id(&self) -> u64 {
395 self.id
396 }
397
398 /// A handle that is notified once per successful transport hot-swap. The
399 /// editor awaits it to drive event-driven reconnect handling.
400 pub fn reconnect_notify(&self) -> Arc<tokio::sync::Notify> {
401 self.reconnect_notify.clone()
402 }
403
404 /// Replace the underlying transport (blocking version for non-async contexts).
405 ///
406 /// Sends the new transport to the tasks and waits until the channel is
407 /// marked as connected (i.e., the read task has drained stale requests
408 /// and is ready to receive responses on the new reader).
409 pub fn replace_transport_blocking<R, W>(&self, reader: R, writer: W)
410 where
411 R: AsyncBufRead + Unpin + Send + 'static,
412 W: AsyncWrite + Unpin + Send + 'static,
413 {
414 self.runtime_handle
415 .block_on(self.replace_transport(reader, writer));
416
417 // Yield until the read task has processed the new reader.
418 // This is typically immediate since the channel send above wakes
419 // the read task's select!, which drains pending and sets connected.
420 while !self.is_connected() {
421 std::thread::yield_now();
422 }
423 }
424
425 /// Set the request timeout duration.
426 ///
427 /// Requests that don't receive a response within this duration will fail
428 /// with `ChannelError::Timeout` and the connection will be marked as
429 /// disconnected.
430 pub fn set_request_timeout(&self, timeout: Duration) {
431 self.request_timeout_ms
432 .store(timeout.as_millis() as u64, Ordering::SeqCst);
433 }
434
435 /// Get the current request timeout duration.
436 fn request_timeout(&self) -> Duration {
437 Duration::from_millis(self.request_timeout_ms.load(Ordering::SeqCst))
438 }
439
440 /// Send a request and wait for the final result (ignoring streaming data)
441 pub async fn request(
442 &self,
443 method: &str,
444 params: serde_json::Value,
445 ) -> Result<serde_json::Value, ChannelError> {
446 let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
447
448 let timeout = self.request_timeout();
449
450 // Drain streaming data and wait for final result, with timeout.
451 let result = tokio::time::timeout(timeout, async {
452 while data_rx.recv().await.is_some() {}
453 result_rx
454 .await
455 .map_err(|_| ChannelError::ChannelClosed)?
456 .map_err(ChannelError::Remote)
457 })
458 .await;
459
460 match result {
461 Ok(inner) => inner,
462 Err(_elapsed) => {
463 warn!("request '{}' timed out after {:?}", method, timeout);
464 self.connected.store(false, Ordering::SeqCst);
465 Err(ChannelError::Timeout)
466 }
467 }
468 }
469
470 /// Send a request that may stream data
471 pub async fn request_streaming(
472 &self,
473 method: &str,
474 params: serde_json::Value,
475 ) -> Result<
476 (
477 mpsc::Receiver<serde_json::Value>,
478 oneshot::Receiver<Result<serde_json::Value, String>>,
479 ),
480 ChannelError,
481 > {
482 if !self.is_connected() {
483 return Err(ChannelError::ChannelClosed);
484 }
485
486 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
487
488 // Create channels for response
489 let (data_tx, data_rx) = mpsc::channel(self.data_channel_capacity);
490 let (result_tx, result_rx) = oneshot::channel();
491
492 // Register pending request
493 {
494 let mut pending = self.pending.lock().unwrap();
495 pending.insert(id, PendingRequest { data_tx, result_tx });
496 }
497
498 // Build and send request
499 let req = AgentRequest::new(id, method, params);
500 self.write_tx
501 .send(req.to_json_line())
502 .await
503 .map_err(|_| ChannelError::ChannelClosed)?;
504
505 Ok((data_rx, result_rx))
506 }
507
508 /// Send a request synchronously (blocking)
509 ///
510 /// This can be called from outside the Tokio runtime context.
511 pub fn request_blocking(
512 &self,
513 method: &str,
514 params: serde_json::Value,
515 ) -> Result<serde_json::Value, ChannelError> {
516 self.runtime_handle.block_on(self.request(method, params))
517 }
518
519 /// Send a request and collect all streaming data along with the final result
520 pub async fn request_with_data(
521 &self,
522 method: &str,
523 params: serde_json::Value,
524 ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
525 let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
526
527 let timeout = self.request_timeout();
528
529 let result = tokio::time::timeout(timeout, async {
530 // Collect all streaming data
531 let mut data = Vec::new();
532 while let Some(chunk) = data_rx.recv().await {
533 data.push(chunk);
534
535 // Test hook: simulate slow consumer for backpressure testing.
536 // Zero-cost in production (atomic load + branch-not-taken).
537 let delay_us = TEST_RECV_DELAY_US.load(Ordering::Relaxed);
538 if delay_us > 0 {
539 tokio::time::sleep(tokio::time::Duration::from_micros(delay_us)).await;
540 }
541 }
542
543 // Wait for final result
544 let result = result_rx
545 .await
546 .map_err(|_| ChannelError::ChannelClosed)?
547 .map_err(ChannelError::Remote)?;
548
549 Ok((data, result))
550 })
551 .await;
552
553 match result {
554 Ok(inner) => inner,
555 Err(_elapsed) => {
556 warn!("streaming request timed out after {:?}", timeout);
557 self.connected.store(false, Ordering::SeqCst);
558 Err(ChannelError::Timeout)
559 }
560 }
561 }
562
563 /// Send a request with streaming data, synchronously (blocking)
564 ///
565 /// This can be called from outside the Tokio runtime context.
566 pub fn request_with_data_blocking(
567 &self,
568 method: &str,
569 params: serde_json::Value,
570 ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
571 self.runtime_handle
572 .block_on(self.request_with_data(method, params))
573 }
574
575 /// Send a streaming request synchronously, returning receivers for
576 /// incremental processing.
577 ///
578 /// Unlike `request_with_data_blocking` which collects all data into
579 /// memory, this returns the raw receivers so callers can process each
580 /// chunk as it arrives (e.g., for `walk_files` where the server sends
581 /// file paths in batches).
582 ///
583 /// Use `data_rx.blocking_recv()` to receive chunks from a sync context.
584 #[allow(clippy::type_complexity)]
585 pub fn request_streaming_blocking(
586 &self,
587 method: &str,
588 params: serde_json::Value,
589 ) -> Result<
590 (
591 mpsc::Receiver<serde_json::Value>,
592 oneshot::Receiver<Result<serde_json::Value, String>>,
593 ),
594 ChannelError,
595 > {
596 self.runtime_handle
597 .block_on(self.request_streaming(method, params))
598 }
599
600 /// Cancel a request
601 pub async fn cancel(&self, request_id: u64) -> Result<(), ChannelError> {
602 use crate::services::remote::protocol::cancel_params;
603 self.request("cancel", cancel_params(request_id)).await?;
604 Ok(())
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 // Tests are in the tests module to allow integration testing with mock agent
611}