1use std::collections::VecDeque;
8use std::path::Path;
9use std::time::Duration;
10
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout};
15use tokio_util::codec::{FramedRead, LinesCodec};
16use tokio_util::sync::CancellationToken;
17
18use crate::client::WireClient;
19use crate::error::WireError;
20use crate::protocol::{
21 InitializeParams, InitializeResult, JsonRpcErrorResponse, JsonRpcRequest,
22 JsonRpcSuccessResponse, RawWireMessage,
23};
24
25pub const MAX_WIRE_LINE_LENGTH: usize = 16 * 1024 * 1024;
31
32pub const MAX_PENDING_MESSAGES: usize = 1024;
38
39const fn is_transient_error(err: &WireError) -> bool {
41 matches!(err, WireError::Io(_) | WireError::Timeout(_))
42}
43
44pub trait Transport: Send {
46 fn read_line(
48 &mut self,
49 ) -> impl std::future::Future<Output = Result<Option<String>, WireError>> + Send;
50
51 fn write_line(
53 &mut self,
54 line: &str,
55 ) -> impl std::future::Future<Output = Result<(), WireError>> + Send;
56
57 fn shutdown(self) -> impl std::future::Future<Output = Result<(), WireError>> + Send
64 where
65 Self: Sized,
66 {
67 async { Ok(()) }
68 }
69}
70
71pub struct TransportWireClient<T: Transport> {
77 transport: T,
78 request_id_counter: u64,
79 handshake_done: bool,
80 pending_messages: VecDeque<RawWireMessage>,
81 default_timeout: Option<Duration>,
82 max_io_retries: u32,
83}
84
85impl<T: Transport> std::fmt::Debug for TransportWireClient<T> {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("TransportWireClient")
88 .field("request_id_counter", &self.request_id_counter)
89 .field("handshake_done", &self.handshake_done)
90 .field("pending_messages", &self.pending_messages.len())
91 .field("default_timeout", &self.default_timeout)
92 .field("max_io_retries", &self.max_io_retries)
93 .finish_non_exhaustive()
94 }
95}
96
97impl<T: Transport> TransportWireClient<T> {
98 pub const fn new(transport: T) -> Self {
100 Self {
101 transport,
102 request_id_counter: 0,
103 handshake_done: false,
104 pending_messages: VecDeque::new(),
105 default_timeout: None,
106 max_io_retries: 0,
107 }
108 }
109
110 pub fn into_transport(self) -> T {
112 self.transport
113 }
114
115 #[must_use]
118 pub const fn with_default_timeout(mut self, timeout: Duration) -> Self {
119 self.default_timeout = Some(timeout);
120 self
121 }
122
123 #[must_use]
127 pub const fn with_max_io_retries(mut self, retries: u32) -> Self {
128 self.max_io_retries = if retries > 5 { 5 } else { retries };
129 self
130 }
131
132 async fn read_line_with_retry(&mut self) -> Result<Option<String>, WireError> {
133 let mut attempt = 0;
134 loop {
135 match self.transport.read_line().await {
136 Ok(result) => return Ok(result),
137 Err(ref e) if attempt < self.max_io_retries && is_transient_error(e) => {
138 attempt += 1;
139 let delay = Duration::from_millis(50 * 2_u64.pow(attempt));
140 tracing::debug!(error = %e, attempt, ?delay, "transient transport read error, retrying");
141 tokio::time::sleep(delay).await;
142 }
143 Err(e) => return Err(e),
144 }
145 }
146 }
147}
148
149impl<T: Transport> WireClient for TransportWireClient<T> {
150 fn next_id(&mut self) -> String {
151 self.request_id_counter += 1;
152 format!("req-{}", self.request_id_counter)
153 }
154
155 async fn send_request<Params: Serialize + Sync>(
156 &mut self,
157 req: &JsonRpcRequest<Params>,
158 ) -> Result<(), WireError> {
159 let line = serde_json::to_string(req).map_err(WireError::from)?;
160 self.transport.write_line(&line).await
161 }
162
163 async fn read_raw_message(&mut self) -> Result<RawWireMessage, WireError> {
164 if let Some(msg) = self.pending_messages.pop_front() {
165 return Ok(msg);
166 }
167 let Some(line) = self.transport.read_line().await? else {
168 return Err(WireError::StreamClosed);
169 };
170 serde_json::from_str(&line).map_err(WireError::from)
171 }
172
173 async fn read_raw_message_timeout(
174 &mut self,
175 timeout: Duration,
176 ) -> Result<RawWireMessage, WireError> {
177 tokio::time::timeout(timeout, self.read_raw_message())
178 .await
179 .map_or(Err(WireError::Timeout(timeout)), |msg| msg)
180 }
181
182 async fn read_response<Res: DeserializeOwned + Send>(
183 &mut self,
184 expected_id: &str,
185 ) -> Result<Res, WireError> {
186 let timeout = self.default_timeout;
187 let fut = async {
188 loop {
189 if let Some(idx) = self
190 .pending_messages
191 .iter()
192 .position(|msg| msg.id.as_deref() == Some(expected_id))
193 {
194 let msg = self
195 .pending_messages
196 .remove(idx)
197 .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
198 return decode_raw_response(msg, expected_id);
199 }
200
201 let Some(line) = self.read_line_with_retry().await? else {
202 return Err(WireError::StreamClosed);
203 };
204 let msg: RawWireMessage = serde_json::from_str(&line).map_err(WireError::from)?;
205 if msg.id.as_deref() == Some(expected_id) {
206 return decode_raw_response(msg, expected_id);
207 }
208 if self.pending_messages.len() >= MAX_PENDING_MESSAGES {
209 return Err(WireError::Internal(format!(
210 "pending message buffer overflow ({MAX_PENDING_MESSAGES} entries) waiting for id {expected_id:?}"
211 )));
212 }
213 self.pending_messages.push_back(msg);
214 }
215 };
216
217 match timeout {
218 Some(d) => tokio::time::timeout(d, fut)
219 .await
220 .map_err(|_| WireError::Timeout(d))?,
221 None => fut.await,
222 }
223 }
224
225 async fn send_response<Res: Serialize + Send>(
226 &mut self,
227 id: &str,
228 result: Res,
229 ) -> Result<(), WireError> {
230 let resp = JsonRpcSuccessResponse {
231 jsonrpc: crate::protocol::JsonRpcVersion::V2,
232 id: id.to_string(),
233 result,
234 };
235 let line = serde_json::to_string(&resp).map_err(WireError::from)?;
236 self.transport.write_line(&line).await
237 }
238
239 async fn send_error(&mut self, id: &str, code: i32, message: &str) -> Result<(), WireError> {
240 let resp = JsonRpcErrorResponse {
241 jsonrpc: crate::protocol::JsonRpcVersion::V2,
242 id: id.to_string(),
243 error: crate::protocol::JsonRpcError {
244 code,
245 message: message.to_string(),
246 data: None,
247 },
248 };
249 let line = serde_json::to_string(&resp).map_err(WireError::from)?;
250 self.transport.write_line(&line).await
251 }
252
253 async fn initialize(
254 &mut self,
255 params: InitializeParams,
256 ) -> Result<InitializeResult, WireError> {
257 let id = self.next_id();
258 let req = JsonRpcRequest {
259 jsonrpc: crate::protocol::JsonRpcVersion::V2,
260 method: "initialize".to_string(),
261 id: id.clone(),
262 params,
263 };
264 self.send_request(&req).await?;
265
266 let Some(line) = self.transport.read_line().await? else {
267 return Err(WireError::StreamClosed);
268 };
269
270 if let Ok(error_resp) = serde_json::from_str::<JsonRpcErrorResponse>(&line) {
272 if error_resp.error.code == crate::protocol::METHOD_NOT_FOUND {
273 tracing::warn!(
274 code = error_resp.error.code,
275 "Server does not support initialize, falling back to legacy no-handshake mode"
276 );
277 self.handshake_done = true;
278 return Ok(InitializeResult {
279 protocol_version: crate::WIRE_PROTOCOL_LEGACY_VERSION.to_string(),
280 server: crate::protocol::ServerInfo {
281 name: "unknown".to_string(),
282 version: "unknown".to_string(),
283 },
284 slash_commands: vec![],
285 external_tools: None,
286 capabilities: None,
287 hooks: None,
288 });
289 }
290 return Err(WireError::RequestFailed {
291 code: error_resp.error.code,
292 message: error_resp.error.message,
293 });
294 }
295
296 let resp: JsonRpcSuccessResponse<InitializeResult> =
297 serde_json::from_str(&line).map_err(WireError::from)?;
298 self.handshake_done = true;
299 Ok(resp.result)
300 }
301
302 fn is_handshake_done(&self) -> bool {
303 self.handshake_done
304 }
305
306 async fn shutdown(self) -> Result<(), WireError> {
307 self.transport.shutdown().await
308 }
309}
310
311fn decode_raw_response<T: DeserializeOwned>(
312 msg: RawWireMessage,
313 _expected_id: &str,
314) -> Result<T, WireError> {
315 if let Some(error) = msg.error {
316 return Err(WireError::RequestFailed {
317 code: error.code,
318 message: error.message,
319 });
320 }
321 let result = msg
322 .result
323 .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
324 serde_json::from_value(result).map_err(WireError::from)
325}
326
327#[derive(Debug)]
333pub struct ChildProcessTransport {
334 child: Option<Child>,
335 stdin: Option<ChildStdin>,
336 stdout_reader: FramedRead<ChildStdout, LinesCodec>,
337 stderr_handle: Option<tokio::task::JoinHandle<()>>,
338 cancel_token: CancellationToken,
339}
340
341impl ChildProcessTransport {
342 pub async fn spawn(
348 kimi_binary: &str,
349 work_dir: Option<&Path>,
350 session: Option<&str>,
351 model: Option<&str>,
352 ) -> Result<Self, WireError> {
353 let mut child = None;
354 for attempt in 0..3 {
355 let mut cmd = tokio::process::Command::new(kimi_binary);
356 cmd.arg("--wire");
357 if let Some(dir) = work_dir {
358 cmd.arg("--work-dir").arg(dir);
359 }
360 if let Some(s) = session {
361 cmd.arg("--session").arg(s);
362 }
363 if let Some(m) = model {
364 cmd.arg("--model").arg(m);
365 }
366 cmd.stdin(std::process::Stdio::piped())
367 .stdout(std::process::Stdio::piped())
368 .stderr(std::process::Stdio::piped());
369
370 match cmd.kill_on_drop(true).spawn() {
371 Ok(spawned) => {
372 child = Some(spawned);
373 break;
374 }
375 Err(err) if err.raw_os_error() == Some(26) && attempt < 2 => {
378 tokio::time::sleep(Duration::from_millis(25)).await;
379 }
380 Err(err) => {
381 return Err(WireError::SpawnFailed(err.to_string()));
382 }
383 }
384 }
385
386 let mut child =
387 child.ok_or_else(|| WireError::SpawnFailed("all spawn attempts failed".to_string()))?;
388 let stdin = child
389 .stdin
390 .take()
391 .ok_or_else(|| WireError::SpawnFailed("no stdin".to_string()))?;
392 let stdout = child
393 .stdout
394 .take()
395 .ok_or_else(|| WireError::SpawnFailed("no stdout".to_string()))?;
396 let stdout_reader = FramedRead::new(
397 stdout,
398 LinesCodec::new_with_max_length(MAX_WIRE_LINE_LENGTH),
399 );
400
401 let cancel_token = CancellationToken::new();
402 let stderr_cancel = cancel_token.clone();
403 let stderr_handle = child.stderr.take().map(|stderr| {
404 tokio::spawn(async move {
405 let mut reader = BufReader::new(stderr).lines();
406 loop {
407 tokio::select! {
408 biased;
409 () = stderr_cancel.cancelled() => break,
410 line = reader.next_line() => {
411 match line {
412 Ok(Some(line)) => {
413 #[cfg(feature = "redact")]
414 tracing::warn!(target: "kimi.stderr", "{}", crate::protocol::redact::scrub_secret_patterns(&line));
415 #[cfg(not(feature = "redact"))]
416 tracing::warn!(target: "kimi.stderr", "{line}");
417 }
418 _ => break,
419 }
420 }
421 }
422 }
423 })
424 });
425
426 tracing::info!(
427 kimi_binary,
428 ?work_dir,
429 ?session,
430 ?model,
431 "child process transport spawned"
432 );
433 Ok(Self {
434 child: Some(child),
435 stdin: Some(stdin),
436 stdout_reader,
437 stderr_handle,
438 cancel_token,
439 })
440 }
441}
442
443impl Transport for ChildProcessTransport {
444 async fn read_line(&mut self) -> Result<Option<String>, WireError> {
445 use tokio_stream::StreamExt;
446 match self.stdout_reader.next().await {
447 Some(Ok(line)) => {
448 tracing::trace!(len = line.len(), "read line from child process transport");
449 Ok(Some(line))
450 }
451 Some(Err(e)) => Err(WireError::Io(e.to_string())),
452 None => Ok(None),
453 }
454 }
455
456 async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
457 let stdin = self.stdin.as_mut().ok_or(WireError::StreamClosed)?;
458 stdin.write_all(line.as_bytes()).await?;
459 stdin.write_all(b"\n").await?;
460 stdin.flush().await?;
461 tracing::trace!(len = line.len(), "wrote line to child process transport");
462 Ok(())
463 }
464
465 async fn shutdown(mut self) -> Result<(), WireError> {
466 tracing::info!("shutting down child process transport");
467 drop(self.stdin.take());
469
470 let grace = Duration::from_secs(3);
472 if let Some(mut child) = self.child.take() {
473 match tokio::time::timeout(grace, child.wait()).await {
474 Ok(Ok(_) | Err(_)) => {}
475 Err(_) => {
476 #[allow(unused_must_use)]
479 let _ = child.kill().await;
480 }
481 }
482 }
483
484 self.cancel_token.cancel();
486 if let Some(handle) = self.stderr_handle.take() {
487 handle.abort();
488 }
489
490 Ok(())
491 }
492}
493
494impl Drop for ChildProcessTransport {
495 fn drop(&mut self) {
496 self.cancel_token.cancel();
497 if let Some(handle) = self.stderr_handle.take() {
498 handle.abort();
499 }
500 }
501}
502
503#[derive(Debug)]
509pub struct ChannelTransport {
510 rx: tokio::sync::mpsc::UnboundedReceiver<String>,
511 tx: tokio::sync::mpsc::UnboundedSender<String>,
512}
513
514impl ChannelTransport {
515 #[must_use]
517 pub fn pair() -> (Self, Self) {
518 let (tx1, rx1) = tokio::sync::mpsc::unbounded_channel();
519 let (tx2, rx2) = tokio::sync::mpsc::unbounded_channel();
520 (Self { rx: rx1, tx: tx2 }, Self { rx: rx2, tx: tx1 })
521 }
522}
523
524impl Transport for ChannelTransport {
525 async fn read_line(&mut self) -> Result<Option<String>, WireError> {
526 Ok(self.rx.recv().await)
527 }
528
529 async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
530 self.tx
531 .send(line.to_string())
532 .map_err(|_| WireError::StreamClosed)
533 }
534}