1use std::collections::VecDeque;
2use std::future::Future;
3use std::time::Duration;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use tokio::sync::Mutex;
8
9use crate::error::WireError;
10use crate::protocol::{
11 CancelParams, CancelResult, InitializeParams, InitializeResult, JsonRpcRequest, PromptParams,
12 PromptResult, ReplayParams, ReplayResult, SetPlanModeParams, SetPlanModeResult, SteerParams,
13 SteerResult, UserInput,
14};
15
16pub trait WireClient: Send {
21 fn next_id(&mut self) -> String;
23
24 fn send_request<Params: Serialize + Sync>(
26 &mut self,
27 req: &JsonRpcRequest<Params>,
28 ) -> impl Future<Output = Result<(), WireError>> + Send;
29
30 fn read_raw_message(&mut self) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
32
33 fn read_raw_message_timeout(
35 &mut self,
36 timeout: Duration,
37 ) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
38
39 fn send_response<T: Serialize + Send>(
41 &mut self,
42 id: &str,
43 result: T,
44 ) -> impl Future<Output = Result<(), WireError>> + Send;
45
46 fn send_error(
48 &mut self,
49 id: &str,
50 code: i32,
51 message: &str,
52 ) -> impl Future<Output = Result<(), WireError>> + Send;
53
54 fn initialize(
56 &mut self,
57 params: InitializeParams,
58 ) -> impl Future<Output = Result<InitializeResult, WireError>> + Send;
59
60 fn is_handshake_done(&self) -> bool;
62
63 fn shutdown(self) -> impl Future<Output = Result<(), WireError>> + Send;
65
66 fn prompt(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<PromptResult, WireError>> + Send {
68 async move {
69 let id = self.start_prompt(user_input).await?;
70 self.read_response(&id).await
71 }
72 }
73
74 fn start_prompt(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<String, WireError>> + Send {
76 async move {
77 let id = self.next_id();
78 let req = JsonRpcRequest {
79 jsonrpc: crate::protocol::JsonRpcVersion::default(),
80 method: "prompt".to_string(),
81 id: id.clone(),
82 params: PromptParams {
83 user_input: user_input.into(),
84 },
85 };
86 self.send_request(&req).await?;
87 Ok(id)
88 }
89 }
90
91 fn replay(&mut self) -> impl Future<Output = Result<ReplayResult, WireError>> + Send {
93 async move {
94 let id = self.next_id();
95 let req = JsonRpcRequest {
96 jsonrpc: crate::protocol::JsonRpcVersion::default(),
97 method: "replay".to_string(),
98 id: id.clone(),
99 params: ReplayParams::default(),
100 };
101 self.send_request(&req).await?;
102 self.read_response(&id).await
103 }
104 }
105
106 fn steer(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<SteerResult, WireError>> + Send {
108 async move {
109 let id = self.next_id();
110 let req = JsonRpcRequest {
111 jsonrpc: crate::protocol::JsonRpcVersion::default(),
112 method: "steer".to_string(),
113 id: id.clone(),
114 params: SteerParams {
115 user_input: user_input.into(),
116 },
117 };
118 self.send_request(&req).await?;
119 self.read_response(&id).await
120 }
121 }
122
123 fn set_plan_mode(
125 &mut self,
126 enabled: bool,
127 ) -> impl Future<Output = Result<SetPlanModeResult, WireError>> + Send {
128 async move {
129 let id = self.next_id();
130 let req = JsonRpcRequest {
131 jsonrpc: crate::protocol::JsonRpcVersion::default(),
132 method: "set_plan_mode".to_string(),
133 id: id.clone(),
134 params: SetPlanModeParams { enabled },
135 };
136 self.send_request(&req).await?;
137 self.read_response(&id).await
138 }
139 }
140
141 fn cancel(&mut self) -> impl Future<Output = Result<(), WireError>> + Send {
143 async move {
144 let id = self.next_id();
145 let req = JsonRpcRequest {
146 jsonrpc: crate::protocol::JsonRpcVersion::default(),
147 method: "cancel".to_string(),
148 id: id.clone(),
149 params: CancelParams::default(),
150 };
151 self.send_request(&req).await?;
152 let _: CancelResult = self.read_response(&id).await?;
153 Ok(())
154 }
155 }
156
157 fn read_response<T: DeserializeOwned + Send>(
160 &mut self,
161 expected_id: &str,
162 ) -> impl Future<Output = Result<T, WireError>> + Send;
163}
164
165#[derive(Debug)]
174pub struct InMemoryWireClient {
175 incoming: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
176 pending: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
177 outgoing: Mutex<Vec<serde_json::Value>>,
178 handshake_done: bool,
179 request_counter: u64,
180}
181
182impl Default for InMemoryWireClient {
183 fn default() -> Self {
184 Self::new()
185 }
186}
187
188impl InMemoryWireClient {
189 pub fn new() -> Self {
191 Self {
192 incoming: Mutex::new(VecDeque::new()),
193 pending: Mutex::new(VecDeque::new()),
194 outgoing: Mutex::new(Vec::new()),
195 handshake_done: false,
196 request_counter: 0,
197 }
198 }
199
200 pub async fn inject(&self, msg: crate::protocol::RawWireMessage) {
202 self.incoming.lock().await.push_back(msg);
203 }
204
205 pub async fn outgoing(&self) -> Vec<serde_json::Value> {
207 self.outgoing.lock().await.clone()
208 }
209}
210
211impl WireClient for InMemoryWireClient {
212 fn next_id(&mut self) -> String {
213 self.request_counter += 1;
214 format!("req-{}", self.request_counter)
215 }
216
217 async fn send_request<Params: Serialize + Sync>(
218 &mut self,
219 req: &JsonRpcRequest<Params>,
220 ) -> Result<(), WireError> {
221 let value = serde_json::to_value(req).map_err(WireError::from)?;
222 self.outgoing.lock().await.push(value);
223 Ok(())
224 }
225
226 async fn read_raw_message(&mut self) -> Result<crate::protocol::RawWireMessage, WireError> {
227 if let Some(msg) = self.pending.lock().await.pop_front() {
228 return Ok(msg);
229 }
230 match self.incoming.lock().await.pop_front() {
231 Some(msg) => Ok(msg),
232 None => Err(WireError::StreamClosed),
233 }
234 }
235
236 async fn read_raw_message_timeout(
237 &mut self,
238 timeout: Duration,
239 ) -> Result<crate::protocol::RawWireMessage, WireError> {
240 match tokio::time::timeout(timeout, self.read_raw_message()).await {
241 Ok(msg) => msg,
242 Err(_) => Err(WireError::Timeout(timeout)),
243 }
244 }
245
246 async fn read_response<T: DeserializeOwned + Send>(
247 &mut self,
248 expected_id: &str,
249 ) -> Result<T, WireError> {
250 loop {
251 let idx = {
252 let lock = self.pending.lock().await;
253 lock.iter()
254 .position(|msg| msg.id.as_deref() == Some(expected_id))
255 };
256 if let Some(idx) = idx {
257 let msg = self
258 .pending
259 .lock()
260 .await
261 .remove(idx)
262 .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
263 return decode_response(msg, expected_id);
264 }
265
266 match self.incoming.lock().await.pop_front() {
267 Some(msg) if msg.id.as_deref() == Some(expected_id) => {
268 return decode_response(msg, expected_id);
269 }
270 Some(other) => {
271 self.pending.lock().await.push_back(other);
272 }
273 None => return Err(WireError::StreamClosed),
274 }
275 }
276 }
277
278 async fn send_response<T: Serialize + Send>(
279 &mut self,
280 id: &str,
281 result: T,
282 ) -> Result<(), WireError> {
283 let resp = crate::protocol::JsonRpcSuccessResponse {
284 jsonrpc: crate::protocol::JsonRpcVersion::default(),
285 id: id.to_string(),
286 result,
287 };
288 let line = format!("{}\n", serde_json::to_string(&resp).map_err(WireError::from)?);
289 self.outgoing
290 .lock()
291 .await
292 .push(serde_json::Value::String(line));
293 Ok(())
294 }
295
296 async fn send_error(
297 &mut self,
298 id: &str,
299 code: i32,
300 message: &str,
301 ) -> Result<(), WireError> {
302 let resp = crate::protocol::JsonRpcErrorResponse {
303 jsonrpc: crate::protocol::JsonRpcVersion::default(),
304 id: id.to_string(),
305 error: crate::protocol::JsonRpcError {
306 code,
307 message: message.to_string(),
308 data: None,
309 },
310 };
311 let line = format!("{}\n", serde_json::to_string(&resp).map_err(WireError::from)?);
312 self.outgoing
313 .lock()
314 .await
315 .push(serde_json::Value::String(line));
316 Ok(())
317 }
318
319 async fn initialize(
320 &mut self,
321 _params: InitializeParams,
322 ) -> Result<InitializeResult, WireError> {
323 self.handshake_done = true;
324 Ok(InitializeResult {
325 protocol_version: crate::WIRE_PROTOCOL_VERSION.to_string(),
326 server: crate::protocol::ServerInfo {
327 name: "test-server".to_string(),
328 version: "0.0.0".to_string(),
329 },
330 slash_commands: vec![],
331 external_tools: None,
332 capabilities: None,
333 hooks: None,
334 })
335 }
336
337 fn is_handshake_done(&self) -> bool {
338 self.handshake_done
339 }
340
341 async fn shutdown(self) -> Result<(), WireError> {
342 Ok(())
343 }
344}
345
346fn decode_response<T: DeserializeOwned>(
347 msg: crate::protocol::RawWireMessage,
348 _expected_id: &str,
349) -> Result<T, WireError> {
350 if let Some(error) = msg.error {
351 return Err(WireError::RequestFailed {
352 code: error.code,
353 message: error.message,
354 });
355 }
356 let result = msg
357 .result
358 .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
359 serde_json::from_value(result).map_err(WireError::from)
360}