1use std::collections::VecDeque;
2use std::future::Future;
3use std::time::Duration;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use tokio::sync::Mutex;
8
9#[cfg(feature = "process")]
10use crate::transport::MAX_PENDING_MESSAGES;
11#[cfg(not(feature = "process"))]
12const MAX_PENDING_MESSAGES: usize = 1024;
13
14use crate::error::WireError;
15use crate::protocol::{
16 CancelParams, CancelResult, InitializeParams, InitializeResult, JsonRpcRequest, PromptParams,
17 PromptResult, ReplayParams, ReplayResult, SetPlanModeParams, SetPlanModeResult, SteerParams,
18 SteerResult, UserInput,
19};
20
21pub trait WireClient: Send {
26 fn next_id(&mut self) -> String;
28
29 fn send_request<Params: Serialize + Sync>(
31 &mut self,
32 req: &JsonRpcRequest<Params>,
33 ) -> impl Future<Output = Result<(), WireError>> + Send;
34
35 fn read_raw_message(
37 &mut self,
38 ) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
39
40 fn read_raw_message_timeout(
42 &mut self,
43 timeout: Duration,
44 ) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
45
46 fn send_response<T: Serialize + Send>(
48 &mut self,
49 id: &str,
50 result: T,
51 ) -> impl Future<Output = Result<(), WireError>> + Send;
52
53 fn send_error(
55 &mut self,
56 id: &str,
57 code: i32,
58 message: &str,
59 ) -> impl Future<Output = Result<(), WireError>> + Send;
60
61 fn initialize(
63 &mut self,
64 params: InitializeParams,
65 ) -> impl Future<Output = Result<InitializeResult, WireError>> + Send;
66
67 fn is_handshake_done(&self) -> bool;
69
70 fn shutdown(self) -> impl Future<Output = Result<(), WireError>> + Send;
72
73 fn prompt(
75 &mut self,
76 user_input: impl Into<UserInput> + Send,
77 ) -> impl Future<Output = Result<PromptResult, WireError>> + Send {
78 async move {
79 let id = self.start_prompt(user_input).await?;
80 self.read_response(&id).await
81 }
82 }
83
84 fn start_prompt(
86 &mut self,
87 user_input: impl Into<UserInput> + Send,
88 ) -> impl Future<Output = Result<String, WireError>> + Send {
89 async move {
90 let id = self.next_id();
91 let req = JsonRpcRequest {
92 jsonrpc: crate::protocol::JsonRpcVersion::V2,
93 method: "prompt".to_string(),
94 id: id.clone(),
95 params: PromptParams {
96 user_input: user_input.into(),
97 },
98 };
99 self.send_request(&req).await?;
100 Ok(id)
101 }
102 }
103
104 fn replay(&mut self) -> impl Future<Output = Result<ReplayResult, WireError>> + Send {
106 async move {
107 let id = self.next_id();
108 let req = JsonRpcRequest {
109 jsonrpc: crate::protocol::JsonRpcVersion::V2,
110 method: "replay".to_string(),
111 id: id.clone(),
112 params: ReplayParams::default(),
113 };
114 self.send_request(&req).await?;
115 self.read_response(&id).await
116 }
117 }
118
119 fn steer(
121 &mut self,
122 user_input: impl Into<UserInput> + Send,
123 ) -> impl Future<Output = Result<SteerResult, WireError>> + Send {
124 async move {
125 let id = self.next_id();
126 let req = JsonRpcRequest {
127 jsonrpc: crate::protocol::JsonRpcVersion::V2,
128 method: "steer".to_string(),
129 id: id.clone(),
130 params: SteerParams {
131 user_input: user_input.into(),
132 },
133 };
134 self.send_request(&req).await?;
135 self.read_response(&id).await
136 }
137 }
138
139 fn set_plan_mode(
141 &mut self,
142 enabled: bool,
143 ) -> impl Future<Output = Result<SetPlanModeResult, WireError>> + Send {
144 async move {
145 let id = self.next_id();
146 let req = JsonRpcRequest {
147 jsonrpc: crate::protocol::JsonRpcVersion::V2,
148 method: "set_plan_mode".to_string(),
149 id: id.clone(),
150 params: SetPlanModeParams { enabled },
151 };
152 self.send_request(&req).await?;
153 self.read_response(&id).await
154 }
155 }
156
157 fn cancel(&mut self) -> impl Future<Output = Result<(), WireError>> + Send {
159 async move {
160 let id = self.next_id();
161 let req = JsonRpcRequest {
162 jsonrpc: crate::protocol::JsonRpcVersion::V2,
163 method: "cancel".to_string(),
164 id: id.clone(),
165 params: CancelParams::default(),
166 };
167 self.send_request(&req).await?;
168 let _: CancelResult = self.read_response(&id).await?;
169 Ok(())
170 }
171 }
172
173 fn read_response<T: DeserializeOwned + Send>(
176 &mut self,
177 expected_id: &str,
178 ) -> impl Future<Output = Result<T, WireError>> + Send;
179}
180
181#[derive(Debug)]
190pub struct InMemoryWireClient {
191 incoming: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
192 pending: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
193 outgoing: Mutex<Vec<serde_json::Value>>,
194 handshake_done: bool,
195 request_counter: u64,
196 default_timeout: Option<Duration>,
197}
198
199impl Default for InMemoryWireClient {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205impl InMemoryWireClient {
206 #[must_use]
208 pub fn new() -> Self {
209 Self {
210 incoming: Mutex::new(VecDeque::new()),
211 pending: Mutex::new(VecDeque::new()),
212 outgoing: Mutex::new(Vec::new()),
213 handshake_done: false,
214 request_counter: 0,
215 default_timeout: None,
216 }
217 }
218
219 #[must_use]
222 pub const fn with_default_timeout(mut self, timeout: Duration) -> Self {
223 self.default_timeout = Some(timeout);
224 self
225 }
226
227 pub async fn inject(&self, msg: crate::protocol::RawWireMessage) {
229 self.incoming.lock().await.push_back(msg);
230 }
231
232 pub async fn outgoing(&self) -> Vec<serde_json::Value> {
234 self.outgoing.lock().await.clone()
235 }
236}
237
238impl WireClient for InMemoryWireClient {
239 fn next_id(&mut self) -> String {
240 self.request_counter += 1;
241 format!("req-{}", self.request_counter)
242 }
243
244 async fn send_request<Params: Serialize + Sync>(
245 &mut self,
246 req: &JsonRpcRequest<Params>,
247 ) -> Result<(), WireError> {
248 let value = serde_json::to_value(req).map_err(WireError::from)?;
249 self.outgoing.lock().await.push(value);
250 Ok(())
251 }
252
253 async fn read_raw_message(&mut self) -> Result<crate::protocol::RawWireMessage, WireError> {
254 let msg = self.pending.lock().await.pop_front();
255 if let Some(msg) = msg {
256 return Ok(msg);
257 }
258 self.incoming
259 .lock()
260 .await
261 .pop_front()
262 .map_or_else(|| Err(WireError::StreamClosed), Ok)
263 }
264
265 async fn read_raw_message_timeout(
266 &mut self,
267 timeout: Duration,
268 ) -> Result<crate::protocol::RawWireMessage, WireError> {
269 tokio::time::timeout(timeout, self.read_raw_message())
270 .await
271 .map_or(Err(WireError::Timeout(timeout)), |msg| msg)
272 }
273
274 async fn read_response<T: DeserializeOwned + Send>(
275 &mut self,
276 expected_id: &str,
277 ) -> Result<T, WireError> {
278 let fut = async {
279 loop {
280 let idx = {
281 let lock = self.pending.lock().await;
282 lock.iter()
283 .position(|msg| msg.id.as_deref() == Some(expected_id))
284 };
285 if let Some(idx) = idx {
286 let msg =
287 self.pending.lock().await.remove(idx).ok_or_else(|| {
288 WireError::Internal("pending index invalid".to_string())
289 })?;
290 return decode_response(msg, expected_id);
291 }
292
293 let msg = self.incoming.lock().await.pop_front();
294 match msg {
295 Some(msg) if msg.id.as_deref() == Some(expected_id) => {
296 return decode_response(msg, expected_id);
297 }
298 Some(other) => {
299 let mut pending = self.pending.lock().await;
300 if pending.len() >= MAX_PENDING_MESSAGES {
301 return Err(WireError::Internal(format!(
302 "pending message buffer overflow ({MAX_PENDING_MESSAGES} entries) waiting for id {expected_id:?}"
303 )));
304 }
305 pending.push_back(other);
306 }
307 None => return Err(WireError::StreamClosed),
308 }
309 }
310 };
311
312 match self.default_timeout {
313 Some(d) => tokio::time::timeout(d, fut)
314 .await
315 .map_err(|_| WireError::Timeout(d))?,
316 None => fut.await,
317 }
318 }
319
320 async fn send_response<T: Serialize + Send>(
321 &mut self,
322 id: &str,
323 result: T,
324 ) -> Result<(), WireError> {
325 let resp = crate::protocol::JsonRpcSuccessResponse {
326 jsonrpc: crate::protocol::JsonRpcVersion::V2,
327 id: id.to_string(),
328 result,
329 };
330 let line = format!(
331 "{}\n",
332 serde_json::to_string(&resp).map_err(WireError::from)?
333 );
334 self.outgoing
335 .lock()
336 .await
337 .push(serde_json::Value::String(line));
338 Ok(())
339 }
340
341 async fn send_error(&mut self, id: &str, code: i32, message: &str) -> Result<(), WireError> {
342 let resp = crate::protocol::JsonRpcErrorResponse {
343 jsonrpc: crate::protocol::JsonRpcVersion::V2,
344 id: id.to_string(),
345 error: crate::protocol::JsonRpcError {
346 code,
347 message: message.to_string(),
348 data: None,
349 },
350 };
351 let line = format!(
352 "{}\n",
353 serde_json::to_string(&resp).map_err(WireError::from)?
354 );
355 self.outgoing
356 .lock()
357 .await
358 .push(serde_json::Value::String(line));
359 Ok(())
360 }
361
362 async fn initialize(
363 &mut self,
364 _params: InitializeParams,
365 ) -> Result<InitializeResult, WireError> {
366 self.handshake_done = true;
367 Ok(InitializeResult {
368 protocol_version: crate::WIRE_PROTOCOL_VERSION.to_string(),
369 server: crate::protocol::ServerInfo {
370 name: "test-server".to_string(),
371 version: "0.0.0".to_string(),
372 },
373 slash_commands: vec![],
374 external_tools: None,
375 capabilities: None,
376 hooks: None,
377 })
378 }
379
380 fn is_handshake_done(&self) -> bool {
381 self.handshake_done
382 }
383
384 async fn shutdown(self) -> Result<(), WireError> {
385 Ok(())
386 }
387}
388
389fn decode_response<T: DeserializeOwned>(
390 msg: crate::protocol::RawWireMessage,
391 _expected_id: &str,
392) -> Result<T, WireError> {
393 if let Some(error) = msg.error {
394 return Err(WireError::RequestFailed {
395 code: error.code,
396 message: error.message,
397 });
398 }
399 let result = msg
400 .result
401 .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
402 serde_json::from_value(result).map_err(WireError::from)
403}