1use std::collections::VecDeque;
2use std::future::Future;
3use std::time::Duration;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use tokio::sync::Mutex;
8
9#[cfg(feature = "process")]
10use crate::transport::MAX_PENDING_MESSAGES;
11#[cfg(not(feature = "process"))]
12const MAX_PENDING_MESSAGES: usize = 1024;
13
14use crate::error::WireError;
15use crate::protocol::{
16 CancelParams, CancelResult, InitializeParams, InitializeResult, JsonRpcRequest, PromptParams,
17 PromptResult, ReplayParams, ReplayResult, SetPlanModeParams, SetPlanModeResult, SteerParams,
18 SteerResult, UserInput,
19};
20
21pub trait WireClient: Send {
26 fn next_id(&mut self) -> String;
28
29 fn send_request<Params: Serialize + Sync>(
31 &mut self,
32 req: &JsonRpcRequest<Params>,
33 ) -> impl Future<Output = Result<(), WireError>> + Send;
34
35 fn read_raw_message(
37 &mut self,
38 ) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
39
40 fn read_raw_message_timeout(
42 &mut self,
43 timeout: Duration,
44 ) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
45
46 fn send_response<T: Serialize + Send>(
48 &mut self,
49 id: &str,
50 result: T,
51 ) -> impl Future<Output = Result<(), WireError>> + Send;
52
53 fn send_error(
55 &mut self,
56 id: &str,
57 code: i32,
58 message: &str,
59 ) -> impl Future<Output = Result<(), WireError>> + Send;
60
61 fn initialize(
63 &mut self,
64 params: InitializeParams,
65 ) -> impl Future<Output = Result<InitializeResult, WireError>> + Send;
66
67 fn is_handshake_done(&self) -> bool;
69
70 fn shutdown(self) -> impl Future<Output = Result<(), WireError>> + Send;
72
73 fn prompt(
75 &mut self,
76 user_input: impl Into<UserInput> + Send,
77 ) -> impl Future<Output = Result<PromptResult, WireError>> + Send {
78 async move {
79 let id = self.start_prompt(user_input).await?;
80 self.read_response(&id).await
81 }
82 }
83
84 fn start_prompt(
86 &mut self,
87 user_input: impl Into<UserInput> + Send,
88 ) -> impl Future<Output = Result<String, WireError>> + Send {
89 async move {
90 let id = self.next_id();
91 let req = JsonRpcRequest {
92 jsonrpc: crate::protocol::JsonRpcVersion::V2,
93 method: "prompt".to_string(),
94 id: id.clone(),
95 params: PromptParams {
96 user_input: user_input.into(),
97 },
98 };
99 self.send_request(&req).await?;
100 Ok(id)
101 }
102 }
103
104 fn replay(&mut self) -> impl Future<Output = Result<ReplayResult, WireError>> + Send {
106 async move {
107 let id = self.next_id();
108 let req = JsonRpcRequest {
109 jsonrpc: crate::protocol::JsonRpcVersion::V2,
110 method: "replay".to_string(),
111 id: id.clone(),
112 params: ReplayParams::default(),
113 };
114 self.send_request(&req).await?;
115 self.read_response(&id).await
116 }
117 }
118
119 fn steer(
121 &mut self,
122 user_input: impl Into<UserInput> + Send,
123 ) -> impl Future<Output = Result<SteerResult, WireError>> + Send {
124 async move {
125 let id = self.next_id();
126 let req = JsonRpcRequest {
127 jsonrpc: crate::protocol::JsonRpcVersion::V2,
128 method: "steer".to_string(),
129 id: id.clone(),
130 params: SteerParams {
131 user_input: user_input.into(),
132 },
133 };
134 self.send_request(&req).await?;
135 self.read_response(&id).await
136 }
137 }
138
139 fn set_plan_mode(
141 &mut self,
142 enabled: bool,
143 ) -> impl Future<Output = Result<SetPlanModeResult, WireError>> + Send {
144 async move {
145 let id = self.next_id();
146 let req = JsonRpcRequest {
147 jsonrpc: crate::protocol::JsonRpcVersion::V2,
148 method: "set_plan_mode".to_string(),
149 id: id.clone(),
150 params: SetPlanModeParams { enabled },
151 };
152 self.send_request(&req).await?;
153 self.read_response(&id).await
154 }
155 }
156
157 fn cancel(&mut self) -> impl Future<Output = Result<(), WireError>> + Send {
159 async move {
160 let id = self.next_id();
161 let req = JsonRpcRequest {
162 jsonrpc: crate::protocol::JsonRpcVersion::V2,
163 method: "cancel".to_string(),
164 id: id.clone(),
165 params: CancelParams::default(),
166 };
167 self.send_request(&req).await?;
168 let _: CancelResult = self.read_response(&id).await?;
169 Ok(())
170 }
171 }
172
173 fn read_response<T: DeserializeOwned + Send>(
176 &mut self,
177 expected_id: &str,
178 ) -> impl Future<Output = Result<T, WireError>> + Send;
179}
180
181#[derive(Debug)]
190pub struct InMemoryWireClient {
191 incoming: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
192 pending: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
193 outgoing: Mutex<Vec<serde_json::Value>>,
194 handshake_done: bool,
195 request_counter: u64,
196 default_timeout: Option<Duration>,
197}
198
199impl Default for InMemoryWireClient {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205impl InMemoryWireClient {
206 #[must_use]
208 pub fn new() -> Self {
209 Self {
210 incoming: Mutex::new(VecDeque::new()),
211 pending: Mutex::new(VecDeque::new()),
212 outgoing: Mutex::new(Vec::new()),
213 handshake_done: false,
214 request_counter: 0,
215 default_timeout: None,
216 }
217 }
218
219 #[must_use]
222 pub const fn with_default_timeout(mut self, timeout: Duration) -> Self {
223 self.default_timeout = Some(timeout);
224 self
225 }
226
227 pub async fn inject(&self, msg: crate::protocol::RawWireMessage) {
229 self.incoming.lock().await.push_back(msg);
230 }
231
232 pub async fn outgoing(&self) -> Vec<serde_json::Value> {
234 self.outgoing.lock().await.clone()
235 }
236}
237
238impl WireClient for InMemoryWireClient {
239 fn next_id(&mut self) -> String {
240 self.request_counter += 1;
241 format!("req-{}", self.request_counter)
242 }
243
244 async fn send_request<Params: Serialize + Sync>(
245 &mut self,
246 req: &JsonRpcRequest<Params>,
247 ) -> Result<(), WireError> {
248 let value = serde_json::to_value(req).map_err(WireError::from)?;
249 self.outgoing.lock().await.push(value);
250 Ok(())
251 }
252
253 async fn read_raw_message(&mut self) -> Result<crate::protocol::RawWireMessage, WireError> {
254 let msg = self.pending.lock().await.pop_front();
255 if let Some(msg) = msg {
256 return Ok(msg);
257 }
258 let msg = self.incoming.lock().await.pop_front();
259 match msg {
260 Some(msg) => Ok(msg),
261 None => Err(WireError::StreamClosed),
262 }
263 }
264
265 async fn read_raw_message_timeout(
266 &mut self,
267 timeout: Duration,
268 ) -> Result<crate::protocol::RawWireMessage, WireError> {
269 match tokio::time::timeout(timeout, self.read_raw_message()).await {
270 Ok(msg) => msg,
271 Err(_) => Err(WireError::Timeout(timeout)),
272 }
273 }
274
275 async fn read_response<T: DeserializeOwned + Send>(
276 &mut self,
277 expected_id: &str,
278 ) -> Result<T, WireError> {
279 let fut = async {
280 loop {
281 let idx = {
282 let lock = self.pending.lock().await;
283 lock.iter()
284 .position(|msg| msg.id.as_deref() == Some(expected_id))
285 };
286 if let Some(idx) = idx {
287 let msg =
288 self.pending.lock().await.remove(idx).ok_or_else(|| {
289 WireError::Internal("pending index invalid".to_string())
290 })?;
291 return decode_response(msg, expected_id);
292 }
293
294 let msg = self.incoming.lock().await.pop_front();
295 match msg {
296 Some(msg) if msg.id.as_deref() == Some(expected_id) => {
297 return decode_response(msg, expected_id);
298 }
299 Some(other) => {
300 let mut pending = self.pending.lock().await;
301 if pending.len() >= MAX_PENDING_MESSAGES {
302 return Err(WireError::Internal(format!(
303 "pending message buffer overflow ({} entries) waiting for id {:?}",
304 MAX_PENDING_MESSAGES, expected_id
305 )));
306 }
307 pending.push_back(other);
308 }
309 None => return Err(WireError::StreamClosed),
310 }
311 }
312 };
313
314 match self.default_timeout {
315 Some(d) => tokio::time::timeout(d, fut)
316 .await
317 .map_err(|_| WireError::Timeout(d))?,
318 None => fut.await,
319 }
320 }
321
322 async fn send_response<T: Serialize + Send>(
323 &mut self,
324 id: &str,
325 result: T,
326 ) -> Result<(), WireError> {
327 let resp = crate::protocol::JsonRpcSuccessResponse {
328 jsonrpc: crate::protocol::JsonRpcVersion::V2,
329 id: id.to_string(),
330 result,
331 };
332 let line = format!(
333 "{}\n",
334 serde_json::to_string(&resp).map_err(WireError::from)?
335 );
336 self.outgoing
337 .lock()
338 .await
339 .push(serde_json::Value::String(line));
340 Ok(())
341 }
342
343 async fn send_error(&mut self, id: &str, code: i32, message: &str) -> Result<(), WireError> {
344 let resp = crate::protocol::JsonRpcErrorResponse {
345 jsonrpc: crate::protocol::JsonRpcVersion::V2,
346 id: id.to_string(),
347 error: crate::protocol::JsonRpcError {
348 code,
349 message: message.to_string(),
350 data: None,
351 },
352 };
353 let line = format!(
354 "{}\n",
355 serde_json::to_string(&resp).map_err(WireError::from)?
356 );
357 self.outgoing
358 .lock()
359 .await
360 .push(serde_json::Value::String(line));
361 Ok(())
362 }
363
364 async fn initialize(
365 &mut self,
366 _params: InitializeParams,
367 ) -> Result<InitializeResult, WireError> {
368 self.handshake_done = true;
369 Ok(InitializeResult {
370 protocol_version: crate::WIRE_PROTOCOL_VERSION.to_string(),
371 server: crate::protocol::ServerInfo {
372 name: "test-server".to_string(),
373 version: "0.0.0".to_string(),
374 },
375 slash_commands: vec![],
376 external_tools: None,
377 capabilities: None,
378 hooks: None,
379 })
380 }
381
382 fn is_handshake_done(&self) -> bool {
383 self.handshake_done
384 }
385
386 async fn shutdown(self) -> Result<(), WireError> {
387 Ok(())
388 }
389}
390
391fn decode_response<T: DeserializeOwned>(
392 msg: crate::protocol::RawWireMessage,
393 _expected_id: &str,
394) -> Result<T, WireError> {
395 if let Some(error) = msg.error {
396 return Err(WireError::RequestFailed {
397 code: error.code,
398 message: error.message,
399 });
400 }
401 let result = msg
402 .result
403 .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
404 serde_json::from_value(result).map_err(WireError::from)
405}