Skip to main content

kimi_wire/
transport.rs

1//! Transport layer for the Kimi Wire protocol.
2//!
3//! The [`crate::transport::Transport`] trait abstracts how raw JSON lines are read and written,
4//! allowing the same [`WireClient`](crate::client::WireClient) logic to run
5//! over stdio, in-memory buffers, or custom channels.
6
7use std::collections::VecDeque;
8use std::path::Path;
9use std::time::Duration;
10
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout};
15use tokio_util::codec::{FramedRead, LinesCodec};
16use tokio_util::sync::CancellationToken;
17
18use crate::client::WireClient;
19use crate::error::WireError;
20use crate::protocol::{
21    InitializeParams, InitializeResult, JsonRpcErrorResponse, JsonRpcRequest,
22    JsonRpcSuccessResponse, RawWireMessage,
23};
24
25/// Maximum wire-line length in bytes.
26///
27/// Each Kimi wire message arrives as a single newline-terminated JSON line.
28/// Without a hard cap, a peer that never emits a newline can drive
29/// `read_line` to allocate until OOM.
30pub const MAX_WIRE_LINE_LENGTH: usize = 16 * 1024 * 1024;
31
32/// Maximum number of out-of-order wire messages buffered while waiting for a
33/// specific response id.
34///
35/// Without a cap, a misbehaving peer that emits unrelated ids can drive
36/// `pending_messages` to allocate until OOM.
37pub const MAX_PENDING_MESSAGES: usize = 1024;
38
39/// Returns `true` for errors where a retry might succeed.
40fn is_transient_error(err: &WireError) -> bool {
41    matches!(err, WireError::Io(_) | WireError::Timeout(_))
42}
43
44/// Async transport for reading and writing newline-delimited JSON.
45pub trait Transport: Send {
46    /// Read the next line from the transport.
47    fn read_line(
48        &mut self,
49    ) -> impl std::future::Future<Output = Result<Option<String>, WireError>> + Send;
50
51    /// Write a line to the transport.
52    fn write_line(
53        &mut self,
54        line: &str,
55    ) -> impl std::future::Future<Output = Result<(), WireError>> + Send;
56
57    /// Gracefully close the transport.
58    ///
59    /// Default implementation returns `Ok(())`. Implementations that wrap a
60    /// child process, network socket, or other resource should override this
61    /// to release the resource cleanly. Called by
62    /// `TransportWireClient::shutdown`.
63    fn shutdown(self) -> impl std::future::Future<Output = Result<(), WireError>> + Send
64    where
65        Self: Sized,
66    {
67        async { Ok(()) }
68    }
69}
70
71// ============================================================================
72// TransportWireClient
73// ============================================================================
74
75/// A [`WireClient`] implementation backed by any [`Transport`].
76pub struct TransportWireClient<T: Transport> {
77    transport: T,
78    request_id_counter: u64,
79    handshake_done: bool,
80    pending_messages: VecDeque<RawWireMessage>,
81    default_timeout: Option<Duration>,
82    max_io_retries: u32,
83}
84
85impl<T: Transport> std::fmt::Debug for TransportWireClient<T> {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("TransportWireClient")
88            .field("request_id_counter", &self.request_id_counter)
89            .field("handshake_done", &self.handshake_done)
90            .field("pending_messages", &self.pending_messages.len())
91            .field("default_timeout", &self.default_timeout)
92            .field("max_io_retries", &self.max_io_retries)
93            .finish_non_exhaustive()
94    }
95}
96
97impl<T: Transport> TransportWireClient<T> {
98    /// Create a new client wrapping the given transport.
99    pub const fn new(transport: T) -> Self {
100        Self {
101            transport,
102            request_id_counter: 0,
103            handshake_done: false,
104            pending_messages: VecDeque::new(),
105            default_timeout: None,
106            max_io_retries: 0,
107        }
108    }
109
110    /// Consume the client and return the underlying transport.
111    pub fn into_transport(self) -> T {
112        self.transport
113    }
114
115    /// Set a default timeout applied to every `read_response` call.
116    /// Without this, `read_response` waits indefinitely for a matching id.
117    #[must_use]
118    pub const fn with_default_timeout(mut self, timeout: Duration) -> Self {
119        self.default_timeout = Some(timeout);
120        self
121    }
122
123    /// Set the maximum number of retries for transient I/O errors during
124    /// `read_response`. Each retry waits exponentially longer
125    /// (`50ms * 2^attempt`).
126    #[must_use]
127    pub const fn with_max_io_retries(mut self, retries: u32) -> Self {
128        self.max_io_retries = if retries > 5 { 5 } else { retries };
129        self
130    }
131
132    async fn read_line_with_retry(&mut self) -> Result<Option<String>, WireError> {
133        let mut attempt = 0;
134        loop {
135            match self.transport.read_line().await {
136                Ok(result) => return Ok(result),
137                Err(ref e) if attempt < self.max_io_retries && is_transient_error(e) => {
138                    attempt += 1;
139                    let delay = Duration::from_millis(50 * 2_u64.pow(attempt));
140                    tracing::debug!(error = %e, attempt, ?delay, "transient transport read error, retrying");
141                    tokio::time::sleep(delay).await;
142                }
143                Err(e) => return Err(e),
144            }
145        }
146    }
147}
148
149impl<T: Transport> WireClient for TransportWireClient<T> {
150    fn next_id(&mut self) -> String {
151        self.request_id_counter += 1;
152        format!("req-{}", self.request_id_counter)
153    }
154
155    async fn send_request<Params: Serialize + Sync>(
156        &mut self,
157        req: &JsonRpcRequest<Params>,
158    ) -> Result<(), WireError> {
159        let line = serde_json::to_string(req).map_err(WireError::from)?;
160        self.transport.write_line(&line).await
161    }
162
163    async fn read_raw_message(&mut self) -> Result<RawWireMessage, WireError> {
164        if let Some(msg) = self.pending_messages.pop_front() {
165            return Ok(msg);
166        }
167        let line = match self.transport.read_line().await? {
168            Some(line) => line,
169            None => return Err(WireError::StreamClosed),
170        };
171        serde_json::from_str(&line).map_err(WireError::from)
172    }
173
174    async fn read_raw_message_timeout(
175        &mut self,
176        timeout: Duration,
177    ) -> Result<RawWireMessage, WireError> {
178        match tokio::time::timeout(timeout, self.read_raw_message()).await {
179            Ok(msg) => msg,
180            Err(_) => Err(WireError::Timeout(timeout)),
181        }
182    }
183
184    async fn read_response<Res: DeserializeOwned + Send>(
185        &mut self,
186        expected_id: &str,
187    ) -> Result<Res, WireError> {
188        let timeout = self.default_timeout;
189        let fut = async {
190            loop {
191                if let Some(idx) = self
192                    .pending_messages
193                    .iter()
194                    .position(|msg| msg.id.as_deref() == Some(expected_id))
195                {
196                    let msg = self
197                        .pending_messages
198                        .remove(idx)
199                        .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
200                    return decode_raw_response(msg, expected_id);
201                }
202
203                let line = match self.read_line_with_retry().await? {
204                    Some(line) => line,
205                    None => return Err(WireError::StreamClosed),
206                };
207                let msg: RawWireMessage = serde_json::from_str(&line).map_err(WireError::from)?;
208                if msg.id.as_deref() == Some(expected_id) {
209                    return decode_raw_response(msg, expected_id);
210                }
211                if self.pending_messages.len() >= MAX_PENDING_MESSAGES {
212                    return Err(WireError::Internal(format!(
213                        "pending message buffer overflow ({} entries) waiting for id {:?}",
214                        MAX_PENDING_MESSAGES, expected_id
215                    )));
216                }
217                self.pending_messages.push_back(msg);
218            }
219        };
220
221        match timeout {
222            Some(d) => tokio::time::timeout(d, fut)
223                .await
224                .map_err(|_| WireError::Timeout(d))?,
225            None => fut.await,
226        }
227    }
228
229    async fn send_response<Res: Serialize + Send>(
230        &mut self,
231        id: &str,
232        result: Res,
233    ) -> Result<(), WireError> {
234        let resp = JsonRpcSuccessResponse {
235            jsonrpc: crate::protocol::JsonRpcVersion::V2,
236            id: id.to_string(),
237            result,
238        };
239        let line = serde_json::to_string(&resp).map_err(WireError::from)?;
240        self.transport.write_line(&line).await
241    }
242
243    async fn send_error(&mut self, id: &str, code: i32, message: &str) -> Result<(), WireError> {
244        let resp = JsonRpcErrorResponse {
245            jsonrpc: crate::protocol::JsonRpcVersion::V2,
246            id: id.to_string(),
247            error: crate::protocol::JsonRpcError {
248                code,
249                message: message.to_string(),
250                data: None,
251            },
252        };
253        let line = serde_json::to_string(&resp).map_err(WireError::from)?;
254        self.transport.write_line(&line).await
255    }
256
257    async fn initialize(
258        &mut self,
259        params: InitializeParams,
260    ) -> Result<InitializeResult, WireError> {
261        let id = self.next_id();
262        let req = JsonRpcRequest {
263            jsonrpc: crate::protocol::JsonRpcVersion::V2,
264            method: "initialize".to_string(),
265            id: id.clone(),
266            params,
267        };
268        self.send_request(&req).await?;
269
270        let line = match self.transport.read_line().await? {
271            Some(line) => line,
272            None => return Err(WireError::StreamClosed),
273        };
274
275        // Check for method-not-found error (-32601)
276        if let Ok(error_resp) = serde_json::from_str::<JsonRpcErrorResponse>(&line) {
277            if error_resp.error.code == crate::protocol::METHOD_NOT_FOUND {
278                tracing::warn!(
279                    code = error_resp.error.code,
280                    "Server does not support initialize, falling back to legacy no-handshake mode"
281                );
282                self.handshake_done = true;
283                return Ok(InitializeResult {
284                    protocol_version: crate::WIRE_PROTOCOL_LEGACY_VERSION.to_string(),
285                    server: crate::protocol::ServerInfo {
286                        name: "unknown".to_string(),
287                        version: "unknown".to_string(),
288                    },
289                    slash_commands: vec![],
290                    external_tools: None,
291                    capabilities: None,
292                    hooks: None,
293                });
294            }
295            return Err(WireError::RequestFailed {
296                code: error_resp.error.code,
297                message: error_resp.error.message,
298            });
299        }
300
301        let resp: JsonRpcSuccessResponse<InitializeResult> =
302            serde_json::from_str(&line).map_err(WireError::from)?;
303        self.handshake_done = true;
304        Ok(resp.result)
305    }
306
307    fn is_handshake_done(&self) -> bool {
308        self.handshake_done
309    }
310
311    async fn shutdown(self) -> Result<(), WireError> {
312        self.transport.shutdown().await
313    }
314}
315
316fn decode_raw_response<T: DeserializeOwned>(
317    msg: RawWireMessage,
318    _expected_id: &str,
319) -> Result<T, WireError> {
320    if let Some(error) = msg.error {
321        return Err(WireError::RequestFailed {
322            code: error.code,
323            message: error.message,
324        });
325    }
326    let result = msg
327        .result
328        .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
329    serde_json::from_value(result).map_err(WireError::from)
330}
331
332// ============================================================================
333// ChildProcessTransport
334// ============================================================================
335
336/// A transport backed by a child process's stdin/stdout.
337#[derive(Debug)]
338pub struct ChildProcessTransport {
339    child: Option<Child>,
340    stdin: Option<ChildStdin>,
341    stdout_reader: FramedRead<ChildStdout, LinesCodec>,
342    stderr_handle: Option<tokio::task::JoinHandle<()>>,
343    cancel_token: CancellationToken,
344}
345
346impl ChildProcessTransport {
347    /// Spawn a new `kimi` process in wire mode.
348    ///
349    /// # Errors
350    ///
351    /// Returns [`WireError::SpawnFailed`] if the process cannot be started.
352    pub async fn spawn(
353        kimi_binary: &str,
354        work_dir: Option<&Path>,
355        session: Option<&str>,
356        model: Option<&str>,
357    ) -> Result<Self, WireError> {
358        let mut child = None;
359        for attempt in 0..3 {
360            let mut cmd = tokio::process::Command::new(kimi_binary);
361            cmd.arg("--wire");
362            if let Some(dir) = work_dir {
363                cmd.arg("--work-dir").arg(dir);
364            }
365            if let Some(s) = session {
366                cmd.arg("--session").arg(s);
367            }
368            if let Some(m) = model {
369                cmd.arg("--model").arg(m);
370            }
371            cmd.stdin(std::process::Stdio::piped())
372                .stdout(std::process::Stdio::piped())
373                .stderr(std::process::Stdio::piped());
374
375            match cmd.kill_on_drop(true).spawn() {
376                Ok(spawned) => {
377                    child = Some(spawned);
378                    break;
379                }
380                // ETXTBSY (Text file busy) on Unix-like systems — the binary may
381                // still be written by another process. Retry a couple of times.
382                Err(err) if err.raw_os_error() == Some(26) && attempt < 2 => {
383                    tokio::time::sleep(Duration::from_millis(25)).await;
384                }
385                Err(err) => {
386                    return Err(WireError::SpawnFailed(err.to_string()));
387                }
388            }
389        }
390
391        let mut child =
392            child.ok_or_else(|| WireError::SpawnFailed("all spawn attempts failed".to_string()))?;
393        let stdin = child
394            .stdin
395            .take()
396            .ok_or_else(|| WireError::SpawnFailed("no stdin".to_string()))?;
397        let stdout = child
398            .stdout
399            .take()
400            .ok_or_else(|| WireError::SpawnFailed("no stdout".to_string()))?;
401        let stdout_reader = FramedRead::new(
402            stdout,
403            LinesCodec::new_with_max_length(MAX_WIRE_LINE_LENGTH),
404        );
405
406        let cancel_token = CancellationToken::new();
407        let stderr_cancel = cancel_token.clone();
408        let stderr_handle = child.stderr.take().map(|stderr| {
409            tokio::spawn(async move {
410                let mut reader = BufReader::new(stderr).lines();
411                loop {
412                    tokio::select! {
413                        biased;
414                        _ = stderr_cancel.cancelled() => break,
415                        line = reader.next_line() => {
416                            match line {
417                                Ok(Some(line)) => {
418                                    #[cfg(feature = "redact")]
419                                    tracing::warn!(target: "kimi.stderr", "{}", crate::protocol::redact::scrub_secret_patterns(&line));
420                                    #[cfg(not(feature = "redact"))]
421                                    tracing::warn!(target: "kimi.stderr", "{line}");
422                                }
423                                _ => break,
424                            }
425                        }
426                    }
427                }
428            })
429        });
430
431        tracing::info!(
432            kimi_binary,
433            ?work_dir,
434            ?session,
435            ?model,
436            "child process transport spawned"
437        );
438        Ok(Self {
439            child: Some(child),
440            stdin: Some(stdin),
441            stdout_reader,
442            stderr_handle,
443            cancel_token,
444        })
445    }
446}
447
448impl Transport for ChildProcessTransport {
449    async fn read_line(&mut self) -> Result<Option<String>, WireError> {
450        use tokio_stream::StreamExt;
451        match self.stdout_reader.next().await {
452            Some(Ok(line)) => {
453                tracing::trace!(len = line.len(), "read line from child process transport");
454                Ok(Some(line))
455            }
456            Some(Err(e)) => Err(WireError::Io(e.to_string())),
457            None => Ok(None),
458        }
459    }
460
461    async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
462        let stdin = self.stdin.as_mut().ok_or(WireError::StreamClosed)?;
463        stdin.write_all(line.as_bytes()).await?;
464        stdin.write_all(b"\n").await?;
465        stdin.flush().await?;
466        tracing::trace!(len = line.len(), "wrote line to child process transport");
467        Ok(())
468    }
469
470    async fn shutdown(mut self) -> Result<(), WireError> {
471        tracing::info!("shutting down child process transport");
472        // Close stdin so the child sees EOF.
473        drop(self.stdin.take());
474
475        // Wait up to 3 seconds for the child to exit gracefully.
476        let grace = Duration::from_secs(3);
477        if let Some(mut child) = self.child.take() {
478            match tokio::time::timeout(grace, child.wait()).await {
479                Ok(Ok(_)) => {}
480                Ok(Err(_)) => {}
481                Err(_) => {
482                    // Best-effort kill after graceful shutdown timed out.
483                    // Safe to ignore: child is already unresponsive.
484                    #[allow(unused_must_use)]
485                    let _ = child.kill().await;
486                }
487            }
488        }
489
490        // Abort the stderr task and cancel the token.
491        self.cancel_token.cancel();
492        if let Some(handle) = self.stderr_handle.take() {
493            handle.abort();
494        }
495
496        Ok(())
497    }
498}
499
500impl Drop for ChildProcessTransport {
501    fn drop(&mut self) {
502        self.cancel_token.cancel();
503        if let Some(handle) = self.stderr_handle.take() {
504            handle.abort();
505        }
506    }
507}
508
509// ============================================================================
510// ChannelTransport
511// ============================================================================
512
513/// A transport backed by in-memory channels for testing.
514#[derive(Debug)]
515pub struct ChannelTransport {
516    rx: tokio::sync::mpsc::UnboundedReceiver<String>,
517    tx: tokio::sync::mpsc::UnboundedSender<String>,
518}
519
520impl ChannelTransport {
521    /// Create a new pair of connected transports.
522    #[must_use]
523    pub fn pair() -> (Self, Self) {
524        let (tx1, rx1) = tokio::sync::mpsc::unbounded_channel();
525        let (tx2, rx2) = tokio::sync::mpsc::unbounded_channel();
526        (Self { rx: rx1, tx: tx2 }, Self { rx: rx2, tx: tx1 })
527    }
528}
529
530impl Transport for ChannelTransport {
531    async fn read_line(&mut self) -> Result<Option<String>, WireError> {
532        Ok(self.rx.recv().await)
533    }
534
535    async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
536        self.tx
537            .send(line.to_string())
538            .map_err(|_| WireError::StreamClosed)
539    }
540}