Skip to main content

kimi_wire/
transport.rs

1//! Transport layer for the Kimi Wire protocol.
2//!
3//! The [`crate::transport::Transport`] trait abstracts how raw JSON lines are read and written,
4//! allowing the same [`WireClient`](crate::client::WireClient) logic to run
5//! over stdio, in-memory buffers, or custom channels.
6
7use std::collections::VecDeque;
8use std::path::Path;
9use std::time::Duration;
10
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout};
15use tokio_util::codec::{FramedRead, LinesCodec};
16use tokio_util::sync::CancellationToken;
17
18use crate::client::WireClient;
19use crate::error::WireError;
20use crate::protocol::{
21    InitializeParams, InitializeResult, JsonRpcErrorResponse, JsonRpcRequest,
22    JsonRpcSuccessResponse, RawWireMessage,
23};
24
25/// Maximum wire-line length in bytes.
26///
27/// Each Kimi wire message arrives as a single newline-terminated JSON line.
28/// Without a hard cap, a peer that never emits a newline can drive
29/// `read_line` to allocate until OOM.
30pub const MAX_WIRE_LINE_LENGTH: usize = 16 * 1024 * 1024;
31
32/// Async transport for reading and writing newline-delimited JSON.
33pub trait Transport: Send {
34    /// Read the next line from the transport.
35    fn read_line(
36        &mut self,
37    ) -> impl std::future::Future<Output = Result<Option<String>, WireError>> + Send;
38
39    /// Write a line to the transport.
40    fn write_line(
41        &mut self,
42        line: &str,
43    ) -> impl std::future::Future<Output = Result<(), WireError>> + Send;
44}
45
46// ============================================================================
47// TransportWireClient
48// ============================================================================
49
50/// A [`WireClient`] implementation backed by any [`Transport`].
51pub struct TransportWireClient<T: Transport> {
52    transport: T,
53    request_id_counter: u64,
54    handshake_done: bool,
55    pending_messages: VecDeque<RawWireMessage>,
56}
57
58impl<T: Transport> TransportWireClient<T> {
59    /// Create a new client wrapping the given transport.
60    pub fn new(transport: T) -> Self {
61        Self {
62            transport,
63            request_id_counter: 0,
64            handshake_done: false,
65            pending_messages: VecDeque::new(),
66        }
67    }
68
69    /// Consume the client and return the underlying transport.
70    pub fn into_transport(self) -> T {
71        self.transport
72    }
73}
74
75impl<T: Transport> WireClient for TransportWireClient<T> {
76    fn next_id(&mut self) -> String {
77        self.request_id_counter += 1;
78        format!("req-{}", self.request_id_counter)
79    }
80
81    async fn send_request<Params: Serialize + Sync>(
82        &mut self,
83        req: &JsonRpcRequest<Params>,
84    ) -> Result<(), WireError> {
85        let line = serde_json::to_string(req).map_err(WireError::from)?;
86        self.transport.write_line(&line).await
87    }
88
89    async fn read_raw_message(&mut self) -> Result<RawWireMessage, WireError> {
90        if let Some(msg) = self.pending_messages.pop_front() {
91            return Ok(msg);
92        }
93        let line = match self.transport.read_line().await? {
94            Some(line) => line,
95            None => return Err(WireError::StreamClosed),
96        };
97        serde_json::from_str(&line).map_err(WireError::from)
98    }
99
100    async fn read_raw_message_timeout(
101        &mut self,
102        timeout: Duration,
103    ) -> Result<RawWireMessage, WireError> {
104        match tokio::time::timeout(timeout, self.read_raw_message()).await {
105            Ok(msg) => msg,
106            Err(_) => Err(WireError::Timeout(timeout)),
107        }
108    }
109
110    async fn read_response<Res: DeserializeOwned + Send>(
111        &mut self,
112        expected_id: &str,
113    ) -> Result<Res, WireError> {
114        loop {
115            if let Some(idx) = self
116                .pending_messages
117                .iter()
118                .position(|msg| msg.id.as_deref() == Some(expected_id))
119            {
120                let msg = self
121                    .pending_messages
122                    .remove(idx)
123                    .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
124                return decode_raw_response(msg, expected_id);
125            }
126
127            match self.read_raw_message().await? {
128                msg if msg.id.as_deref() == Some(expected_id) => {
129                    return decode_raw_response(msg, expected_id);
130                }
131                other => {
132                    self.pending_messages.push_back(other);
133                }
134            }
135        }
136    }
137
138    async fn send_response<Res: Serialize + Send>(
139        &mut self,
140        id: &str,
141        result: Res,
142    ) -> Result<(), WireError> {
143        let resp = JsonRpcSuccessResponse {
144            jsonrpc: crate::protocol::JsonRpcVersion::default(),
145            id: id.to_string(),
146            result,
147        };
148        let line = serde_json::to_string(&resp).map_err(WireError::from)?;
149        self.transport.write_line(&line).await
150    }
151
152    async fn send_error(
153        &mut self,
154        id: &str,
155        code: i32,
156        message: &str,
157    ) -> Result<(), WireError> {
158        let resp = JsonRpcErrorResponse {
159            jsonrpc: crate::protocol::JsonRpcVersion::default(),
160            id: id.to_string(),
161            error: crate::protocol::JsonRpcError {
162                code,
163                message: message.to_string(),
164                data: None,
165            },
166        };
167        let line = serde_json::to_string(&resp).map_err(WireError::from)?;
168        self.transport.write_line(&line).await
169    }
170
171    async fn initialize(
172        &mut self,
173        params: InitializeParams,
174    ) -> Result<InitializeResult, WireError> {
175        let id = self.next_id();
176        let req = JsonRpcRequest {
177            jsonrpc: crate::protocol::JsonRpcVersion::default(),
178            method: "initialize".to_string(),
179            id: id.clone(),
180            params,
181        };
182        self.send_request(&req).await?;
183
184        let line = match self.transport.read_line().await? {
185            Some(line) => line,
186            None => return Err(WireError::StreamClosed),
187        };
188
189        // Check for method-not-found error (-32601)
190        if let Ok(error_resp) = serde_json::from_str::<JsonRpcErrorResponse>(&line) {
191            if error_resp.error.code == crate::protocol::METHOD_NOT_FOUND {
192                tracing::warn!(
193                    code = error_resp.error.code,
194                    "Server does not support initialize, falling back to legacy no-handshake mode"
195                );
196                self.handshake_done = true;
197                return Ok(InitializeResult {
198                    protocol_version: crate::WIRE_PROTOCOL_LEGACY_VERSION.to_string(),
199                    server: crate::protocol::ServerInfo {
200                        name: "unknown".to_string(),
201                        version: "unknown".to_string(),
202                    },
203                    slash_commands: vec![],
204                    external_tools: None,
205                    capabilities: None,
206                    hooks: None,
207                });
208            }
209            return Err(WireError::RequestFailed {
210                code: error_resp.error.code,
211                message: error_resp.error.message,
212            });
213        }
214
215        let resp: JsonRpcSuccessResponse<InitializeResult> =
216            serde_json::from_str(&line).map_err(WireError::from)?;
217        self.handshake_done = true;
218        Ok(resp.result)
219    }
220
221    fn is_handshake_done(&self) -> bool {
222        self.handshake_done
223    }
224
225    async fn shutdown(self) -> Result<(), WireError> {
226        Ok(())
227    }
228}
229
230fn decode_raw_response<T: DeserializeOwned>(
231    msg: RawWireMessage,
232    _expected_id: &str,
233) -> Result<T, WireError> {
234    if let Some(error) = msg.error {
235        return Err(WireError::RequestFailed {
236            code: error.code,
237            message: error.message,
238        });
239    }
240    let result = msg
241        .result
242        .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
243    serde_json::from_value(result).map_err(WireError::from)
244}
245
246// ============================================================================
247// ChildProcessTransport
248// ============================================================================
249
250/// A transport backed by a child process's stdin/stdout.
251pub struct ChildProcessTransport {
252    #[allow(dead_code)]
253    child: Child,
254    stdin: ChildStdin,
255    stdout_reader: FramedRead<ChildStdout, LinesCodec>,
256    stderr_handle: Option<tokio::task::JoinHandle<()>>,
257    cancel_token: CancellationToken,
258}
259
260impl ChildProcessTransport {
261    /// Spawn a new `kimi` process in wire mode.
262    ///
263    /// # Errors
264    ///
265    /// Returns [`WireError::SpawnFailed`] if the process cannot be started.
266    pub async fn spawn(
267        kimi_binary: &str,
268        work_dir: Option<&Path>,
269        session: Option<&str>,
270        model: Option<&str>,
271    ) -> Result<Self, WireError> {
272        let mut child = None;
273        for attempt in 0..3 {
274            let mut cmd = tokio::process::Command::new(kimi_binary);
275            cmd.arg("--wire");
276            if let Some(dir) = work_dir {
277                cmd.arg("--work-dir").arg(dir);
278            }
279            if let Some(s) = session {
280                cmd.arg("--session").arg(s);
281            }
282            if let Some(m) = model {
283                cmd.arg("--model").arg(m);
284            }
285            cmd.stdin(std::process::Stdio::piped())
286                .stdout(std::process::Stdio::piped())
287                .stderr(std::process::Stdio::piped());
288
289            match cmd.kill_on_drop(true).spawn() {
290                Ok(spawned) => {
291                    child = Some(spawned);
292                    break;
293                }
294                Err(err) if err.raw_os_error() == Some(26) && attempt < 2 => {
295                    tokio::time::sleep(Duration::from_millis(25)).await;
296                }
297                Err(err) => {
298                    return Err(WireError::SpawnFailed(err.to_string()));
299                }
300            }
301        }
302
303        let mut child = child
304            .ok_or_else(|| WireError::SpawnFailed("all spawn attempts failed".to_string()))?;
305        let stdin = child
306            .stdin
307            .take()
308            .ok_or_else(|| WireError::SpawnFailed("no stdin".to_string()))?;
309        let stdout = child
310            .stdout
311            .take()
312            .ok_or_else(|| WireError::SpawnFailed("no stdout".to_string()))?;
313        let stdout_reader = FramedRead::new(
314            stdout,
315            LinesCodec::new_with_max_length(MAX_WIRE_LINE_LENGTH),
316        );
317
318        let cancel_token = CancellationToken::new();
319        let stderr_cancel = cancel_token.clone();
320        let stderr_handle = child.stderr.take().map(|stderr| {
321            tokio::spawn(async move {
322                let mut reader = BufReader::new(stderr).lines();
323                loop {
324                    tokio::select! {
325                        biased;
326                        _ = stderr_cancel.cancelled() => break,
327                        line = reader.next_line() => {
328                            match line {
329                                Ok(Some(line)) => {
330                                    #[cfg(feature = "redact")]
331                                    tracing::warn!(target: "kimi.stderr", "{}", crate::protocol::redact::scrub_secret_patterns(&line));
332                                    #[cfg(not(feature = "redact"))]
333                                    tracing::warn!(target: "kimi.stderr", "{line}");
334                                }
335                                _ => break,
336                            }
337                        }
338                    }
339                }
340            })
341        });
342
343        Ok(Self {
344            child,
345            stdin,
346            stdout_reader,
347            stderr_handle,
348            cancel_token,
349        })
350    }
351}
352
353impl Transport for ChildProcessTransport {
354    async fn read_line(&mut self) -> Result<Option<String>, WireError> {
355        use tokio_stream::StreamExt;
356        match self.stdout_reader.next().await {
357            Some(Ok(line)) => Ok(Some(line)),
358            Some(Err(e)) => Err(WireError::Io(e.to_string())),
359            None => Ok(None),
360        }
361    }
362
363    async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
364        self.stdin.write_all(line.as_bytes()).await?;
365        self.stdin.write_all(b"\n").await?;
366        self.stdin.flush().await?;
367        Ok(())
368    }
369}
370
371impl Drop for ChildProcessTransport {
372    fn drop(&mut self) {
373        self.cancel_token.cancel();
374        if let Some(handle) = self.stderr_handle.take() {
375            handle.abort();
376        }
377    }
378}
379
380// ============================================================================
381// ChannelTransport
382// ============================================================================
383
384/// A transport backed by in-memory channels for testing.
385#[derive(Debug)]
386pub struct ChannelTransport {
387    rx: tokio::sync::mpsc::UnboundedReceiver<String>,
388    tx: tokio::sync::mpsc::UnboundedSender<String>,
389}
390
391impl ChannelTransport {
392    /// Create a new pair of connected transports.
393    pub fn pair() -> (Self, Self) {
394        let (tx1, rx1) = tokio::sync::mpsc::unbounded_channel();
395        let (tx2, rx2) = tokio::sync::mpsc::unbounded_channel();
396        (
397            Self { rx: rx1, tx: tx2 },
398            Self { rx: rx2, tx: tx1 },
399        )
400    }
401}
402
403impl Transport for ChannelTransport {
404    async fn read_line(&mut self) -> Result<Option<String>, WireError> {
405        Ok(self.rx.recv().await)
406    }
407
408    async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
409        self.tx
410            .send(line.to_string())
411            .map_err(|_| WireError::StreamClosed)
412    }
413}