Skip to main content

kimi_wire/
transport.rs

1//! Transport layer for the Kimi Wire protocol.
2//!
3//! The [`crate::transport::Transport`] trait abstracts how raw JSON lines are read and written,
4//! allowing the same [`WireClient`](crate::client::WireClient) logic to run
5//! over stdio, in-memory buffers, or custom channels.
6
7use std::collections::VecDeque;
8use std::path::Path;
9use std::time::Duration;
10
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout};
15use tokio_util::codec::{FramedRead, LinesCodec};
16use tokio_util::sync::CancellationToken;
17
18use crate::client::WireClient;
19use crate::error::WireError;
20use crate::protocol::{
21    InitializeParams, InitializeResult, JsonRpcErrorResponse, JsonRpcRequest,
22    JsonRpcSuccessResponse, RawWireMessage,
23};
24
25/// Maximum wire-line length in bytes.
26///
27/// Each Kimi wire message arrives as a single newline-terminated JSON line.
28/// Without a hard cap, a peer that never emits a newline can drive
29/// `read_line` to allocate until OOM.
30pub const MAX_WIRE_LINE_LENGTH: usize = 16 * 1024 * 1024;
31
32/// Maximum number of out-of-order wire messages buffered while waiting for a
33/// specific response id.
34///
35/// Without a cap, a misbehaving peer that emits unrelated ids can drive
36/// `pending_messages` to allocate until OOM.
37pub const MAX_PENDING_MESSAGES: usize = 1024;
38
39/// Returns `true` for errors where a retry might succeed.
40const fn is_transient_error(err: &WireError) -> bool {
41    matches!(err, WireError::Io(_) | WireError::Timeout(_))
42}
43
44/// Async transport for reading and writing newline-delimited JSON.
45pub trait Transport: Send {
46    /// Read the next line from the transport.
47    fn read_line(
48        &mut self,
49    ) -> impl std::future::Future<Output = Result<Option<String>, WireError>> + Send;
50
51    /// Write a line to the transport.
52    fn write_line(
53        &mut self,
54        line: &str,
55    ) -> impl std::future::Future<Output = Result<(), WireError>> + Send;
56
57    /// Gracefully close the transport.
58    ///
59    /// Default implementation returns `Ok(())`. Implementations that wrap a
60    /// child process, network socket, or other resource should override this
61    /// to release the resource cleanly. Called by
62    /// `TransportWireClient::shutdown`.
63    fn shutdown(self) -> impl std::future::Future<Output = Result<(), WireError>> + Send
64    where
65        Self: Sized,
66    {
67        async { Ok(()) }
68    }
69}
70
71// ============================================================================
72// TransportWireClient
73// ============================================================================
74
75/// A [`WireClient`] implementation backed by any [`Transport`].
76pub struct TransportWireClient<T: Transport> {
77    transport: T,
78    request_id_counter: u64,
79    handshake_done: bool,
80    pending_messages: VecDeque<RawWireMessage>,
81    default_timeout: Option<Duration>,
82    max_io_retries: u32,
83}
84
85impl<T: Transport> std::fmt::Debug for TransportWireClient<T> {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("TransportWireClient")
88            .field("request_id_counter", &self.request_id_counter)
89            .field("handshake_done", &self.handshake_done)
90            .field("pending_messages", &self.pending_messages.len())
91            .field("default_timeout", &self.default_timeout)
92            .field("max_io_retries", &self.max_io_retries)
93            .finish_non_exhaustive()
94    }
95}
96
97impl<T: Transport> TransportWireClient<T> {
98    /// Create a new client wrapping the given transport.
99    pub const fn new(transport: T) -> Self {
100        Self {
101            transport,
102            request_id_counter: 0,
103            handshake_done: false,
104            pending_messages: VecDeque::new(),
105            default_timeout: None,
106            max_io_retries: 0,
107        }
108    }
109
110    /// Consume the client and return the underlying transport.
111    pub fn into_transport(self) -> T {
112        self.transport
113    }
114
115    /// Set a default timeout applied to every `read_response` call.
116    /// Without this, `read_response` waits indefinitely for a matching id.
117    #[must_use]
118    pub const fn with_default_timeout(mut self, timeout: Duration) -> Self {
119        self.default_timeout = Some(timeout);
120        self
121    }
122
123    /// Set the maximum number of retries for transient I/O errors during
124    /// `read_response`. Each retry waits exponentially longer
125    /// (`50ms * 2^attempt`).
126    #[must_use]
127    pub const fn with_max_io_retries(mut self, retries: u32) -> Self {
128        self.max_io_retries = if retries > 5 { 5 } else { retries };
129        self
130    }
131
132    async fn read_line_with_retry(&mut self) -> Result<Option<String>, WireError> {
133        let mut attempt = 0;
134        loop {
135            match self.transport.read_line().await {
136                Ok(result) => return Ok(result),
137                Err(ref e) if attempt < self.max_io_retries && is_transient_error(e) => {
138                    attempt += 1;
139                    let delay = Duration::from_millis(50 * 2_u64.pow(attempt));
140                    tracing::debug!(error = %e, attempt, ?delay, "transient transport read error, retrying");
141                    tokio::time::sleep(delay).await;
142                }
143                Err(e) => return Err(e),
144            }
145        }
146    }
147}
148
149impl<T: Transport> WireClient for TransportWireClient<T> {
150    fn next_id(&mut self) -> String {
151        self.request_id_counter += 1;
152        format!("req-{}", self.request_id_counter)
153    }
154
155    async fn send_request<Params: Serialize + Sync>(
156        &mut self,
157        req: &JsonRpcRequest<Params>,
158    ) -> Result<(), WireError> {
159        let line = serde_json::to_string(req).map_err(WireError::from)?;
160        self.transport.write_line(&line).await
161    }
162
163    async fn read_raw_message(&mut self) -> Result<RawWireMessage, WireError> {
164        if let Some(msg) = self.pending_messages.pop_front() {
165            return Ok(msg);
166        }
167        let Some(line) = self.transport.read_line().await? else {
168            return Err(WireError::StreamClosed);
169        };
170        serde_json::from_str(&line).map_err(WireError::from)
171    }
172
173    async fn read_raw_message_timeout(
174        &mut self,
175        timeout: Duration,
176    ) -> Result<RawWireMessage, WireError> {
177        tokio::time::timeout(timeout, self.read_raw_message())
178            .await
179            .map_or(Err(WireError::Timeout(timeout)), |msg| msg)
180    }
181
182    async fn read_response<Res: DeserializeOwned + Send>(
183        &mut self,
184        expected_id: &str,
185    ) -> Result<Res, WireError> {
186        let timeout = self.default_timeout;
187        let fut = async {
188            loop {
189                if let Some(idx) = self
190                    .pending_messages
191                    .iter()
192                    .position(|msg| msg.id.as_deref() == Some(expected_id))
193                {
194                    let msg = self
195                        .pending_messages
196                        .remove(idx)
197                        .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
198                    return decode_raw_response(msg, expected_id);
199                }
200
201                let Some(line) = self.read_line_with_retry().await? else {
202                    return Err(WireError::StreamClosed);
203                };
204                let msg: RawWireMessage = serde_json::from_str(&line).map_err(WireError::from)?;
205                if msg.id.as_deref() == Some(expected_id) {
206                    return decode_raw_response(msg, expected_id);
207                }
208                if self.pending_messages.len() >= MAX_PENDING_MESSAGES {
209                    return Err(WireError::Internal(format!(
210                        "pending message buffer overflow ({MAX_PENDING_MESSAGES} entries) waiting for id {expected_id:?}"
211                    )));
212                }
213                self.pending_messages.push_back(msg);
214            }
215        };
216
217        match timeout {
218            Some(d) => tokio::time::timeout(d, fut)
219                .await
220                .map_err(|_| WireError::Timeout(d))?,
221            None => fut.await,
222        }
223    }
224
225    async fn send_response<Res: Serialize + Send>(
226        &mut self,
227        id: &str,
228        result: Res,
229    ) -> Result<(), WireError> {
230        let resp = JsonRpcSuccessResponse {
231            jsonrpc: crate::protocol::JsonRpcVersion::V2,
232            id: id.to_string(),
233            result,
234        };
235        let line = serde_json::to_string(&resp).map_err(WireError::from)?;
236        self.transport.write_line(&line).await
237    }
238
239    async fn send_error(&mut self, id: &str, code: i32, message: &str) -> Result<(), WireError> {
240        let resp = JsonRpcErrorResponse {
241            jsonrpc: crate::protocol::JsonRpcVersion::V2,
242            id: id.to_string(),
243            error: crate::protocol::JsonRpcError {
244                code,
245                message: message.to_string(),
246                data: None,
247            },
248        };
249        let line = serde_json::to_string(&resp).map_err(WireError::from)?;
250        self.transport.write_line(&line).await
251    }
252
253    async fn initialize(
254        &mut self,
255        params: InitializeParams,
256    ) -> Result<InitializeResult, WireError> {
257        let id = self.next_id();
258        let req = JsonRpcRequest {
259            jsonrpc: crate::protocol::JsonRpcVersion::V2,
260            method: "initialize".to_string(),
261            id: id.clone(),
262            params,
263        };
264        self.send_request(&req).await?;
265
266        let Some(line) = self.transport.read_line().await? else {
267            return Err(WireError::StreamClosed);
268        };
269
270        // Check for method-not-found error (-32601)
271        if let Ok(error_resp) = serde_json::from_str::<JsonRpcErrorResponse>(&line) {
272            if error_resp.error.code == crate::protocol::METHOD_NOT_FOUND {
273                tracing::warn!(
274                    code = error_resp.error.code,
275                    "Server does not support initialize, falling back to legacy no-handshake mode"
276                );
277                self.handshake_done = true;
278                return Ok(InitializeResult {
279                    protocol_version: crate::WIRE_PROTOCOL_LEGACY_VERSION.to_string(),
280                    server: crate::protocol::ServerInfo {
281                        name: "unknown".to_string(),
282                        version: "unknown".to_string(),
283                    },
284                    slash_commands: vec![],
285                    external_tools: None,
286                    capabilities: None,
287                    hooks: None,
288                });
289            }
290            return Err(WireError::RequestFailed {
291                code: error_resp.error.code,
292                message: error_resp.error.message,
293            });
294        }
295
296        let resp: JsonRpcSuccessResponse<InitializeResult> =
297            serde_json::from_str(&line).map_err(WireError::from)?;
298        self.handshake_done = true;
299        Ok(resp.result)
300    }
301
302    fn is_handshake_done(&self) -> bool {
303        self.handshake_done
304    }
305
306    async fn shutdown(self) -> Result<(), WireError> {
307        self.transport.shutdown().await
308    }
309}
310
311fn decode_raw_response<T: DeserializeOwned>(
312    msg: RawWireMessage,
313    _expected_id: &str,
314) -> Result<T, WireError> {
315    if let Some(error) = msg.error {
316        return Err(WireError::RequestFailed {
317            code: error.code,
318            message: error.message,
319        });
320    }
321    let result = msg
322        .result
323        .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
324    serde_json::from_value(result).map_err(WireError::from)
325}
326
327// ============================================================================
328// ChildProcessTransport
329// ============================================================================
330
331/// A transport backed by a child process's stdin/stdout.
332#[derive(Debug)]
333pub struct ChildProcessTransport {
334    child: Option<Child>,
335    stdin: Option<ChildStdin>,
336    stdout_reader: FramedRead<ChildStdout, LinesCodec>,
337    stderr_handle: Option<tokio::task::JoinHandle<()>>,
338    cancel_token: CancellationToken,
339}
340
341impl ChildProcessTransport {
342    /// Spawn a new `kimi` process in wire mode.
343    ///
344    /// # Errors
345    ///
346    /// Returns [`WireError::SpawnFailed`] if the process cannot be started.
347    pub async fn spawn(
348        kimi_binary: &str,
349        work_dir: Option<&Path>,
350        session: Option<&str>,
351        model: Option<&str>,
352    ) -> Result<Self, WireError> {
353        let mut child = None;
354        for attempt in 0..3 {
355            let mut cmd = tokio::process::Command::new(kimi_binary);
356            cmd.arg("--wire");
357            if let Some(dir) = work_dir {
358                cmd.arg("--work-dir").arg(dir);
359            }
360            if let Some(s) = session {
361                cmd.arg("--session").arg(s);
362            }
363            if let Some(m) = model {
364                cmd.arg("--model").arg(m);
365            }
366            cmd.stdin(std::process::Stdio::piped())
367                .stdout(std::process::Stdio::piped())
368                .stderr(std::process::Stdio::piped());
369
370            match cmd.kill_on_drop(true).spawn() {
371                Ok(spawned) => {
372                    child = Some(spawned);
373                    break;
374                }
375                // ETXTBSY (Text file busy) on Unix-like systems — the binary may
376                // still be written by another process. Retry a couple of times.
377                Err(err) if err.raw_os_error() == Some(26) && attempt < 2 => {
378                    tokio::time::sleep(Duration::from_millis(25)).await;
379                }
380                Err(err) => {
381                    return Err(WireError::SpawnFailed(err.to_string()));
382                }
383            }
384        }
385
386        let mut child =
387            child.ok_or_else(|| WireError::SpawnFailed("all spawn attempts failed".to_string()))?;
388        let stdin = child
389            .stdin
390            .take()
391            .ok_or_else(|| WireError::SpawnFailed("no stdin".to_string()))?;
392        let stdout = child
393            .stdout
394            .take()
395            .ok_or_else(|| WireError::SpawnFailed("no stdout".to_string()))?;
396        let stdout_reader = FramedRead::new(
397            stdout,
398            LinesCodec::new_with_max_length(MAX_WIRE_LINE_LENGTH),
399        );
400
401        let cancel_token = CancellationToken::new();
402        let stderr_cancel = cancel_token.clone();
403        let stderr_handle = child.stderr.take().map(|stderr| {
404            tokio::spawn(async move {
405                let mut reader = BufReader::new(stderr).lines();
406                loop {
407                    tokio::select! {
408                        biased;
409                        () = stderr_cancel.cancelled() => break,
410                        line = reader.next_line() => {
411                            match line {
412                                Ok(Some(line)) => {
413                                    #[cfg(feature = "redact")]
414                                    tracing::warn!(target: "kimi.stderr", "{}", crate::protocol::redact::scrub_secret_patterns(&line));
415                                    #[cfg(not(feature = "redact"))]
416                                    tracing::warn!(target: "kimi.stderr", "{line}");
417                                }
418                                _ => break,
419                            }
420                        }
421                    }
422                }
423            })
424        });
425
426        tracing::info!(
427            kimi_binary,
428            ?work_dir,
429            ?session,
430            ?model,
431            "child process transport spawned"
432        );
433        Ok(Self {
434            child: Some(child),
435            stdin: Some(stdin),
436            stdout_reader,
437            stderr_handle,
438            cancel_token,
439        })
440    }
441}
442
443impl Transport for ChildProcessTransport {
444    async fn read_line(&mut self) -> Result<Option<String>, WireError> {
445        use tokio_stream::StreamExt;
446        match self.stdout_reader.next().await {
447            Some(Ok(line)) => {
448                tracing::trace!(len = line.len(), "read line from child process transport");
449                Ok(Some(line))
450            }
451            Some(Err(e)) => Err(WireError::Io(e.to_string())),
452            None => Ok(None),
453        }
454    }
455
456    async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
457        let stdin = self.stdin.as_mut().ok_or(WireError::StreamClosed)?;
458        stdin.write_all(line.as_bytes()).await?;
459        stdin.write_all(b"\n").await?;
460        stdin.flush().await?;
461        tracing::trace!(len = line.len(), "wrote line to child process transport");
462        Ok(())
463    }
464
465    async fn shutdown(mut self) -> Result<(), WireError> {
466        tracing::info!("shutting down child process transport");
467        // Close stdin so the child sees EOF.
468        drop(self.stdin.take());
469
470        // Wait up to 3 seconds for the child to exit gracefully.
471        let grace = Duration::from_secs(3);
472        if let Some(mut child) = self.child.take() {
473            match tokio::time::timeout(grace, child.wait()).await {
474                Ok(Ok(_) | Err(_)) => {}
475                Err(_) => {
476                    // Best-effort kill after graceful shutdown timed out.
477                    // Safe to ignore: child is already unresponsive.
478                    #[allow(unused_must_use)]
479                    let _ = child.kill().await;
480                }
481            }
482        }
483
484        // Abort the stderr task and cancel the token.
485        self.cancel_token.cancel();
486        if let Some(handle) = self.stderr_handle.take() {
487            handle.abort();
488        }
489
490        Ok(())
491    }
492}
493
494impl Drop for ChildProcessTransport {
495    fn drop(&mut self) {
496        self.cancel_token.cancel();
497        if let Some(handle) = self.stderr_handle.take() {
498            handle.abort();
499        }
500    }
501}
502
503// ============================================================================
504// ChannelTransport
505// ============================================================================
506
507/// A transport backed by in-memory channels for testing.
508#[derive(Debug)]
509pub struct ChannelTransport {
510    rx: tokio::sync::mpsc::UnboundedReceiver<String>,
511    tx: tokio::sync::mpsc::UnboundedSender<String>,
512}
513
514impl ChannelTransport {
515    /// Create a new pair of connected transports.
516    #[must_use]
517    pub fn pair() -> (Self, Self) {
518        let (tx1, rx1) = tokio::sync::mpsc::unbounded_channel();
519        let (tx2, rx2) = tokio::sync::mpsc::unbounded_channel();
520        (Self { rx: rx1, tx: tx2 }, Self { rx: rx2, tx: tx1 })
521    }
522}
523
524impl Transport for ChannelTransport {
525    async fn read_line(&mut self) -> Result<Option<String>, WireError> {
526        Ok(self.rx.recv().await)
527    }
528
529    async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
530        self.tx
531            .send(line.to_string())
532            .map_err(|_| WireError::StreamClosed)
533    }
534}