Skip to main content

fresh/services/remote/
channel.rs

1//! Agent communication channel
2//!
3//! Handles request/response multiplexing over SSH stdin/stdout.
4
5use crate::services::remote::protocol::{AgentRequest, AgentResponse};
6use std::collections::HashMap;
7use std::io;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::{Arc, Mutex};
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
11use tokio::sync::{mpsc, oneshot};
12use tracing::warn;
13
14/// Default capacity for the per-request streaming data channel.
15const DEFAULT_DATA_CHANNEL_CAPACITY: usize = 64;
16
17/// Test-only: microseconds to sleep in the consumer loop between chunks.
18/// Set to a non-zero value from tests to simulate a slow consumer and
19/// deterministically reproduce channel backpressure scenarios.
20/// Always compiled (not cfg(test)) because integration tests need access.
21pub static TEST_RECV_DELAY_US: AtomicU64 = AtomicU64::new(0);
22
23/// Error type for channel operations
24#[derive(Debug, thiserror::Error)]
25pub enum ChannelError {
26    #[error("IO error: {0}")]
27    Io(#[from] io::Error),
28
29    #[error("JSON error: {0}")]
30    Json(#[from] serde_json::Error),
31
32    #[error("Channel closed")]
33    ChannelClosed,
34
35    #[error("Request cancelled")]
36    Cancelled,
37
38    #[error("Request timed out")]
39    Timeout,
40
41    #[error("Remote error: {0}")]
42    Remote(String),
43}
44
45/// Pending request state
46struct PendingRequest {
47    /// Channel for streaming data
48    data_tx: mpsc::Sender<serde_json::Value>,
49    /// Channel for final result
50    result_tx: oneshot::Sender<Result<serde_json::Value, String>>,
51}
52
53/// Communication channel with the remote agent
54pub struct AgentChannel {
55    /// Sender to the write task
56    write_tx: mpsc::Sender<String>,
57    /// Pending requests awaiting responses
58    pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
59    /// Next request ID
60    next_id: AtomicU64,
61    /// Whether the channel is connected
62    connected: Arc<std::sync::atomic::AtomicBool>,
63    /// Runtime handle for blocking operations
64    runtime_handle: tokio::runtime::Handle,
65    /// Capacity for per-request streaming data channels
66    data_channel_capacity: usize,
67}
68
69impl AgentChannel {
70    /// Create a new channel from async read/write handles
71    ///
72    /// Must be called from within a Tokio runtime context.
73    pub fn new(
74        reader: tokio::io::BufReader<tokio::process::ChildStdout>,
75        writer: tokio::process::ChildStdin,
76    ) -> Self {
77        Self::with_capacity(reader, writer, DEFAULT_DATA_CHANNEL_CAPACITY)
78    }
79
80    /// Create a new channel with a custom data channel capacity.
81    ///
82    /// Lower capacity makes channel overflow more likely if `try_send` is used,
83    /// which is useful for stress-testing backpressure handling.
84    pub fn with_capacity(
85        mut reader: tokio::io::BufReader<tokio::process::ChildStdout>,
86        mut writer: tokio::process::ChildStdin,
87        data_channel_capacity: usize,
88    ) -> Self {
89        let pending: Arc<Mutex<HashMap<u64, PendingRequest>>> =
90            Arc::new(Mutex::new(HashMap::new()));
91        let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
92        // Capture the runtime handle for later use in blocking operations
93        let runtime_handle = tokio::runtime::Handle::current();
94
95        // Channel for outgoing requests
96        let (write_tx, mut write_rx) = mpsc::channel::<String>(64);
97
98        // Spawn write task
99        let connected_write = connected.clone();
100        tokio::spawn(async move {
101            while let Some(msg) = write_rx.recv().await {
102                if writer.write_all(msg.as_bytes()).await.is_err() {
103                    connected_write.store(false, Ordering::SeqCst);
104                    break;
105                }
106                if writer.flush().await.is_err() {
107                    connected_write.store(false, Ordering::SeqCst);
108                    break;
109                }
110            }
111        });
112
113        // Spawn read task
114        let pending_read = pending.clone();
115        let connected_read = connected.clone();
116        tokio::spawn(async move {
117            let mut line = String::new();
118            loop {
119                line.clear();
120                match reader.read_line(&mut line).await {
121                    Ok(0) => {
122                        // EOF
123                        connected_read.store(false, Ordering::SeqCst);
124                        break;
125                    }
126                    Ok(_) => {
127                        if let Ok(resp) = serde_json::from_str::<AgentResponse>(&line) {
128                            Self::handle_response(&pending_read, resp).await;
129                        }
130                    }
131                    Err(_) => {
132                        connected_read.store(false, Ordering::SeqCst);
133                        break;
134                    }
135                }
136            }
137
138            // Clean up pending requests on disconnect.
139            let mut pending = pending_read.lock().unwrap();
140            for (id, req) in pending.drain() {
141                match req.result_tx.send(Err("connection closed".to_string())) {
142                    Ok(()) => {}
143                    Err(_) => {
144                        // Receiver was dropped before we could notify it.
145                        // This is unexpected — callers should hold their
146                        // receivers until the operation completes.
147                        warn!("request {id}: receiver dropped during disconnect cleanup");
148                    }
149                }
150            }
151        });
152
153        Self {
154            write_tx,
155            pending,
156            next_id: AtomicU64::new(1),
157            connected,
158            runtime_handle,
159            data_channel_capacity,
160        }
161    }
162
163    /// Handle an incoming response.
164    ///
165    /// For streaming data, uses `send().await` to apply backpressure when the
166    /// consumer is slower than the producer. This prevents silent data loss
167    /// that occurred with `try_send` (#1059).
168    async fn handle_response(
169        pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>,
170        resp: AgentResponse,
171    ) {
172        // Send streaming data without holding the mutex (send().await may yield)
173        if let Some(data) = resp.data {
174            let data_tx = {
175                let pending = pending.lock().unwrap();
176                pending.get(&resp.id).map(|req| req.data_tx.clone())
177            };
178            if let Some(tx) = data_tx {
179                // send().await blocks until the consumer drains a slot, providing
180                // backpressure instead of silently dropping data.
181                if tx.send(data).await.is_err() {
182                    // Receiver was dropped — this is unexpected since callers
183                    // should hold data_rx until the stream ends. Clean up the
184                    // pending entry to avoid leaking the dead request.
185                    warn!("request {}: data receiver dropped mid-stream", resp.id);
186                    let mut pending = pending.lock().unwrap();
187                    pending.remove(&resp.id);
188                    return;
189                }
190            }
191        }
192
193        // Handle final result/error
194        if resp.result.is_some() || resp.error.is_some() {
195            let mut pending = pending.lock().unwrap();
196            if let Some(req) = pending.remove(&resp.id) {
197                let outcome = if let Some(result) = resp.result {
198                    req.result_tx.send(Ok(result))
199                } else if let Some(error) = resp.error {
200                    req.result_tx.send(Err(error))
201                } else {
202                    // resp matched the outer condition (result or error is Some)
203                    // but neither branch fired — unreachable by construction.
204                    return;
205                };
206                match outcome {
207                    Ok(()) => {}
208                    Err(_) => {
209                        // Receiver was dropped — this is unexpected since
210                        // callers should hold result_rx until they get a result.
211                        warn!("request {}: result receiver dropped", resp.id);
212                    }
213                }
214            }
215        }
216    }
217
218    /// Check if the channel is connected
219    pub fn is_connected(&self) -> bool {
220        self.connected.load(Ordering::SeqCst)
221    }
222
223    /// Send a request and wait for the final result (ignoring streaming data)
224    pub async fn request(
225        &self,
226        method: &str,
227        params: serde_json::Value,
228    ) -> Result<serde_json::Value, ChannelError> {
229        let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
230
231        // Drain streaming data
232        while data_rx.recv().await.is_some() {}
233
234        // Wait for final result
235        result_rx
236            .await
237            .map_err(|_| ChannelError::ChannelClosed)?
238            .map_err(ChannelError::Remote)
239    }
240
241    /// Send a request that may stream data
242    pub async fn request_streaming(
243        &self,
244        method: &str,
245        params: serde_json::Value,
246    ) -> Result<
247        (
248            mpsc::Receiver<serde_json::Value>,
249            oneshot::Receiver<Result<serde_json::Value, String>>,
250        ),
251        ChannelError,
252    > {
253        if !self.is_connected() {
254            return Err(ChannelError::ChannelClosed);
255        }
256
257        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
258
259        // Create channels for response
260        let (data_tx, data_rx) = mpsc::channel(self.data_channel_capacity);
261        let (result_tx, result_rx) = oneshot::channel();
262
263        // Register pending request
264        {
265            let mut pending = self.pending.lock().unwrap();
266            pending.insert(id, PendingRequest { data_tx, result_tx });
267        }
268
269        // Build and send request
270        let req = AgentRequest::new(id, method, params);
271        self.write_tx
272            .send(req.to_json_line())
273            .await
274            .map_err(|_| ChannelError::ChannelClosed)?;
275
276        Ok((data_rx, result_rx))
277    }
278
279    /// Send a request synchronously (blocking)
280    ///
281    /// This can be called from outside the Tokio runtime context.
282    pub fn request_blocking(
283        &self,
284        method: &str,
285        params: serde_json::Value,
286    ) -> Result<serde_json::Value, ChannelError> {
287        self.runtime_handle.block_on(self.request(method, params))
288    }
289
290    /// Send a request and collect all streaming data along with the final result
291    pub async fn request_with_data(
292        &self,
293        method: &str,
294        params: serde_json::Value,
295    ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
296        let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
297
298        // Collect all streaming data
299        let mut data = Vec::new();
300        while let Some(chunk) = data_rx.recv().await {
301            data.push(chunk);
302
303            // Test hook: simulate slow consumer for backpressure testing.
304            // Zero-cost in production (atomic load + branch-not-taken).
305            let delay_us = TEST_RECV_DELAY_US.load(Ordering::Relaxed);
306            if delay_us > 0 {
307                tokio::time::sleep(tokio::time::Duration::from_micros(delay_us)).await;
308            }
309        }
310
311        // Wait for final result
312        let result = result_rx
313            .await
314            .map_err(|_| ChannelError::ChannelClosed)?
315            .map_err(ChannelError::Remote)?;
316
317        Ok((data, result))
318    }
319
320    /// Send a request with streaming data, synchronously (blocking)
321    ///
322    /// This can be called from outside the Tokio runtime context.
323    pub fn request_with_data_blocking(
324        &self,
325        method: &str,
326        params: serde_json::Value,
327    ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
328        self.runtime_handle
329            .block_on(self.request_with_data(method, params))
330    }
331
332    /// Cancel a request
333    pub async fn cancel(&self, request_id: u64) -> Result<(), ChannelError> {
334        use crate::services::remote::protocol::cancel_params;
335        self.request("cancel", cancel_params(request_id)).await?;
336        Ok(())
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    // Tests are in the tests module to allow integration testing with mock agent
343}