Skip to main content

kimi_wire/
client.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::time::Duration;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use tokio::sync::Mutex;
8
9#[cfg(feature = "process")]
10use crate::transport::MAX_PENDING_MESSAGES;
11#[cfg(not(feature = "process"))]
12const MAX_PENDING_MESSAGES: usize = 1024;
13
14use crate::error::WireError;
15use crate::protocol::{
16    CancelParams, CancelResult, InitializeParams, InitializeResult, JsonRpcRequest, PromptParams,
17    PromptResult, ReplayParams, ReplayResult, SetPlanModeParams, SetPlanModeResult, SteerParams,
18    SteerResult, UserInput,
19};
20
21/// Trait for a Kimi Wire Protocol client.
22///
23/// Implementations may communicate over a child process, an in-memory channel,
24/// or any other transport.
25pub trait WireClient: Send {
26    /// Generate the next request id.
27    fn next_id(&mut self) -> String;
28
29    /// Send a JSON-RPC request.
30    fn send_request<Params: Serialize + Sync>(
31        &mut self,
32        req: &JsonRpcRequest<Params>,
33    ) -> impl Future<Output = Result<(), WireError>> + Send;
34
35    /// Read the next incoming raw wire message.
36    fn read_raw_message(
37        &mut self,
38    ) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
39
40    /// Read the next incoming raw wire message with a timeout.
41    fn read_raw_message_timeout(
42        &mut self,
43        timeout: Duration,
44    ) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
45
46    /// Send a JSON-RPC success response.
47    fn send_response<T: Serialize + Send>(
48        &mut self,
49        id: &str,
50        result: T,
51    ) -> impl Future<Output = Result<(), WireError>> + Send;
52
53    /// Send a JSON-RPC error response.
54    fn send_error(
55        &mut self,
56        id: &str,
57        code: i32,
58        message: &str,
59    ) -> impl Future<Output = Result<(), WireError>> + Send;
60
61    /// Perform the initialize handshake.
62    fn initialize(
63        &mut self,
64        params: InitializeParams,
65    ) -> impl Future<Output = Result<InitializeResult, WireError>> + Send;
66
67    /// Returns true if the initialize handshake has completed.
68    fn is_handshake_done(&self) -> bool;
69
70    /// Gracefully shut down the client.
71    fn shutdown(self) -> impl Future<Output = Result<(), WireError>> + Send;
72
73    /// Send a prompt and wait for the result.
74    fn prompt(
75        &mut self,
76        user_input: impl Into<UserInput> + Send,
77    ) -> impl Future<Output = Result<PromptResult, WireError>> + Send {
78        async move {
79            let id = self.start_prompt(user_input).await?;
80            self.read_response(&id).await
81        }
82    }
83
84    /// Send a prompt without waiting for the result.
85    fn start_prompt(
86        &mut self,
87        user_input: impl Into<UserInput> + Send,
88    ) -> impl Future<Output = Result<String, WireError>> + Send {
89        async move {
90            let id = self.next_id();
91            let req = JsonRpcRequest {
92                jsonrpc: crate::protocol::JsonRpcVersion::V2,
93                method: "prompt".to_string(),
94                id: id.clone(),
95                params: PromptParams {
96                    user_input: user_input.into(),
97                },
98            };
99            self.send_request(&req).await?;
100            Ok(id)
101        }
102    }
103
104    /// Replay events and requests from the current session.
105    fn replay(&mut self) -> impl Future<Output = Result<ReplayResult, WireError>> + Send {
106        async move {
107            let id = self.next_id();
108            let req = JsonRpcRequest {
109                jsonrpc: crate::protocol::JsonRpcVersion::V2,
110                method: "replay".to_string(),
111                id: id.clone(),
112                params: ReplayParams::default(),
113            };
114            self.send_request(&req).await?;
115            self.read_response(&id).await
116        }
117    }
118
119    /// Steer the current turn with additional user input.
120    fn steer(
121        &mut self,
122        user_input: impl Into<UserInput> + Send,
123    ) -> impl Future<Output = Result<SteerResult, WireError>> + Send {
124        async move {
125            let id = self.next_id();
126            let req = JsonRpcRequest {
127                jsonrpc: crate::protocol::JsonRpcVersion::V2,
128                method: "steer".to_string(),
129                id: id.clone(),
130                params: SteerParams {
131                    user_input: user_input.into(),
132                },
133            };
134            self.send_request(&req).await?;
135            self.read_response(&id).await
136        }
137    }
138
139    /// Enable or disable plan mode.
140    fn set_plan_mode(
141        &mut self,
142        enabled: bool,
143    ) -> impl Future<Output = Result<SetPlanModeResult, WireError>> + Send {
144        async move {
145            let id = self.next_id();
146            let req = JsonRpcRequest {
147                jsonrpc: crate::protocol::JsonRpcVersion::V2,
148                method: "set_plan_mode".to_string(),
149                id: id.clone(),
150                params: SetPlanModeParams { enabled },
151            };
152            self.send_request(&req).await?;
153            self.read_response(&id).await
154        }
155    }
156
157    /// Cancel the current turn.
158    fn cancel(&mut self) -> impl Future<Output = Result<(), WireError>> + Send {
159        async move {
160            let id = self.next_id();
161            let req = JsonRpcRequest {
162                jsonrpc: crate::protocol::JsonRpcVersion::V2,
163                method: "cancel".to_string(),
164                id: id.clone(),
165                params: CancelParams::default(),
166            };
167            self.send_request(&req).await?;
168            let _: CancelResult = self.read_response(&id).await?;
169            Ok(())
170        }
171    }
172
173    /// Wait for a response matching `expected_id`, buffering out-of-order
174    /// messages internally.
175    fn read_response<T: DeserializeOwned + Send>(
176        &mut self,
177        expected_id: &str,
178    ) -> impl Future<Output = Result<T, WireError>> + Send;
179}
180
181// ============================================================================
182// InMemoryWireClient
183// ============================================================================
184
185/// In-memory wire client for unit tests.
186///
187/// Holds an internal queue of [`crate::protocol::RawWireMessage`]s that
188/// `read_raw_message` drains. Tests inject messages via [`InMemoryWireClient::inject`].
189#[derive(Debug)]
190pub struct InMemoryWireClient {
191    incoming: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
192    pending: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
193    outgoing: Mutex<Vec<serde_json::Value>>,
194    handshake_done: bool,
195    request_counter: u64,
196    default_timeout: Option<Duration>,
197}
198
199impl Default for InMemoryWireClient {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205impl InMemoryWireClient {
206    /// Create a new in-memory client.
207    #[must_use]
208    pub fn new() -> Self {
209        Self {
210            incoming: Mutex::new(VecDeque::new()),
211            pending: Mutex::new(VecDeque::new()),
212            outgoing: Mutex::new(Vec::new()),
213            handshake_done: false,
214            request_counter: 0,
215            default_timeout: None,
216        }
217    }
218
219    /// Set a default timeout applied to every `read_response` call.
220    /// Without this, `read_response` waits indefinitely for a matching id.
221    #[must_use]
222    pub const fn with_default_timeout(mut self, timeout: Duration) -> Self {
223        self.default_timeout = Some(timeout);
224        self
225    }
226
227    /// Inject an incoming raw wire message for the client to read.
228    pub async fn inject(&self, msg: crate::protocol::RawWireMessage) {
229        self.incoming.lock().await.push_back(msg);
230    }
231
232    /// Access all messages sent by the client.
233    pub async fn outgoing(&self) -> Vec<serde_json::Value> {
234        self.outgoing.lock().await.clone()
235    }
236}
237
238impl WireClient for InMemoryWireClient {
239    fn next_id(&mut self) -> String {
240        self.request_counter += 1;
241        format!("req-{}", self.request_counter)
242    }
243
244    async fn send_request<Params: Serialize + Sync>(
245        &mut self,
246        req: &JsonRpcRequest<Params>,
247    ) -> Result<(), WireError> {
248        let value = serde_json::to_value(req).map_err(WireError::from)?;
249        self.outgoing.lock().await.push(value);
250        Ok(())
251    }
252
253    async fn read_raw_message(&mut self) -> Result<crate::protocol::RawWireMessage, WireError> {
254        let msg = self.pending.lock().await.pop_front();
255        if let Some(msg) = msg {
256            return Ok(msg);
257        }
258        let msg = self.incoming.lock().await.pop_front();
259        match msg {
260            Some(msg) => Ok(msg),
261            None => Err(WireError::StreamClosed),
262        }
263    }
264
265    async fn read_raw_message_timeout(
266        &mut self,
267        timeout: Duration,
268    ) -> Result<crate::protocol::RawWireMessage, WireError> {
269        match tokio::time::timeout(timeout, self.read_raw_message()).await {
270            Ok(msg) => msg,
271            Err(_) => Err(WireError::Timeout(timeout)),
272        }
273    }
274
275    async fn read_response<T: DeserializeOwned + Send>(
276        &mut self,
277        expected_id: &str,
278    ) -> Result<T, WireError> {
279        let fut = async {
280            loop {
281                let idx = {
282                    let lock = self.pending.lock().await;
283                    lock.iter()
284                        .position(|msg| msg.id.as_deref() == Some(expected_id))
285                };
286                if let Some(idx) = idx {
287                    let msg =
288                        self.pending.lock().await.remove(idx).ok_or_else(|| {
289                            WireError::Internal("pending index invalid".to_string())
290                        })?;
291                    return decode_response(msg, expected_id);
292                }
293
294                let msg = self.incoming.lock().await.pop_front();
295                match msg {
296                    Some(msg) if msg.id.as_deref() == Some(expected_id) => {
297                        return decode_response(msg, expected_id);
298                    }
299                    Some(other) => {
300                        let mut pending = self.pending.lock().await;
301                        if pending.len() >= MAX_PENDING_MESSAGES {
302                            return Err(WireError::Internal(format!(
303                                "pending message buffer overflow ({} entries) waiting for id {:?}",
304                                MAX_PENDING_MESSAGES, expected_id
305                            )));
306                        }
307                        pending.push_back(other);
308                    }
309                    None => return Err(WireError::StreamClosed),
310                }
311            }
312        };
313
314        match self.default_timeout {
315            Some(d) => tokio::time::timeout(d, fut)
316                .await
317                .map_err(|_| WireError::Timeout(d))?,
318            None => fut.await,
319        }
320    }
321
322    async fn send_response<T: Serialize + Send>(
323        &mut self,
324        id: &str,
325        result: T,
326    ) -> Result<(), WireError> {
327        let resp = crate::protocol::JsonRpcSuccessResponse {
328            jsonrpc: crate::protocol::JsonRpcVersion::V2,
329            id: id.to_string(),
330            result,
331        };
332        let line = format!(
333            "{}\n",
334            serde_json::to_string(&resp).map_err(WireError::from)?
335        );
336        self.outgoing
337            .lock()
338            .await
339            .push(serde_json::Value::String(line));
340        Ok(())
341    }
342
343    async fn send_error(&mut self, id: &str, code: i32, message: &str) -> Result<(), WireError> {
344        let resp = crate::protocol::JsonRpcErrorResponse {
345            jsonrpc: crate::protocol::JsonRpcVersion::V2,
346            id: id.to_string(),
347            error: crate::protocol::JsonRpcError {
348                code,
349                message: message.to_string(),
350                data: None,
351            },
352        };
353        let line = format!(
354            "{}\n",
355            serde_json::to_string(&resp).map_err(WireError::from)?
356        );
357        self.outgoing
358            .lock()
359            .await
360            .push(serde_json::Value::String(line));
361        Ok(())
362    }
363
364    async fn initialize(
365        &mut self,
366        _params: InitializeParams,
367    ) -> Result<InitializeResult, WireError> {
368        self.handshake_done = true;
369        Ok(InitializeResult {
370            protocol_version: crate::WIRE_PROTOCOL_VERSION.to_string(),
371            server: crate::protocol::ServerInfo {
372                name: "test-server".to_string(),
373                version: "0.0.0".to_string(),
374            },
375            slash_commands: vec![],
376            external_tools: None,
377            capabilities: None,
378            hooks: None,
379        })
380    }
381
382    fn is_handshake_done(&self) -> bool {
383        self.handshake_done
384    }
385
386    async fn shutdown(self) -> Result<(), WireError> {
387        Ok(())
388    }
389}
390
391fn decode_response<T: DeserializeOwned>(
392    msg: crate::protocol::RawWireMessage,
393    _expected_id: &str,
394) -> Result<T, WireError> {
395    if let Some(error) = msg.error {
396        return Err(WireError::RequestFailed {
397            code: error.code,
398            message: error.message,
399        });
400    }
401    let result = msg
402        .result
403        .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
404    serde_json::from_value(result).map_err(WireError::from)
405}