Skip to main content

fresh/services/remote/
channel.rs

1//! Agent communication channel
2//!
3//! Handles request/response multiplexing over SSH stdin/stdout.
4
5use crate::services::remote::protocol::{AgentRequest, AgentResponse};
6use std::collections::HashMap;
7use std::io;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::{Arc, Mutex};
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
11use tokio::sync::{mpsc, oneshot};
12
13/// Error type for channel operations
14#[derive(Debug, thiserror::Error)]
15pub enum ChannelError {
16    #[error("IO error: {0}")]
17    Io(#[from] io::Error),
18
19    #[error("JSON error: {0}")]
20    Json(#[from] serde_json::Error),
21
22    #[error("Channel closed")]
23    ChannelClosed,
24
25    #[error("Request cancelled")]
26    Cancelled,
27
28    #[error("Request timed out")]
29    Timeout,
30
31    #[error("Remote error: {0}")]
32    Remote(String),
33}
34
35/// Pending request state
36struct PendingRequest {
37    /// Channel for streaming data
38    data_tx: mpsc::Sender<serde_json::Value>,
39    /// Channel for final result
40    result_tx: oneshot::Sender<Result<serde_json::Value, String>>,
41}
42
43/// Communication channel with the remote agent
44pub struct AgentChannel {
45    /// Sender to the write task
46    write_tx: mpsc::Sender<String>,
47    /// Pending requests awaiting responses
48    pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
49    /// Next request ID
50    next_id: AtomicU64,
51    /// Whether the channel is connected
52    connected: Arc<std::sync::atomic::AtomicBool>,
53    /// Runtime handle for blocking operations
54    runtime_handle: tokio::runtime::Handle,
55}
56
57impl AgentChannel {
58    /// Create a new channel from async read/write handles
59    ///
60    /// Must be called from within a Tokio runtime context.
61    pub fn new(
62        mut reader: tokio::io::BufReader<tokio::process::ChildStdout>,
63        mut writer: tokio::process::ChildStdin,
64    ) -> Self {
65        let pending: Arc<Mutex<HashMap<u64, PendingRequest>>> =
66            Arc::new(Mutex::new(HashMap::new()));
67        let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
68        // Capture the runtime handle for later use in blocking operations
69        let runtime_handle = tokio::runtime::Handle::current();
70
71        // Channel for outgoing requests
72        let (write_tx, mut write_rx) = mpsc::channel::<String>(64);
73
74        // Spawn write task
75        let connected_write = connected.clone();
76        tokio::spawn(async move {
77            while let Some(msg) = write_rx.recv().await {
78                if writer.write_all(msg.as_bytes()).await.is_err() {
79                    connected_write.store(false, Ordering::SeqCst);
80                    break;
81                }
82                if writer.flush().await.is_err() {
83                    connected_write.store(false, Ordering::SeqCst);
84                    break;
85                }
86            }
87        });
88
89        // Spawn read task
90        let pending_read = pending.clone();
91        let connected_read = connected.clone();
92        tokio::spawn(async move {
93            let mut line = String::new();
94            loop {
95                line.clear();
96                match reader.read_line(&mut line).await {
97                    Ok(0) => {
98                        // EOF
99                        connected_read.store(false, Ordering::SeqCst);
100                        break;
101                    }
102                    Ok(_) => {
103                        if let Ok(resp) = serde_json::from_str::<AgentResponse>(&line) {
104                            Self::handle_response(&pending_read, resp);
105                        }
106                    }
107                    Err(_) => {
108                        connected_read.store(false, Ordering::SeqCst);
109                        break;
110                    }
111                }
112            }
113
114            // Clean up pending requests on disconnect
115            let mut pending = pending_read.lock().unwrap();
116            for (_, req) in pending.drain() {
117                let _ = req.result_tx.send(Err("connection closed".to_string()));
118            }
119        });
120
121        Self {
122            write_tx,
123            pending,
124            next_id: AtomicU64::new(1),
125            connected,
126            runtime_handle,
127        }
128    }
129
130    /// Handle an incoming response
131    fn handle_response(pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>, resp: AgentResponse) {
132        let mut pending = pending.lock().unwrap();
133
134        if let Some(req) = pending.get(&resp.id) {
135            if let Some(data) = resp.data {
136                // Streaming data - send to channel (ignore if receiver dropped)
137                let _ = req.data_tx.try_send(data);
138            }
139
140            if let Some(result) = resp.result {
141                // Success - complete request
142                if let Some(req) = pending.remove(&resp.id) {
143                    let _ = req.result_tx.send(Ok(result));
144                }
145            } else if let Some(error) = resp.error {
146                // Error - complete request
147                if let Some(req) = pending.remove(&resp.id) {
148                    let _ = req.result_tx.send(Err(error));
149                }
150            }
151        }
152    }
153
154    /// Check if the channel is connected
155    pub fn is_connected(&self) -> bool {
156        self.connected.load(Ordering::SeqCst)
157    }
158
159    /// Send a request and wait for the final result (ignoring streaming data)
160    pub async fn request(
161        &self,
162        method: &str,
163        params: serde_json::Value,
164    ) -> Result<serde_json::Value, ChannelError> {
165        let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
166
167        // Drain streaming data
168        while data_rx.recv().await.is_some() {}
169
170        // Wait for final result
171        result_rx
172            .await
173            .map_err(|_| ChannelError::ChannelClosed)?
174            .map_err(ChannelError::Remote)
175    }
176
177    /// Send a request that may stream data
178    pub async fn request_streaming(
179        &self,
180        method: &str,
181        params: serde_json::Value,
182    ) -> Result<
183        (
184            mpsc::Receiver<serde_json::Value>,
185            oneshot::Receiver<Result<serde_json::Value, String>>,
186        ),
187        ChannelError,
188    > {
189        if !self.is_connected() {
190            return Err(ChannelError::ChannelClosed);
191        }
192
193        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
194
195        // Create channels for response
196        let (data_tx, data_rx) = mpsc::channel(64);
197        let (result_tx, result_rx) = oneshot::channel();
198
199        // Register pending request
200        {
201            let mut pending = self.pending.lock().unwrap();
202            pending.insert(id, PendingRequest { data_tx, result_tx });
203        }
204
205        // Build and send request
206        let req = AgentRequest::new(id, method, params);
207        self.write_tx
208            .send(req.to_json_line())
209            .await
210            .map_err(|_| ChannelError::ChannelClosed)?;
211
212        Ok((data_rx, result_rx))
213    }
214
215    /// Send a request synchronously (blocking)
216    ///
217    /// This can be called from outside the Tokio runtime context.
218    pub fn request_blocking(
219        &self,
220        method: &str,
221        params: serde_json::Value,
222    ) -> Result<serde_json::Value, ChannelError> {
223        self.runtime_handle.block_on(self.request(method, params))
224    }
225
226    /// Send a request and collect all streaming data along with the final result
227    pub async fn request_with_data(
228        &self,
229        method: &str,
230        params: serde_json::Value,
231    ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
232        let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
233
234        // Collect all streaming data
235        let mut data = Vec::new();
236        while let Some(chunk) = data_rx.recv().await {
237            data.push(chunk);
238        }
239
240        // Wait for final result
241        let result = result_rx
242            .await
243            .map_err(|_| ChannelError::ChannelClosed)?
244            .map_err(ChannelError::Remote)?;
245
246        Ok((data, result))
247    }
248
249    /// Send a request with streaming data, synchronously (blocking)
250    ///
251    /// This can be called from outside the Tokio runtime context.
252    pub fn request_with_data_blocking(
253        &self,
254        method: &str,
255        params: serde_json::Value,
256    ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
257        self.runtime_handle
258            .block_on(self.request_with_data(method, params))
259    }
260
261    /// Cancel a request
262    pub async fn cancel(&self, request_id: u64) -> Result<(), ChannelError> {
263        use crate::services::remote::protocol::cancel_params;
264        self.request("cancel", cancel_params(request_id)).await?;
265        Ok(())
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    // Tests are in the tests module to allow integration testing with mock agent
272}