1use crate::services::remote::protocol::{AgentRequest, AgentResponse};
6use std::collections::HashMap;
7use std::io;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::{Arc, Mutex};
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
11use tokio::sync::{mpsc, oneshot};
12use tracing::warn;
13
14const DEFAULT_DATA_CHANNEL_CAPACITY: usize = 64;
16
17pub static TEST_RECV_DELAY_US: AtomicU64 = AtomicU64::new(0);
22
23#[derive(Debug, thiserror::Error)]
25pub enum ChannelError {
26 #[error("IO error: {0}")]
27 Io(#[from] io::Error),
28
29 #[error("JSON error: {0}")]
30 Json(#[from] serde_json::Error),
31
32 #[error("Channel closed")]
33 ChannelClosed,
34
35 #[error("Request cancelled")]
36 Cancelled,
37
38 #[error("Request timed out")]
39 Timeout,
40
41 #[error("Remote error: {0}")]
42 Remote(String),
43}
44
45struct PendingRequest {
47 data_tx: mpsc::Sender<serde_json::Value>,
49 result_tx: oneshot::Sender<Result<serde_json::Value, String>>,
51}
52
53pub struct AgentChannel {
55 write_tx: mpsc::Sender<String>,
57 pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
59 next_id: AtomicU64,
61 connected: Arc<std::sync::atomic::AtomicBool>,
63 runtime_handle: tokio::runtime::Handle,
65 data_channel_capacity: usize,
67}
68
69impl AgentChannel {
70 pub fn new(
74 reader: tokio::io::BufReader<tokio::process::ChildStdout>,
75 writer: tokio::process::ChildStdin,
76 ) -> Self {
77 Self::with_capacity(reader, writer, DEFAULT_DATA_CHANNEL_CAPACITY)
78 }
79
80 pub fn with_capacity(
85 mut reader: tokio::io::BufReader<tokio::process::ChildStdout>,
86 mut writer: tokio::process::ChildStdin,
87 data_channel_capacity: usize,
88 ) -> Self {
89 let pending: Arc<Mutex<HashMap<u64, PendingRequest>>> =
90 Arc::new(Mutex::new(HashMap::new()));
91 let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
92 let runtime_handle = tokio::runtime::Handle::current();
94
95 let (write_tx, mut write_rx) = mpsc::channel::<String>(64);
97
98 let connected_write = connected.clone();
100 tokio::spawn(async move {
101 while let Some(msg) = write_rx.recv().await {
102 if writer.write_all(msg.as_bytes()).await.is_err() {
103 connected_write.store(false, Ordering::SeqCst);
104 break;
105 }
106 if writer.flush().await.is_err() {
107 connected_write.store(false, Ordering::SeqCst);
108 break;
109 }
110 }
111 });
112
113 let pending_read = pending.clone();
115 let connected_read = connected.clone();
116 tokio::spawn(async move {
117 let mut line = String::new();
118 loop {
119 line.clear();
120 match reader.read_line(&mut line).await {
121 Ok(0) => {
122 connected_read.store(false, Ordering::SeqCst);
124 break;
125 }
126 Ok(_) => {
127 if let Ok(resp) = serde_json::from_str::<AgentResponse>(&line) {
128 Self::handle_response(&pending_read, resp).await;
129 }
130 }
131 Err(_) => {
132 connected_read.store(false, Ordering::SeqCst);
133 break;
134 }
135 }
136 }
137
138 let mut pending = pending_read.lock().unwrap();
140 for (id, req) in pending.drain() {
141 match req.result_tx.send(Err("connection closed".to_string())) {
142 Ok(()) => {}
143 Err(_) => {
144 warn!("request {id}: receiver dropped during disconnect cleanup");
148 }
149 }
150 }
151 });
152
153 Self {
154 write_tx,
155 pending,
156 next_id: AtomicU64::new(1),
157 connected,
158 runtime_handle,
159 data_channel_capacity,
160 }
161 }
162
163 async fn handle_response(
169 pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>,
170 resp: AgentResponse,
171 ) {
172 if let Some(data) = resp.data {
174 let data_tx = {
175 let pending = pending.lock().unwrap();
176 pending.get(&resp.id).map(|req| req.data_tx.clone())
177 };
178 if let Some(tx) = data_tx {
179 if tx.send(data).await.is_err() {
182 warn!("request {}: data receiver dropped mid-stream", resp.id);
186 let mut pending = pending.lock().unwrap();
187 pending.remove(&resp.id);
188 return;
189 }
190 }
191 }
192
193 if resp.result.is_some() || resp.error.is_some() {
195 let mut pending = pending.lock().unwrap();
196 if let Some(req) = pending.remove(&resp.id) {
197 let outcome = if let Some(result) = resp.result {
198 req.result_tx.send(Ok(result))
199 } else if let Some(error) = resp.error {
200 req.result_tx.send(Err(error))
201 } else {
202 return;
205 };
206 match outcome {
207 Ok(()) => {}
208 Err(_) => {
209 warn!("request {}: result receiver dropped", resp.id);
212 }
213 }
214 }
215 }
216 }
217
218 pub fn is_connected(&self) -> bool {
220 self.connected.load(Ordering::SeqCst)
221 }
222
223 pub async fn request(
225 &self,
226 method: &str,
227 params: serde_json::Value,
228 ) -> Result<serde_json::Value, ChannelError> {
229 let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
230
231 while data_rx.recv().await.is_some() {}
233
234 result_rx
236 .await
237 .map_err(|_| ChannelError::ChannelClosed)?
238 .map_err(ChannelError::Remote)
239 }
240
241 pub async fn request_streaming(
243 &self,
244 method: &str,
245 params: serde_json::Value,
246 ) -> Result<
247 (
248 mpsc::Receiver<serde_json::Value>,
249 oneshot::Receiver<Result<serde_json::Value, String>>,
250 ),
251 ChannelError,
252 > {
253 if !self.is_connected() {
254 return Err(ChannelError::ChannelClosed);
255 }
256
257 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
258
259 let (data_tx, data_rx) = mpsc::channel(self.data_channel_capacity);
261 let (result_tx, result_rx) = oneshot::channel();
262
263 {
265 let mut pending = self.pending.lock().unwrap();
266 pending.insert(id, PendingRequest { data_tx, result_tx });
267 }
268
269 let req = AgentRequest::new(id, method, params);
271 self.write_tx
272 .send(req.to_json_line())
273 .await
274 .map_err(|_| ChannelError::ChannelClosed)?;
275
276 Ok((data_rx, result_rx))
277 }
278
279 pub fn request_blocking(
283 &self,
284 method: &str,
285 params: serde_json::Value,
286 ) -> Result<serde_json::Value, ChannelError> {
287 self.runtime_handle.block_on(self.request(method, params))
288 }
289
290 pub async fn request_with_data(
292 &self,
293 method: &str,
294 params: serde_json::Value,
295 ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
296 let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
297
298 let mut data = Vec::new();
300 while let Some(chunk) = data_rx.recv().await {
301 data.push(chunk);
302
303 let delay_us = TEST_RECV_DELAY_US.load(Ordering::Relaxed);
306 if delay_us > 0 {
307 tokio::time::sleep(tokio::time::Duration::from_micros(delay_us)).await;
308 }
309 }
310
311 let result = result_rx
313 .await
314 .map_err(|_| ChannelError::ChannelClosed)?
315 .map_err(ChannelError::Remote)?;
316
317 Ok((data, result))
318 }
319
320 pub fn request_with_data_blocking(
324 &self,
325 method: &str,
326 params: serde_json::Value,
327 ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
328 self.runtime_handle
329 .block_on(self.request_with_data(method, params))
330 }
331
332 pub async fn cancel(&self, request_id: u64) -> Result<(), ChannelError> {
334 use crate::services::remote::protocol::cancel_params;
335 self.request("cancel", cancel_params(request_id)).await?;
336 Ok(())
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 }