Skip to main content

kimi_wire/
client.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::time::Duration;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use tokio::sync::Mutex;
8
9#[cfg(feature = "process")]
10use crate::transport::MAX_PENDING_MESSAGES;
11#[cfg(not(feature = "process"))]
12const MAX_PENDING_MESSAGES: usize = 1024;
13
14use crate::error::WireError;
15use crate::protocol::{
16    CancelParams, CancelResult, InitializeParams, InitializeResult, JsonRpcRequest, PromptParams,
17    PromptResult, ReplayParams, ReplayResult, SetPlanModeParams, SetPlanModeResult, SteerParams,
18    SteerResult, UserInput,
19};
20
21/// Trait for a Kimi Wire Protocol client.
22///
23/// Implementations may communicate over a child process, an in-memory channel,
24/// or any other transport.
25pub trait WireClient: Send {
26    /// Generate the next request id.
27    fn next_id(&mut self) -> String;
28
29    /// Send a JSON-RPC request.
30    fn send_request<Params: Serialize + Sync>(
31        &mut self,
32        req: &JsonRpcRequest<Params>,
33    ) -> impl Future<Output = Result<(), WireError>> + Send;
34
35    /// Read the next incoming raw wire message.
36    fn read_raw_message(&mut self) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
37
38    /// Read the next incoming raw wire message with a timeout.
39    fn read_raw_message_timeout(
40        &mut self,
41        timeout: Duration,
42    ) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
43
44    /// Send a JSON-RPC success response.
45    fn send_response<T: Serialize + Send>(
46        &mut self,
47        id: &str,
48        result: T,
49    ) -> impl Future<Output = Result<(), WireError>> + Send;
50
51    /// Send a JSON-RPC error response.
52    fn send_error(
53        &mut self,
54        id: &str,
55        code: i32,
56        message: &str,
57    ) -> impl Future<Output = Result<(), WireError>> + Send;
58
59    /// Perform the initialize handshake.
60    fn initialize(
61        &mut self,
62        params: InitializeParams,
63    ) -> impl Future<Output = Result<InitializeResult, WireError>> + Send;
64
65    /// Returns true if the initialize handshake has completed.
66    fn is_handshake_done(&self) -> bool;
67
68    /// Gracefully shut down the client.
69    fn shutdown(self) -> impl Future<Output = Result<(), WireError>> + Send;
70
71    /// Send a prompt and wait for the result.
72    fn prompt(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<PromptResult, WireError>> + Send {
73        async move {
74            let id = self.start_prompt(user_input).await?;
75            self.read_response(&id).await
76        }
77    }
78
79    /// Send a prompt without waiting for the result.
80    fn start_prompt(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<String, WireError>> + Send {
81        async move {
82            let id = self.next_id();
83            let req = JsonRpcRequest {
84                jsonrpc: crate::protocol::JsonRpcVersion::V2,
85                method: "prompt".to_string(),
86                id: id.clone(),
87                params: PromptParams {
88                    user_input: user_input.into(),
89                },
90            };
91            self.send_request(&req).await?;
92            Ok(id)
93        }
94    }
95
96    /// Replay events and requests from the current session.
97    fn replay(&mut self) -> impl Future<Output = Result<ReplayResult, WireError>> + Send {
98        async move {
99            let id = self.next_id();
100            let req = JsonRpcRequest {
101                jsonrpc: crate::protocol::JsonRpcVersion::V2,
102                method: "replay".to_string(),
103                id: id.clone(),
104                params: ReplayParams::default(),
105            };
106            self.send_request(&req).await?;
107            self.read_response(&id).await
108        }
109    }
110
111    /// Steer the current turn with additional user input.
112    fn steer(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<SteerResult, WireError>> + Send {
113        async move {
114            let id = self.next_id();
115            let req = JsonRpcRequest {
116                jsonrpc: crate::protocol::JsonRpcVersion::V2,
117                method: "steer".to_string(),
118                id: id.clone(),
119                params: SteerParams {
120                    user_input: user_input.into(),
121                },
122            };
123            self.send_request(&req).await?;
124            self.read_response(&id).await
125        }
126    }
127
128    /// Enable or disable plan mode.
129    fn set_plan_mode(
130        &mut self,
131        enabled: bool,
132    ) -> impl Future<Output = Result<SetPlanModeResult, WireError>> + Send {
133        async move {
134            let id = self.next_id();
135            let req = JsonRpcRequest {
136                jsonrpc: crate::protocol::JsonRpcVersion::V2,
137                method: "set_plan_mode".to_string(),
138                id: id.clone(),
139                params: SetPlanModeParams { enabled },
140            };
141            self.send_request(&req).await?;
142            self.read_response(&id).await
143        }
144    }
145
146    /// Cancel the current turn.
147    fn cancel(&mut self) -> impl Future<Output = Result<(), WireError>> + Send {
148        async move {
149            let id = self.next_id();
150            let req = JsonRpcRequest {
151                jsonrpc: crate::protocol::JsonRpcVersion::V2,
152                method: "cancel".to_string(),
153                id: id.clone(),
154                params: CancelParams::default(),
155            };
156            self.send_request(&req).await?;
157            let _: CancelResult = self.read_response(&id).await?;
158            Ok(())
159        }
160    }
161
162    /// Wait for a response matching `expected_id`, buffering out-of-order
163    /// messages internally.
164    fn read_response<T: DeserializeOwned + Send>(
165        &mut self,
166        expected_id: &str,
167    ) -> impl Future<Output = Result<T, WireError>> + Send;
168}
169
170// ============================================================================
171// InMemoryWireClient
172// ============================================================================
173
174/// In-memory wire client for unit tests.
175///
176/// Holds an internal queue of [`crate::protocol::RawWireMessage`]s that
177/// `read_raw_message` drains. Tests inject messages via [`InMemoryWireClient::inject`].
178#[derive(Debug)]
179pub struct InMemoryWireClient {
180    incoming: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
181    pending: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
182    outgoing: Mutex<Vec<serde_json::Value>>,
183    handshake_done: bool,
184    request_counter: u64,
185    default_timeout: Option<Duration>,
186}
187
188impl Default for InMemoryWireClient {
189    fn default() -> Self {
190        Self::new()
191    }
192}
193
194impl InMemoryWireClient {
195    /// Create a new in-memory client.
196    pub fn new() -> Self {
197        Self {
198            incoming: Mutex::new(VecDeque::new()),
199            pending: Mutex::new(VecDeque::new()),
200            outgoing: Mutex::new(Vec::new()),
201            handshake_done: false,
202            request_counter: 0,
203            default_timeout: None,
204        }
205    }
206
207    /// Set a default timeout applied to every `read_response` call.
208    /// Without this, `read_response` waits indefinitely for a matching id.
209    pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
210        self.default_timeout = Some(timeout);
211        self
212    }
213
214    /// Inject an incoming raw wire message for the client to read.
215    pub async fn inject(&self, msg: crate::protocol::RawWireMessage) {
216        self.incoming.lock().await.push_back(msg);
217    }
218
219    /// Access all messages sent by the client.
220    pub async fn outgoing(&self) -> Vec<serde_json::Value> {
221        self.outgoing.lock().await.clone()
222    }
223}
224
225impl WireClient for InMemoryWireClient {
226    fn next_id(&mut self) -> String {
227        self.request_counter += 1;
228        format!("req-{}", self.request_counter)
229    }
230
231    async fn send_request<Params: Serialize + Sync>(
232        &mut self,
233        req: &JsonRpcRequest<Params>,
234    ) -> Result<(), WireError> {
235        let value = serde_json::to_value(req).map_err(WireError::from)?;
236        self.outgoing.lock().await.push(value);
237        Ok(())
238    }
239
240    async fn read_raw_message(&mut self) -> Result<crate::protocol::RawWireMessage, WireError> {
241        if let Some(msg) = self.pending.lock().await.pop_front() {
242            return Ok(msg);
243        }
244        match self.incoming.lock().await.pop_front() {
245            Some(msg) => Ok(msg),
246            None => Err(WireError::StreamClosed),
247        }
248    }
249
250    async fn read_raw_message_timeout(
251        &mut self,
252        timeout: Duration,
253    ) -> Result<crate::protocol::RawWireMessage, WireError> {
254        match tokio::time::timeout(timeout, self.read_raw_message()).await {
255            Ok(msg) => msg,
256            Err(_) => Err(WireError::Timeout(timeout)),
257        }
258    }
259
260    async fn read_response<T: DeserializeOwned + Send>(
261        &mut self,
262        expected_id: &str,
263    ) -> Result<T, WireError> {
264        let fut = async {
265            loop {
266                let idx = {
267                    let lock = self.pending.lock().await;
268                    lock.iter()
269                        .position(|msg| msg.id.as_deref() == Some(expected_id))
270                };
271                if let Some(idx) = idx {
272                    let msg = self
273                        .pending
274                        .lock()
275                        .await
276                        .remove(idx)
277                        .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
278                    return decode_response(msg, expected_id);
279                }
280
281                match self.incoming.lock().await.pop_front() {
282                    Some(msg) if msg.id.as_deref() == Some(expected_id) => {
283                        return decode_response(msg, expected_id);
284                    }
285                    Some(other) => {
286                        let mut pending = self.pending.lock().await;
287                        if pending.len() >= MAX_PENDING_MESSAGES {
288                            return Err(WireError::Internal(format!(
289                                "pending message buffer overflow ({} entries) waiting for id {:?}",
290                                MAX_PENDING_MESSAGES, expected_id
291                            )));
292                        }
293                        pending.push_back(other);
294                    }
295                    None => return Err(WireError::StreamClosed),
296                }
297            }
298        };
299
300        match self.default_timeout {
301            Some(d) => tokio::time::timeout(d, fut)
302                .await
303                .map_err(|_| WireError::Timeout(d))?,
304            None => fut.await,
305        }
306    }
307
308    async fn send_response<T: Serialize + Send>(
309        &mut self,
310        id: &str,
311        result: T,
312    ) -> Result<(), WireError> {
313        let resp = crate::protocol::JsonRpcSuccessResponse {
314            jsonrpc: crate::protocol::JsonRpcVersion::V2,
315            id: id.to_string(),
316            result,
317        };
318        let line = format!("{}\n", serde_json::to_string(&resp).map_err(WireError::from)?);
319        self.outgoing
320            .lock()
321            .await
322            .push(serde_json::Value::String(line));
323        Ok(())
324    }
325
326    async fn send_error(
327        &mut self,
328        id: &str,
329        code: i32,
330        message: &str,
331    ) -> Result<(), WireError> {
332        let resp = crate::protocol::JsonRpcErrorResponse {
333            jsonrpc: crate::protocol::JsonRpcVersion::V2,
334            id: id.to_string(),
335            error: crate::protocol::JsonRpcError {
336                code,
337                message: message.to_string(),
338                data: None,
339            },
340        };
341        let line = format!("{}\n", serde_json::to_string(&resp).map_err(WireError::from)?);
342        self.outgoing
343            .lock()
344            .await
345            .push(serde_json::Value::String(line));
346        Ok(())
347    }
348
349    async fn initialize(
350        &mut self,
351        _params: InitializeParams,
352    ) -> Result<InitializeResult, WireError> {
353        self.handshake_done = true;
354        Ok(InitializeResult {
355            protocol_version: crate::WIRE_PROTOCOL_VERSION.to_string(),
356            server: crate::protocol::ServerInfo {
357                name: "test-server".to_string(),
358                version: "0.0.0".to_string(),
359            },
360            slash_commands: vec![],
361            external_tools: None,
362            capabilities: None,
363            hooks: None,
364        })
365    }
366
367    fn is_handshake_done(&self) -> bool {
368        self.handshake_done
369    }
370
371    async fn shutdown(self) -> Result<(), WireError> {
372        Ok(())
373    }
374}
375
376fn decode_response<T: DeserializeOwned>(
377    msg: crate::protocol::RawWireMessage,
378    _expected_id: &str,
379) -> Result<T, WireError> {
380    if let Some(error) = msg.error {
381        return Err(WireError::RequestFailed {
382            code: error.code,
383            message: error.message,
384        });
385    }
386    let result = msg
387        .result
388        .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
389    serde_json::from_value(result).map_err(WireError::from)
390}