Skip to main content

kimi_wire/
client.rs

1use std::collections::VecDeque;
2use std::future::Future;
3use std::time::Duration;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use tokio::sync::Mutex;
8
9use crate::error::WireError;
10use crate::protocol::{
11    CancelParams, CancelResult, InitializeParams, InitializeResult, JsonRpcRequest, PromptParams,
12    PromptResult, ReplayParams, ReplayResult, SetPlanModeParams, SetPlanModeResult, SteerParams,
13    SteerResult, UserInput,
14};
15
16/// Trait for a Kimi Wire Protocol client.
17///
18/// Implementations may communicate over a child process, an in-memory channel,
19/// or any other transport.
20pub trait WireClient: Send {
21    /// Generate the next request id.
22    fn next_id(&mut self) -> String;
23
24    /// Send a JSON-RPC request.
25    fn send_request<Params: Serialize + Sync>(
26        &mut self,
27        req: &JsonRpcRequest<Params>,
28    ) -> impl Future<Output = Result<(), WireError>> + Send;
29
30    /// Read the next incoming raw wire message.
31    fn read_raw_message(&mut self) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
32
33    /// Read the next incoming raw wire message with a timeout.
34    fn read_raw_message_timeout(
35        &mut self,
36        timeout: Duration,
37    ) -> impl Future<Output = Result<crate::protocol::RawWireMessage, WireError>> + Send;
38
39    /// Send a JSON-RPC success response.
40    fn send_response<T: Serialize + Send>(
41        &mut self,
42        id: &str,
43        result: T,
44    ) -> impl Future<Output = Result<(), WireError>> + Send;
45
46    /// Send a JSON-RPC error response.
47    fn send_error(
48        &mut self,
49        id: &str,
50        code: i32,
51        message: &str,
52    ) -> impl Future<Output = Result<(), WireError>> + Send;
53
54    /// Perform the initialize handshake.
55    fn initialize(
56        &mut self,
57        params: InitializeParams,
58    ) -> impl Future<Output = Result<InitializeResult, WireError>> + Send;
59
60    /// Returns true if the initialize handshake has completed.
61    fn is_handshake_done(&self) -> bool;
62
63    /// Gracefully shut down the client.
64    fn shutdown(self) -> impl Future<Output = Result<(), WireError>> + Send;
65
66    /// Send a prompt and wait for the result.
67    fn prompt(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<PromptResult, WireError>> + Send {
68        async move {
69            let id = self.start_prompt(user_input).await?;
70            self.read_response(&id).await
71        }
72    }
73
74    /// Send a prompt without waiting for the result.
75    fn start_prompt(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<String, WireError>> + Send {
76        async move {
77            let id = self.next_id();
78            let req = JsonRpcRequest {
79                jsonrpc: crate::protocol::JsonRpcVersion::default(),
80                method: "prompt".to_string(),
81                id: id.clone(),
82                params: PromptParams {
83                    user_input: user_input.into(),
84                },
85            };
86            self.send_request(&req).await?;
87            Ok(id)
88        }
89    }
90
91    /// Replay events and requests from the current session.
92    fn replay(&mut self) -> impl Future<Output = Result<ReplayResult, WireError>> + Send {
93        async move {
94            let id = self.next_id();
95            let req = JsonRpcRequest {
96                jsonrpc: crate::protocol::JsonRpcVersion::default(),
97                method: "replay".to_string(),
98                id: id.clone(),
99                params: ReplayParams::default(),
100            };
101            self.send_request(&req).await?;
102            self.read_response(&id).await
103        }
104    }
105
106    /// Steer the current turn with additional user input.
107    fn steer(&mut self, user_input: impl Into<UserInput> + Send) -> impl Future<Output = Result<SteerResult, WireError>> + Send {
108        async move {
109            let id = self.next_id();
110            let req = JsonRpcRequest {
111                jsonrpc: crate::protocol::JsonRpcVersion::default(),
112                method: "steer".to_string(),
113                id: id.clone(),
114                params: SteerParams {
115                    user_input: user_input.into(),
116                },
117            };
118            self.send_request(&req).await?;
119            self.read_response(&id).await
120        }
121    }
122
123    /// Enable or disable plan mode.
124    fn set_plan_mode(
125        &mut self,
126        enabled: bool,
127    ) -> impl Future<Output = Result<SetPlanModeResult, WireError>> + Send {
128        async move {
129            let id = self.next_id();
130            let req = JsonRpcRequest {
131                jsonrpc: crate::protocol::JsonRpcVersion::default(),
132                method: "set_plan_mode".to_string(),
133                id: id.clone(),
134                params: SetPlanModeParams { enabled },
135            };
136            self.send_request(&req).await?;
137            self.read_response(&id).await
138        }
139    }
140
141    /// Cancel the current turn.
142    fn cancel(&mut self) -> impl Future<Output = Result<(), WireError>> + Send {
143        async move {
144            let id = self.next_id();
145            let req = JsonRpcRequest {
146                jsonrpc: crate::protocol::JsonRpcVersion::default(),
147                method: "cancel".to_string(),
148                id: id.clone(),
149                params: CancelParams::default(),
150            };
151            self.send_request(&req).await?;
152            let _: CancelResult = self.read_response(&id).await?;
153            Ok(())
154        }
155    }
156
157    /// Wait for a response matching `expected_id`, buffering out-of-order
158    /// messages internally.
159    fn read_response<T: DeserializeOwned + Send>(
160        &mut self,
161        expected_id: &str,
162    ) -> impl Future<Output = Result<T, WireError>> + Send;
163}
164
165// ============================================================================
166// InMemoryWireClient
167// ============================================================================
168
169/// In-memory wire client for unit tests.
170///
171/// Holds an internal queue of [`crate::protocol::RawWireMessage`]s that
172/// `read_raw_message` drains. Tests inject messages via [`InMemoryWireClient::inject`].
173#[derive(Debug)]
174pub struct InMemoryWireClient {
175    incoming: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
176    pending: Mutex<VecDeque<crate::protocol::RawWireMessage>>,
177    outgoing: Mutex<Vec<serde_json::Value>>,
178    handshake_done: bool,
179    request_counter: u64,
180}
181
182impl Default for InMemoryWireClient {
183    fn default() -> Self {
184        Self::new()
185    }
186}
187
188impl InMemoryWireClient {
189    /// Create a new in-memory client.
190    pub fn new() -> Self {
191        Self {
192            incoming: Mutex::new(VecDeque::new()),
193            pending: Mutex::new(VecDeque::new()),
194            outgoing: Mutex::new(Vec::new()),
195            handshake_done: false,
196            request_counter: 0,
197        }
198    }
199
200    /// Inject an incoming raw wire message for the client to read.
201    pub async fn inject(&self, msg: crate::protocol::RawWireMessage) {
202        self.incoming.lock().await.push_back(msg);
203    }
204
205    /// Access all messages sent by the client.
206    pub async fn outgoing(&self) -> Vec<serde_json::Value> {
207        self.outgoing.lock().await.clone()
208    }
209}
210
211impl WireClient for InMemoryWireClient {
212    fn next_id(&mut self) -> String {
213        self.request_counter += 1;
214        format!("req-{}", self.request_counter)
215    }
216
217    async fn send_request<Params: Serialize + Sync>(
218        &mut self,
219        req: &JsonRpcRequest<Params>,
220    ) -> Result<(), WireError> {
221        let value = serde_json::to_value(req).map_err(WireError::from)?;
222        self.outgoing.lock().await.push(value);
223        Ok(())
224    }
225
226    async fn read_raw_message(&mut self) -> Result<crate::protocol::RawWireMessage, WireError> {
227        if let Some(msg) = self.pending.lock().await.pop_front() {
228            return Ok(msg);
229        }
230        match self.incoming.lock().await.pop_front() {
231            Some(msg) => Ok(msg),
232            None => Err(WireError::StreamClosed),
233        }
234    }
235
236    async fn read_raw_message_timeout(
237        &mut self,
238        timeout: Duration,
239    ) -> Result<crate::protocol::RawWireMessage, WireError> {
240        match tokio::time::timeout(timeout, self.read_raw_message()).await {
241            Ok(msg) => msg,
242            Err(_) => Err(WireError::Timeout(timeout)),
243        }
244    }
245
246    async fn read_response<T: DeserializeOwned + Send>(
247        &mut self,
248        expected_id: &str,
249    ) -> Result<T, WireError> {
250        loop {
251            let idx = {
252                let lock = self.pending.lock().await;
253                lock.iter()
254                    .position(|msg| msg.id.as_deref() == Some(expected_id))
255            };
256            if let Some(idx) = idx {
257                let msg = self
258                    .pending
259                    .lock()
260                    .await
261                    .remove(idx)
262                    .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
263                return decode_response(msg, expected_id);
264            }
265
266            match self.incoming.lock().await.pop_front() {
267                Some(msg) if msg.id.as_deref() == Some(expected_id) => {
268                    return decode_response(msg, expected_id);
269                }
270                Some(other) => {
271                    self.pending.lock().await.push_back(other);
272                }
273                None => return Err(WireError::StreamClosed),
274            }
275        }
276    }
277
278    async fn send_response<T: Serialize + Send>(
279        &mut self,
280        id: &str,
281        result: T,
282    ) -> Result<(), WireError> {
283        let resp = crate::protocol::JsonRpcSuccessResponse {
284            jsonrpc: crate::protocol::JsonRpcVersion::default(),
285            id: id.to_string(),
286            result,
287        };
288        let line = format!("{}\n", serde_json::to_string(&resp).map_err(WireError::from)?);
289        self.outgoing
290            .lock()
291            .await
292            .push(serde_json::Value::String(line));
293        Ok(())
294    }
295
296    async fn send_error(
297        &mut self,
298        id: &str,
299        code: i32,
300        message: &str,
301    ) -> Result<(), WireError> {
302        let resp = crate::protocol::JsonRpcErrorResponse {
303            jsonrpc: crate::protocol::JsonRpcVersion::default(),
304            id: id.to_string(),
305            error: crate::protocol::JsonRpcError {
306                code,
307                message: message.to_string(),
308                data: None,
309            },
310        };
311        let line = format!("{}\n", serde_json::to_string(&resp).map_err(WireError::from)?);
312        self.outgoing
313            .lock()
314            .await
315            .push(serde_json::Value::String(line));
316        Ok(())
317    }
318
319    async fn initialize(
320        &mut self,
321        _params: InitializeParams,
322    ) -> Result<InitializeResult, WireError> {
323        self.handshake_done = true;
324        Ok(InitializeResult {
325            protocol_version: crate::WIRE_PROTOCOL_VERSION.to_string(),
326            server: crate::protocol::ServerInfo {
327                name: "test-server".to_string(),
328                version: "0.0.0".to_string(),
329            },
330            slash_commands: vec![],
331            external_tools: None,
332            capabilities: None,
333            hooks: None,
334        })
335    }
336
337    fn is_handshake_done(&self) -> bool {
338        self.handshake_done
339    }
340
341    async fn shutdown(self) -> Result<(), WireError> {
342        Ok(())
343    }
344}
345
346fn decode_response<T: DeserializeOwned>(
347    msg: crate::protocol::RawWireMessage,
348    _expected_id: &str,
349) -> Result<T, WireError> {
350    if let Some(error) = msg.error {
351        return Err(WireError::RequestFailed {
352            code: error.code,
353            message: error.message,
354        });
355    }
356    let result = msg
357        .result
358        .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
359    serde_json::from_value(result).map_err(WireError::from)
360}