1use std::collections::VecDeque;
2use std::future::Future;
3use std::time::Duration;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use tokio::sync::Mutex;
8
9#[cfg(feature = "process")]
10use crate::transport::MAX_PENDING_MESSAGES;
11#[cfg(not(feature = "process"))]
12const MAX_PENDING_MESSAGES: usize = 1024;
13
14use crate::error::WireError;
15use crate::protocol::{
16 CancelParams, CancelResult, InitializeParams, InitializeResult, JsonRpcRequest, PromptParams,
17 PromptResult, ReplayParams, ReplayResult, SetPlanModeParams, SetPlanModeResult, SteerParams,
18 SteerResult, UserInput,
19};
20
21pub trait WireClient: Send {
26 fn next_id(&mut self) -> String;
28
29 fn send_request<Params: Serialize + Sync>(
31 &mut self,
32 req: &JsonRpcRequest<Params>,
33 ) -> impl Future<Output = Result<(), WireError>> + Send;
34
35 fn read_raw_message(&mut self) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
37
38 fn read_raw_message_timeout(
40 &mut self,
41 timeout: Duration,
42 ) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
43
44 fn send_response<T: Serialize + Send>(
46 &mut self,
47 id: &str,
48 result: T,
49 ) -> impl Future<Output = Result<(), WireError>> + Send;
50
51 fn send_error(
53 &mut self,
54 id: &str,
55 code: i32,
56 message: &str,
57 ) -> impl Future<Output = Result<(), WireError>> + Send;
58
59 fn initialize(
61 &mut self,
62 params: InitializeParams,
63 ) -> impl Future<Output = Result<InitializeResult, WireError>> + Send;
64
65 fn is_handshake_done(&self) -> bool;
67
68 fn shutdown(self) -> impl Future<Output = Result<(), WireError>> + Send;
70
71 fn prompt(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<PromptResult, WireError>> + Send {
73 async move {
74 let id = self.start_prompt(user_input).await?;
75 self.read_response(&id).await
76 }
77 }
78
79 fn start_prompt(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<String, WireError>> + Send {
81 async move {
82 let id = self.next_id();
83 let req = JsonRpcRequest {
84 jsonrpc: crate::protocol::JsonRpcVersion::V2,
85 method: "prompt".to_string(),
86 id: id.clone(),
87 params: PromptParams {
88 user_input: user_input.into(),
89 },
90 };
91 self.send_request(&req).await?;
92 Ok(id)
93 }
94 }
95
96 fn replay(&mut self) -> impl Future<Output = Result<ReplayResult, WireError>> + Send {
98 async move {
99 let id = self.next_id();
100 let req = JsonRpcRequest {
101 jsonrpc: crate::protocol::JsonRpcVersion::V2,
102 method: "replay".to_string(),
103 id: id.clone(),
104 params: ReplayParams::default(),
105 };
106 self.send_request(&req).await?;
107 self.read_response(&id).await
108 }
109 }
110
111 fn steer(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<SteerResult, WireError>> + Send {
113 async move {
114 let id = self.next_id();
115 let req = JsonRpcRequest {
116 jsonrpc: crate::protocol::JsonRpcVersion::V2,
117 method: "steer".to_string(),
118 id: id.clone(),
119 params: SteerParams {
120 user_input: user_input.into(),
121 },
122 };
123 self.send_request(&req).await?;
124 self.read_response(&id).await
125 }
126 }
127
128 fn set_plan_mode(
130 &mut self,
131 enabled: bool,
132 ) -> impl Future<Output = Result<SetPlanModeResult, WireError>> + Send {
133 async move {
134 let id = self.next_id();
135 let req = JsonRpcRequest {
136 jsonrpc: crate::protocol::JsonRpcVersion::V2,
137 method: "set_plan_mode".to_string(),
138 id: id.clone(),
139 params: SetPlanModeParams { enabled },
140 };
141 self.send_request(&req).await?;
142 self.read_response(&id).await
143 }
144 }
145
146 fn cancel(&mut self) -> impl Future<Output = Result<(), WireError>> + Send {
148 async move {
149 let id = self.next_id();
150 let req = JsonRpcRequest {
151 jsonrpc: crate::protocol::JsonRpcVersion::V2,
152 method: "cancel".to_string(),
153 id: id.clone(),
154 params: CancelParams::default(),
155 };
156 self.send_request(&req).await?;
157 let _: CancelResult = self.read_response(&id).await?;
158 Ok(())
159 }
160 }
161
162 fn read_response<T: DeserializeOwned + Send>(
165 &mut self,
166 expected_id: &str,
167 ) -> impl Future<Output = Result<T, WireError>> + Send;
168}
169
170#[derive(Debug)]
179pub struct InMemoryWireClient {
180 incoming: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
181 pending: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
182 outgoing: Mutex<Vec<serde_json::Value>>,
183 handshake_done: bool,
184 request_counter: u64,
185 default_timeout: Option<Duration>,
186}
187
188impl Default for InMemoryWireClient {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194impl InMemoryWireClient {
195 pub fn new() -> Self {
197 Self {
198 incoming: Mutex::new(VecDeque::new()),
199 pending: Mutex::new(VecDeque::new()),
200 outgoing: Mutex::new(Vec::new()),
201 handshake_done: false,
202 request_counter: 0,
203 default_timeout: None,
204 }
205 }
206
207 pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
210 self.default_timeout = Some(timeout);
211 self
212 }
213
214 pub async fn inject(&self, msg: crate::protocol::RawWireMessage) {
216 self.incoming.lock().await.push_back(msg);
217 }
218
219 pub async fn outgoing(&self) -> Vec<serde_json::Value> {
221 self.outgoing.lock().await.clone()
222 }
223}
224
225impl WireClient for InMemoryWireClient {
226 fn next_id(&mut self) -> String {
227 self.request_counter += 1;
228 format!("req-{}", self.request_counter)
229 }
230
231 async fn send_request<Params: Serialize + Sync>(
232 &mut self,
233 req: &JsonRpcRequest<Params>,
234 ) -> Result<(), WireError> {
235 let value = serde_json::to_value(req).map_err(WireError::from)?;
236 self.outgoing.lock().await.push(value);
237 Ok(())
238 }
239
240 async fn read_raw_message(&mut self) -> Result<crate::protocol::RawWireMessage, WireError> {
241 if let Some(msg) = self.pending.lock().await.pop_front() {
242 return Ok(msg);
243 }
244 match self.incoming.lock().await.pop_front() {
245 Some(msg) => Ok(msg),
246 None => Err(WireError::StreamClosed),
247 }
248 }
249
250 async fn read_raw_message_timeout(
251 &mut self,
252 timeout: Duration,
253 ) -> Result<crate::protocol::RawWireMessage, WireError> {
254 match tokio::time::timeout(timeout, self.read_raw_message()).await {
255 Ok(msg) => msg,
256 Err(_) => Err(WireError::Timeout(timeout)),
257 }
258 }
259
260 async fn read_response<T: DeserializeOwned + Send>(
261 &mut self,
262 expected_id: &str,
263 ) -> Result<T, WireError> {
264 let fut = async {
265 loop {
266 let idx = {
267 let lock = self.pending.lock().await;
268 lock.iter()
269 .position(|msg| msg.id.as_deref() == Some(expected_id))
270 };
271 if let Some(idx) = idx {
272 let msg = self
273 .pending
274 .lock()
275 .await
276 .remove(idx)
277 .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
278 return decode_response(msg, expected_id);
279 }
280
281 match self.incoming.lock().await.pop_front() {
282 Some(msg) if msg.id.as_deref() == Some(expected_id) => {
283 return decode_response(msg, expected_id);
284 }
285 Some(other) => {
286 let mut pending = self.pending.lock().await;
287 if pending.len() >= MAX_PENDING_MESSAGES {
288 return Err(WireError::Internal(format!(
289 "pending message buffer overflow ({} entries) waiting for id {:?}",
290 MAX_PENDING_MESSAGES, expected_id
291 )));
292 }
293 pending.push_back(other);
294 }
295 None => return Err(WireError::StreamClosed),
296 }
297 }
298 };
299
300 match self.default_timeout {
301 Some(d) => tokio::time::timeout(d, fut)
302 .await
303 .map_err(|_| WireError::Timeout(d))?,
304 None => fut.await,
305 }
306 }
307
308 async fn send_response<T: Serialize + Send>(
309 &mut self,
310 id: &str,
311 result: T,
312 ) -> Result<(), WireError> {
313 let resp = crate::protocol::JsonRpcSuccessResponse {
314 jsonrpc: crate::protocol::JsonRpcVersion::V2,
315 id: id.to_string(),
316 result,
317 };
318 let line = format!("{}\n", serde_json::to_string(&resp).map_err(WireError::from)?);
319 self.outgoing
320 .lock()
321 .await
322 .push(serde_json::Value::String(line));
323 Ok(())
324 }
325
326 async fn send_error(
327 &mut self,
328 id: &str,
329 code: i32,
330 message: &str,
331 ) -> Result<(), WireError> {
332 let resp = crate::protocol::JsonRpcErrorResponse {
333 jsonrpc: crate::protocol::JsonRpcVersion::V2,
334 id: id.to_string(),
335 error: crate::protocol::JsonRpcError {
336 code,
337 message: message.to_string(),
338 data: None,
339 },
340 };
341 let line = format!("{}\n", serde_json::to_string(&resp).map_err(WireError::from)?);
342 self.outgoing
343 .lock()
344 .await
345 .push(serde_json::Value::String(line));
346 Ok(())
347 }
348
349 async fn initialize(
350 &mut self,
351 _params: InitializeParams,
352 ) -> Result<InitializeResult, WireError> {
353 self.handshake_done = true;
354 Ok(InitializeResult {
355 protocol_version: crate::WIRE_PROTOCOL_VERSION.to_string(),
356 server: crate::protocol::ServerInfo {
357 name: "test-server".to_string(),
358 version: "0.0.0".to_string(),
359 },
360 slash_commands: vec![],
361 external_tools: None,
362 capabilities: None,
363 hooks: None,
364 })
365 }
366
367 fn is_handshake_done(&self) -> bool {
368 self.handshake_done
369 }
370
371 async fn shutdown(self) -> Result<(), WireError> {
372 Ok(())
373 }
374}
375
376fn decode_response<T: DeserializeOwned>(
377 msg: crate::protocol::RawWireMessage,
378 _expected_id: &str,
379) -> Result<T, WireError> {
380 if let Some(error) = msg.error {
381 return Err(WireError::RequestFailed {
382 code: error.code,
383 message: error.message,
384 });
385 }
386 let result = msg
387 .result
388 .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
389 serde_json::from_value(result).map_err(WireError::from)
390}