fresh/services/remote/
channel.rs1use crate::services::remote::protocol::{AgentRequest, AgentResponse};
6use std::collections::HashMap;
7use std::io;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::{Arc, Mutex};
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
11use tokio::sync::{mpsc, oneshot};
12
13#[derive(Debug, thiserror::Error)]
15pub enum ChannelError {
16 #[error("IO error: {0}")]
17 Io(#[from] io::Error),
18
19 #[error("JSON error: {0}")]
20 Json(#[from] serde_json::Error),
21
22 #[error("Channel closed")]
23 ChannelClosed,
24
25 #[error("Request cancelled")]
26 Cancelled,
27
28 #[error("Request timed out")]
29 Timeout,
30
31 #[error("Remote error: {0}")]
32 Remote(String),
33}
34
35struct PendingRequest {
37 data_tx: mpsc::Sender<serde_json::Value>,
39 result_tx: oneshot::Sender<Result<serde_json::Value, String>>,
41}
42
43pub struct AgentChannel {
45 write_tx: mpsc::Sender<String>,
47 pending: Arc<Mutex<HashMap<u64, PendingRequest>>>,
49 next_id: AtomicU64,
51 connected: Arc<std::sync::atomic::AtomicBool>,
53 runtime_handle: tokio::runtime::Handle,
55}
56
57impl AgentChannel {
58 pub fn new(
62 mut reader: tokio::io::BufReader<tokio::process::ChildStdout>,
63 mut writer: tokio::process::ChildStdin,
64 ) -> Self {
65 let pending: Arc<Mutex<HashMap<u64, PendingRequest>>> =
66 Arc::new(Mutex::new(HashMap::new()));
67 let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
68 let runtime_handle = tokio::runtime::Handle::current();
70
71 let (write_tx, mut write_rx) = mpsc::channel::<String>(64);
73
74 let connected_write = connected.clone();
76 tokio::spawn(async move {
77 while let Some(msg) = write_rx.recv().await {
78 if writer.write_all(msg.as_bytes()).await.is_err() {
79 connected_write.store(false, Ordering::SeqCst);
80 break;
81 }
82 if writer.flush().await.is_err() {
83 connected_write.store(false, Ordering::SeqCst);
84 break;
85 }
86 }
87 });
88
89 let pending_read = pending.clone();
91 let connected_read = connected.clone();
92 tokio::spawn(async move {
93 let mut line = String::new();
94 loop {
95 line.clear();
96 match reader.read_line(&mut line).await {
97 Ok(0) => {
98 connected_read.store(false, Ordering::SeqCst);
100 break;
101 }
102 Ok(_) => {
103 if let Ok(resp) = serde_json::from_str::<AgentResponse>(&line) {
104 Self::handle_response(&pending_read, resp);
105 }
106 }
107 Err(_) => {
108 connected_read.store(false, Ordering::SeqCst);
109 break;
110 }
111 }
112 }
113
114 let mut pending = pending_read.lock().unwrap();
116 for (_, req) in pending.drain() {
117 let _ = req.result_tx.send(Err("connection closed".to_string()));
118 }
119 });
120
121 Self {
122 write_tx,
123 pending,
124 next_id: AtomicU64::new(1),
125 connected,
126 runtime_handle,
127 }
128 }
129
130 fn handle_response(pending: &Arc<Mutex<HashMap<u64, PendingRequest>>>, resp: AgentResponse) {
132 let mut pending = pending.lock().unwrap();
133
134 if let Some(req) = pending.get(&resp.id) {
135 if let Some(data) = resp.data {
136 let _ = req.data_tx.try_send(data);
138 }
139
140 if let Some(result) = resp.result {
141 if let Some(req) = pending.remove(&resp.id) {
143 let _ = req.result_tx.send(Ok(result));
144 }
145 } else if let Some(error) = resp.error {
146 if let Some(req) = pending.remove(&resp.id) {
148 let _ = req.result_tx.send(Err(error));
149 }
150 }
151 }
152 }
153
154 pub fn is_connected(&self) -> bool {
156 self.connected.load(Ordering::SeqCst)
157 }
158
159 pub async fn request(
161 &self,
162 method: &str,
163 params: serde_json::Value,
164 ) -> Result<serde_json::Value, ChannelError> {
165 let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
166
167 while data_rx.recv().await.is_some() {}
169
170 result_rx
172 .await
173 .map_err(|_| ChannelError::ChannelClosed)?
174 .map_err(ChannelError::Remote)
175 }
176
177 pub async fn request_streaming(
179 &self,
180 method: &str,
181 params: serde_json::Value,
182 ) -> Result<
183 (
184 mpsc::Receiver<serde_json::Value>,
185 oneshot::Receiver<Result<serde_json::Value, String>>,
186 ),
187 ChannelError,
188 > {
189 if !self.is_connected() {
190 return Err(ChannelError::ChannelClosed);
191 }
192
193 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
194
195 let (data_tx, data_rx) = mpsc::channel(64);
197 let (result_tx, result_rx) = oneshot::channel();
198
199 {
201 let mut pending = self.pending.lock().unwrap();
202 pending.insert(id, PendingRequest { data_tx, result_tx });
203 }
204
205 let req = AgentRequest::new(id, method, params);
207 self.write_tx
208 .send(req.to_json_line())
209 .await
210 .map_err(|_| ChannelError::ChannelClosed)?;
211
212 Ok((data_rx, result_rx))
213 }
214
215 pub fn request_blocking(
219 &self,
220 method: &str,
221 params: serde_json::Value,
222 ) -> Result<serde_json::Value, ChannelError> {
223 self.runtime_handle.block_on(self.request(method, params))
224 }
225
226 pub async fn request_with_data(
228 &self,
229 method: &str,
230 params: serde_json::Value,
231 ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
232 let (mut data_rx, result_rx) = self.request_streaming(method, params).await?;
233
234 let mut data = Vec::new();
236 while let Some(chunk) = data_rx.recv().await {
237 data.push(chunk);
238 }
239
240 let result = result_rx
242 .await
243 .map_err(|_| ChannelError::ChannelClosed)?
244 .map_err(ChannelError::Remote)?;
245
246 Ok((data, result))
247 }
248
249 pub fn request_with_data_blocking(
253 &self,
254 method: &str,
255 params: serde_json::Value,
256 ) -> Result<(Vec<serde_json::Value>, serde_json::Value), ChannelError> {
257 self.runtime_handle
258 .block_on(self.request_with_data(method, params))
259 }
260
261 pub async fn cancel(&self, request_id: u64) -> Result<(), ChannelError> {
263 use crate::services::remote::protocol::cancel_params;
264 self.request("cancel", cancel_params(request_id)).await?;
265 Ok(())
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 }