Skip to main content

kimi_wire/
transport.rs

1//! Transport layer for the Kimi Wire protocol.
2//!
3//! The [`crate::transport::Transport`] trait abstracts how raw JSON lines are read and written,
4//! allowing the same [`WireClient`](crate::client::WireClient) logic to run
5//! over stdio, in-memory buffers, or custom channels.
6
7use std::collections::VecDeque;
8use std::path::Path;
9use std::time::Duration;
10
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout};
15use tokio_util::codec::{FramedRead, LinesCodec};
16use tokio_util::sync::CancellationToken;
17
18use crate::client::WireClient;
19use crate::error::WireError;
20use crate::protocol::{
21    InitializeParams, InitializeResult, JsonRpcErrorResponse, JsonRpcRequest,
22    JsonRpcSuccessResponse, RawWireMessage,
23};
24
25/// Maximum wire-line length in bytes.
26///
27/// Each Kimi wire message arrives as a single newline-terminated JSON line.
28/// Without a hard cap, a peer that never emits a newline can drive
29/// `read_line` to allocate until OOM.
30pub const MAX_WIRE_LINE_LENGTH: usize = 16 * 1024 * 1024;
31
32/// Maximum number of out-of-order wire messages buffered while waiting for a
33/// specific response id.
34///
35/// Without a cap, a misbehaving peer that emits unrelated ids can drive
36/// `pending_messages` to allocate until OOM.
37pub const MAX_PENDING_MESSAGES: usize = 1024;
38
39/// Async transport for reading and writing newline-delimited JSON.
40pub trait Transport: Send {
41    /// Read the next line from the transport.
42    fn read_line(
43        &mut self,
44    ) -> impl std::future::Future<Output = Result<Option<String>, WireError>> + Send;
45
46    /// Write a line to the transport.
47    fn write_line(
48        &mut self,
49        line: &str,
50    ) -> impl std::future::Future<Output = Result<(), WireError>> + Send;
51
52    /// Gracefully close the transport.
53    ///
54    /// Default implementation returns `Ok(())`. Implementations that wrap a
55    /// child process, network socket, or other resource should override this
56    /// to release the resource cleanly. Called by
57    /// `TransportWireClient::shutdown`.
58    fn shutdown(self) -> impl std::future::Future<Output = Result<(), WireError>> + Send
59    where
60        Self: Sized,
61    {
62        async { Ok(()) }
63    }
64}
65
66// ============================================================================
67// TransportWireClient
68// ============================================================================
69
70/// A [`WireClient`] implementation backed by any [`Transport`].
71pub struct TransportWireClient<T: Transport> {
72    transport: T,
73    request_id_counter: u64,
74    handshake_done: bool,
75    pending_messages: VecDeque<RawWireMessage>,
76    default_timeout: Option<Duration>,
77}
78
79impl<T: Transport> TransportWireClient<T> {
80    /// Create a new client wrapping the given transport.
81    pub fn new(transport: T) -> Self {
82        Self {
83            transport,
84            request_id_counter: 0,
85            handshake_done: false,
86            pending_messages: VecDeque::new(),
87            default_timeout: None,
88        }
89    }
90
91    /// Consume the client and return the underlying transport.
92    pub fn into_transport(self) -> T {
93        self.transport
94    }
95
96    /// Set a default timeout applied to every `read_response` call.
97    /// Without this, `read_response` waits indefinitely for a matching id.
98    pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
99        self.default_timeout = Some(timeout);
100        self
101    }
102}
103
104impl<T: Transport> WireClient for TransportWireClient<T> {
105    fn next_id(&mut self) -> String {
106        self.request_id_counter += 1;
107        format!("req-{}", self.request_id_counter)
108    }
109
110    async fn send_request<Params: Serialize + Sync>(
111        &mut self,
112        req: &JsonRpcRequest<Params>,
113    ) -> Result<(), WireError> {
114        let line = serde_json::to_string(req).map_err(WireError::from)?;
115        self.transport.write_line(&line).await
116    }
117
118    async fn read_raw_message(&mut self) -> Result<RawWireMessage, WireError> {
119        if let Some(msg) = self.pending_messages.pop_front() {
120            return Ok(msg);
121        }
122        let line = match self.transport.read_line().await? {
123            Some(line) => line,
124            None => return Err(WireError::StreamClosed),
125        };
126        serde_json::from_str(&line).map_err(WireError::from)
127    }
128
129    async fn read_raw_message_timeout(
130        &mut self,
131        timeout: Duration,
132    ) -> Result<RawWireMessage, WireError> {
133        match tokio::time::timeout(timeout, self.read_raw_message()).await {
134            Ok(msg) => msg,
135            Err(_) => Err(WireError::Timeout(timeout)),
136        }
137    }
138
139    async fn read_response<Res: DeserializeOwned + Send>(
140        &mut self,
141        expected_id: &str,
142    ) -> Result<Res, WireError> {
143        let fut = async {
144            loop {
145                if let Some(idx) = self
146                    .pending_messages
147                    .iter()
148                    .position(|msg| msg.id.as_deref() == Some(expected_id))
149                {
150                    let msg = self
151                        .pending_messages
152                        .remove(idx)
153                        .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
154                    return decode_raw_response(msg, expected_id);
155                }
156
157                let line = match self.transport.read_line().await? {
158                    Some(line) => line,
159                    None => return Err(WireError::StreamClosed),
160                };
161                let msg: RawWireMessage = serde_json::from_str(&line).map_err(WireError::from)?;
162                if msg.id.as_deref() == Some(expected_id) {
163                    return decode_raw_response(msg, expected_id);
164                }
165                if self.pending_messages.len() >= MAX_PENDING_MESSAGES {
166                    return Err(WireError::Internal(format!(
167                        "pending message buffer overflow ({} entries) waiting for id {:?}",
168                        MAX_PENDING_MESSAGES, expected_id
169                    )));
170                }
171                self.pending_messages.push_back(msg);
172            }
173        };
174
175        match self.default_timeout {
176            Some(d) => tokio::time::timeout(d, fut)
177                .await
178                .map_err(|_| WireError::Timeout(d))?,
179            None => fut.await,
180        }
181    }
182
183    async fn send_response<Res: Serialize + Send>(
184        &mut self,
185        id: &str,
186        result: Res,
187    ) -> Result<(), WireError> {
188        let resp = JsonRpcSuccessResponse {
189            jsonrpc: crate::protocol::JsonRpcVersion::V2,
190            id: id.to_string(),
191            result,
192        };
193        let line = serde_json::to_string(&resp).map_err(WireError::from)?;
194        self.transport.write_line(&line).await
195    }
196
197    async fn send_error(
198        &mut self,
199        id: &str,
200        code: i32,
201        message: &str,
202    ) -> Result<(), WireError> {
203        let resp = JsonRpcErrorResponse {
204            jsonrpc: crate::protocol::JsonRpcVersion::V2,
205            id: id.to_string(),
206            error: crate::protocol::JsonRpcError {
207                code,
208                message: message.to_string(),
209                data: None,
210            },
211        };
212        let line = serde_json::to_string(&resp).map_err(WireError::from)?;
213        self.transport.write_line(&line).await
214    }
215
216    async fn initialize(
217        &mut self,
218        params: InitializeParams,
219    ) -> Result<InitializeResult, WireError> {
220        let id = self.next_id();
221        let req = JsonRpcRequest {
222            jsonrpc: crate::protocol::JsonRpcVersion::V2,
223            method: "initialize".to_string(),
224            id: id.clone(),
225            params,
226        };
227        self.send_request(&req).await?;
228
229        let line = match self.transport.read_line().await? {
230            Some(line) => line,
231            None => return Err(WireError::StreamClosed),
232        };
233
234        // Check for method-not-found error (-32601)
235        if let Ok(error_resp) = serde_json::from_str::<JsonRpcErrorResponse>(&line) {
236            if error_resp.error.code == crate::protocol::METHOD_NOT_FOUND {
237                tracing::warn!(
238                    code = error_resp.error.code,
239                    "Server does not support initialize, falling back to legacy no-handshake mode"
240                );
241                self.handshake_done = true;
242                return Ok(InitializeResult {
243                    protocol_version: crate::WIRE_PROTOCOL_LEGACY_VERSION.to_string(),
244                    server: crate::protocol::ServerInfo {
245                        name: "unknown".to_string(),
246                        version: "unknown".to_string(),
247                    },
248                    slash_commands: vec![],
249                    external_tools: None,
250                    capabilities: None,
251                    hooks: None,
252                });
253            }
254            return Err(WireError::RequestFailed {
255                code: error_resp.error.code,
256                message: error_resp.error.message,
257            });
258        }
259
260        let resp: JsonRpcSuccessResponse<InitializeResult> =
261            serde_json::from_str(&line).map_err(WireError::from)?;
262        self.handshake_done = true;
263        Ok(resp.result)
264    }
265
266    fn is_handshake_done(&self) -> bool {
267        self.handshake_done
268    }
269
270    async fn shutdown(self) -> Result<(), WireError> {
271        self.transport.shutdown().await
272    }
273}
274
275fn decode_raw_response<T: DeserializeOwned>(
276    msg: RawWireMessage,
277    _expected_id: &str,
278) -> Result<T, WireError> {
279    if let Some(error) = msg.error {
280        return Err(WireError::RequestFailed {
281            code: error.code,
282            message: error.message,
283        });
284    }
285    let result = msg
286        .result
287        .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
288    serde_json::from_value(result).map_err(WireError::from)
289}
290
291// ============================================================================
292// ChildProcessTransport
293// ============================================================================
294
295/// A transport backed by a child process's stdin/stdout.
296#[derive(Debug)]
297pub struct ChildProcessTransport {
298    child: Option<Child>,
299    stdin: Option<ChildStdin>,
300    stdout_reader: FramedRead<ChildStdout, LinesCodec>,
301    stderr_handle: Option<tokio::task::JoinHandle<()>>,
302    cancel_token: CancellationToken,
303}
304
305impl ChildProcessTransport {
306    /// Spawn a new `kimi` process in wire mode.
307    ///
308    /// # Errors
309    ///
310    /// Returns [`WireError::SpawnFailed`] if the process cannot be started.
311    pub async fn spawn(
312        kimi_binary: &str,
313        work_dir: Option<&Path>,
314        session: Option<&str>,
315        model: Option<&str>,
316    ) -> Result<Self, WireError> {
317        let mut child = None;
318        for attempt in 0..3 {
319            let mut cmd = tokio::process::Command::new(kimi_binary);
320            cmd.arg("--wire");
321            if let Some(dir) = work_dir {
322                cmd.arg("--work-dir").arg(dir);
323            }
324            if let Some(s) = session {
325                cmd.arg("--session").arg(s);
326            }
327            if let Some(m) = model {
328                cmd.arg("--model").arg(m);
329            }
330            cmd.stdin(std::process::Stdio::piped())
331                .stdout(std::process::Stdio::piped())
332                .stderr(std::process::Stdio::piped());
333
334            match cmd.kill_on_drop(true).spawn() {
335                Ok(spawned) => {
336                    child = Some(spawned);
337                    break;
338                }
339                Err(err) if err.raw_os_error() == Some(26) && attempt < 2 => {
340                    tokio::time::sleep(Duration::from_millis(25)).await;
341                }
342                Err(err) => {
343                    return Err(WireError::SpawnFailed(err.to_string()));
344                }
345            }
346        }
347
348        let mut child = child
349            .ok_or_else(|| WireError::SpawnFailed("all spawn attempts failed".to_string()))?;
350        let stdin = child
351            .stdin
352            .take()
353            .ok_or_else(|| WireError::SpawnFailed("no stdin".to_string()))?;
354        let stdout = child
355            .stdout
356            .take()
357            .ok_or_else(|| WireError::SpawnFailed("no stdout".to_string()))?;
358        let stdout_reader = FramedRead::new(
359            stdout,
360            LinesCodec::new_with_max_length(MAX_WIRE_LINE_LENGTH),
361        );
362
363        let cancel_token = CancellationToken::new();
364        let stderr_cancel = cancel_token.clone();
365        let stderr_handle = child.stderr.take().map(|stderr| {
366            tokio::spawn(async move {
367                let mut reader = BufReader::new(stderr).lines();
368                loop {
369                    tokio::select! {
370                        biased;
371                        _ = stderr_cancel.cancelled() => break,
372                        line = reader.next_line() => {
373                            match line {
374                                Ok(Some(line)) => {
375                                    #[cfg(feature = "redact")]
376                                    tracing::warn!(target: "kimi.stderr", "{}", crate::protocol::redact::scrub_secret_patterns(&line));
377                                    #[cfg(not(feature = "redact"))]
378                                    tracing::warn!(target: "kimi.stderr", "{line}");
379                                }
380                                _ => break,
381                            }
382                        }
383                    }
384                }
385            })
386        });
387
388        Ok(Self {
389            child: Some(child),
390            stdin: Some(stdin),
391            stdout_reader,
392            stderr_handle,
393            cancel_token,
394        })
395    }
396}
397
398impl Transport for ChildProcessTransport {
399    async fn read_line(&mut self) -> Result<Option<String>, WireError> {
400        use tokio_stream::StreamExt;
401        match self.stdout_reader.next().await {
402            Some(Ok(line)) => Ok(Some(line)),
403            Some(Err(e)) => Err(WireError::Io(e.to_string())),
404            None => Ok(None),
405        }
406    }
407
408    async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
409        let stdin = self
410            .stdin
411            .as_mut()
412            .ok_or(WireError::StreamClosed)?;
413        stdin.write_all(line.as_bytes()).await?;
414        stdin.write_all(b"\n").await?;
415        stdin.flush().await?;
416        Ok(())
417    }
418
419    async fn shutdown(mut self) -> Result<(), WireError> {
420        // Close stdin so the child sees EOF.
421        drop(self.stdin.take());
422
423        // Wait up to 3 seconds for the child to exit gracefully.
424        let grace = Duration::from_secs(3);
425        if let Some(mut child) = self.child.take() {
426            match tokio::time::timeout(grace, child.wait()).await {
427                Ok(Ok(_)) => {}
428                Ok(Err(_)) => {}
429                Err(_) => {
430                    let _ = child.kill().await;
431                }
432            }
433        }
434
435        // Abort the stderr task and cancel the token.
436        self.cancel_token.cancel();
437        if let Some(handle) = self.stderr_handle.take() {
438            handle.abort();
439        }
440
441        Ok(())
442    }
443}
444
445impl Drop for ChildProcessTransport {
446    fn drop(&mut self) {
447        self.cancel_token.cancel();
448        if let Some(handle) = self.stderr_handle.take() {
449            handle.abort();
450        }
451    }
452}
453
454// ============================================================================
455// ChannelTransport
456// ============================================================================
457
458/// A transport backed by in-memory channels for testing.
459#[derive(Debug)]
460pub struct ChannelTransport {
461    rx: tokio::sync::mpsc::UnboundedReceiver<String>,
462    tx: tokio::sync::mpsc::UnboundedSender<String>,
463}
464
465impl ChannelTransport {
466    /// Create a new pair of connected transports.
467    pub fn pair() -> (Self, Self) {
468        let (tx1, rx1) = tokio::sync::mpsc::unbounded_channel();
469        let (tx2, rx2) = tokio::sync::mpsc::unbounded_channel();
470        (
471            Self { rx: rx1, tx: tx2 },
472            Self { rx: rx2, tx: tx1 },
473        )
474    }
475}
476
477impl Transport for ChannelTransport {
478    async fn read_line(&mut self) -> Result<Option<String>, WireError> {
479        Ok(self.rx.recv().await)
480    }
481
482    async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
483        self.tx
484            .send(line.to_string())
485            .map_err(|_| WireError::StreamClosed)
486    }
487}