1use std::collections::VecDeque;
8use std::path::Path;
9use std::time::Duration;
10
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout};
15use tokio_util::codec::{FramedRead, LinesCodec};
16use tokio_util::sync::CancellationToken;
17
18use crate::client::WireClient;
19use crate::error::WireError;
20use crate::protocol::{
21 InitializeParams, InitializeResult, JsonRpcErrorResponse, JsonRpcRequest,
22 JsonRpcSuccessResponse, RawWireMessage,
23};
24
25pub const MAX_WIRE_LINE_LENGTH: usize = 16 * 1024 * 1024;
31
32pub const MAX_PENDING_MESSAGES: usize = 1024;
38
39fn is_transient_error(err: &WireError) -> bool {
41 matches!(err, WireError::Io(_) | WireError::Timeout(_))
42}
43
44pub trait Transport: Send {
46 fn read_line(
48 &mut self,
49 ) -> impl std::future::Future<Output = Result<Option<String>, WireError>> + Send;
50
51 fn write_line(
53 &mut self,
54 line: &str,
55 ) -> impl std::future::Future<Output = Result<(), WireError>> + Send;
56
57 fn shutdown(self) -> impl std::future::Future<Output = Result<(), WireError>> + Send
64 where
65 Self: Sized,
66 {
67 async { Ok(()) }
68 }
69}
70
71pub struct TransportWireClient<T: Transport> {
77 transport: T,
78 request_id_counter: u64,
79 handshake_done: bool,
80 pending_messages: VecDeque<RawWireMessage>,
81 default_timeout: Option<Duration>,
82 max_io_retries: u32,
83}
84
85impl<T: Transport> std::fmt::Debug for TransportWireClient<T> {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct("TransportWireClient")
88 .field("request_id_counter", &self.request_id_counter)
89 .field("handshake_done", &self.handshake_done)
90 .field("pending_messages", &self.pending_messages.len())
91 .field("default_timeout", &self.default_timeout)
92 .field("max_io_retries", &self.max_io_retries)
93 .finish_non_exhaustive()
94 }
95}
96
97impl<T: Transport> TransportWireClient<T> {
98 pub const fn new(transport: T) -> Self {
100 Self {
101 transport,
102 request_id_counter: 0,
103 handshake_done: false,
104 pending_messages: VecDeque::new(),
105 default_timeout: None,
106 max_io_retries: 0,
107 }
108 }
109
110 pub fn into_transport(self) -> T {
112 self.transport
113 }
114
115 #[must_use]
118 pub const fn with_default_timeout(mut self, timeout: Duration) -> Self {
119 self.default_timeout = Some(timeout);
120 self
121 }
122
123 #[must_use]
127 pub const fn with_max_io_retries(mut self, retries: u32) -> Self {
128 self.max_io_retries = if retries > 5 { 5 } else { retries };
129 self
130 }
131
132 async fn read_line_with_retry(&mut self) -> Result<Option<String>, WireError> {
133 let mut attempt = 0;
134 loop {
135 match self.transport.read_line().await {
136 Ok(result) => return Ok(result),
137 Err(ref e) if attempt < self.max_io_retries && is_transient_error(e) => {
138 attempt += 1;
139 let delay = Duration::from_millis(50 * 2_u64.pow(attempt));
140 tracing::debug!(error = %e, attempt, ?delay, "transient transport read error, retrying");
141 tokio::time::sleep(delay).await;
142 }
143 Err(e) => return Err(e),
144 }
145 }
146 }
147}
148
149impl<T: Transport> WireClient for TransportWireClient<T> {
150 fn next_id(&mut self) -> String {
151 self.request_id_counter += 1;
152 format!("req-{}", self.request_id_counter)
153 }
154
155 async fn send_request<Params: Serialize + Sync>(
156 &mut self,
157 req: &JsonRpcRequest<Params>,
158 ) -> Result<(), WireError> {
159 let line = serde_json::to_string(req).map_err(WireError::from)?;
160 self.transport.write_line(&line).await
161 }
162
163 async fn read_raw_message(&mut self) -> Result<RawWireMessage, WireError> {
164 if let Some(msg) = self.pending_messages.pop_front() {
165 return Ok(msg);
166 }
167 let line = match self.transport.read_line().await? {
168 Some(line) => line,
169 None => return Err(WireError::StreamClosed),
170 };
171 serde_json::from_str(&line).map_err(WireError::from)
172 }
173
174 async fn read_raw_message_timeout(
175 &mut self,
176 timeout: Duration,
177 ) -> Result<RawWireMessage, WireError> {
178 match tokio::time::timeout(timeout, self.read_raw_message()).await {
179 Ok(msg) => msg,
180 Err(_) => Err(WireError::Timeout(timeout)),
181 }
182 }
183
184 async fn read_response<Res: DeserializeOwned + Send>(
185 &mut self,
186 expected_id: &str,
187 ) -> Result<Res, WireError> {
188 let timeout = self.default_timeout;
189 let fut = async {
190 loop {
191 if let Some(idx) = self
192 .pending_messages
193 .iter()
194 .position(|msg| msg.id.as_deref() == Some(expected_id))
195 {
196 let msg = self
197 .pending_messages
198 .remove(idx)
199 .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
200 return decode_raw_response(msg, expected_id);
201 }
202
203 let line = match self.read_line_with_retry().await? {
204 Some(line) => line,
205 None => return Err(WireError::StreamClosed),
206 };
207 let msg: RawWireMessage = serde_json::from_str(&line).map_err(WireError::from)?;
208 if msg.id.as_deref() == Some(expected_id) {
209 return decode_raw_response(msg, expected_id);
210 }
211 if self.pending_messages.len() >= MAX_PENDING_MESSAGES {
212 return Err(WireError::Internal(format!(
213 "pending message buffer overflow ({} entries) waiting for id {:?}",
214 MAX_PENDING_MESSAGES, expected_id
215 )));
216 }
217 self.pending_messages.push_back(msg);
218 }
219 };
220
221 match timeout {
222 Some(d) => tokio::time::timeout(d, fut)
223 .await
224 .map_err(|_| WireError::Timeout(d))?,
225 None => fut.await,
226 }
227 }
228
229 async fn send_response<Res: Serialize + Send>(
230 &mut self,
231 id: &str,
232 result: Res,
233 ) -> Result<(), WireError> {
234 let resp = JsonRpcSuccessResponse {
235 jsonrpc: crate::protocol::JsonRpcVersion::V2,
236 id: id.to_string(),
237 result,
238 };
239 let line = serde_json::to_string(&resp).map_err(WireError::from)?;
240 self.transport.write_line(&line).await
241 }
242
243 async fn send_error(&mut self, id: &str, code: i32, message: &str) -> Result<(), WireError> {
244 let resp = JsonRpcErrorResponse {
245 jsonrpc: crate::protocol::JsonRpcVersion::V2,
246 id: id.to_string(),
247 error: crate::protocol::JsonRpcError {
248 code,
249 message: message.to_string(),
250 data: None,
251 },
252 };
253 let line = serde_json::to_string(&resp).map_err(WireError::from)?;
254 self.transport.write_line(&line).await
255 }
256
257 async fn initialize(
258 &mut self,
259 params: InitializeParams,
260 ) -> Result<InitializeResult, WireError> {
261 let id = self.next_id();
262 let req = JsonRpcRequest {
263 jsonrpc: crate::protocol::JsonRpcVersion::V2,
264 method: "initialize".to_string(),
265 id: id.clone(),
266 params,
267 };
268 self.send_request(&req).await?;
269
270 let line = match self.transport.read_line().await? {
271 Some(line) => line,
272 None => return Err(WireError::StreamClosed),
273 };
274
275 if let Ok(error_resp) = serde_json::from_str::<JsonRpcErrorResponse>(&line) {
277 if error_resp.error.code == crate::protocol::METHOD_NOT_FOUND {
278 tracing::warn!(
279 code = error_resp.error.code,
280 "Server does not support initialize, falling back to legacy no-handshake mode"
281 );
282 self.handshake_done = true;
283 return Ok(InitializeResult {
284 protocol_version: crate::WIRE_PROTOCOL_LEGACY_VERSION.to_string(),
285 server: crate::protocol::ServerInfo {
286 name: "unknown".to_string(),
287 version: "unknown".to_string(),
288 },
289 slash_commands: vec![],
290 external_tools: None,
291 capabilities: None,
292 hooks: None,
293 });
294 }
295 return Err(WireError::RequestFailed {
296 code: error_resp.error.code,
297 message: error_resp.error.message,
298 });
299 }
300
301 let resp: JsonRpcSuccessResponse<InitializeResult> =
302 serde_json::from_str(&line).map_err(WireError::from)?;
303 self.handshake_done = true;
304 Ok(resp.result)
305 }
306
307 fn is_handshake_done(&self) -> bool {
308 self.handshake_done
309 }
310
311 async fn shutdown(self) -> Result<(), WireError> {
312 self.transport.shutdown().await
313 }
314}
315
316fn decode_raw_response<T: DeserializeOwned>(
317 msg: RawWireMessage,
318 _expected_id: &str,
319) -> Result<T, WireError> {
320 if let Some(error) = msg.error {
321 return Err(WireError::RequestFailed {
322 code: error.code,
323 message: error.message,
324 });
325 }
326 let result = msg
327 .result
328 .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
329 serde_json::from_value(result).map_err(WireError::from)
330}
331
332#[derive(Debug)]
338pub struct ChildProcessTransport {
339 child: Option<Child>,
340 stdin: Option<ChildStdin>,
341 stdout_reader: FramedRead<ChildStdout, LinesCodec>,
342 stderr_handle: Option<tokio::task::JoinHandle<()>>,
343 cancel_token: CancellationToken,
344}
345
346impl ChildProcessTransport {
347 pub async fn spawn(
353 kimi_binary: &str,
354 work_dir: Option<&Path>,
355 session: Option<&str>,
356 model: Option<&str>,
357 ) -> Result<Self, WireError> {
358 let mut child = None;
359 for attempt in 0..3 {
360 let mut cmd = tokio::process::Command::new(kimi_binary);
361 cmd.arg("--wire");
362 if let Some(dir) = work_dir {
363 cmd.arg("--work-dir").arg(dir);
364 }
365 if let Some(s) = session {
366 cmd.arg("--session").arg(s);
367 }
368 if let Some(m) = model {
369 cmd.arg("--model").arg(m);
370 }
371 cmd.stdin(std::process::Stdio::piped())
372 .stdout(std::process::Stdio::piped())
373 .stderr(std::process::Stdio::piped());
374
375 match cmd.kill_on_drop(true).spawn() {
376 Ok(spawned) => {
377 child = Some(spawned);
378 break;
379 }
380 Err(err) if err.raw_os_error() == Some(26) && attempt < 2 => {
383 tokio::time::sleep(Duration::from_millis(25)).await;
384 }
385 Err(err) => {
386 return Err(WireError::SpawnFailed(err.to_string()));
387 }
388 }
389 }
390
391 let mut child =
392 child.ok_or_else(|| WireError::SpawnFailed("all spawn attempts failed".to_string()))?;
393 let stdin = child
394 .stdin
395 .take()
396 .ok_or_else(|| WireError::SpawnFailed("no stdin".to_string()))?;
397 let stdout = child
398 .stdout
399 .take()
400 .ok_or_else(|| WireError::SpawnFailed("no stdout".to_string()))?;
401 let stdout_reader = FramedRead::new(
402 stdout,
403 LinesCodec::new_with_max_length(MAX_WIRE_LINE_LENGTH),
404 );
405
406 let cancel_token = CancellationToken::new();
407 let stderr_cancel = cancel_token.clone();
408 let stderr_handle = child.stderr.take().map(|stderr| {
409 tokio::spawn(async move {
410 let mut reader = BufReader::new(stderr).lines();
411 loop {
412 tokio::select! {
413 biased;
414 _ = stderr_cancel.cancelled() => break,
415 line = reader.next_line() => {
416 match line {
417 Ok(Some(line)) => {
418 #[cfg(feature = "redact")]
419 tracing::warn!(target: "kimi.stderr", "{}", crate::protocol::redact::scrub_secret_patterns(&line));
420 #[cfg(not(feature = "redact"))]
421 tracing::warn!(target: "kimi.stderr", "{line}");
422 }
423 _ => break,
424 }
425 }
426 }
427 }
428 })
429 });
430
431 tracing::info!(
432 kimi_binary,
433 ?work_dir,
434 ?session,
435 ?model,
436 "child process transport spawned"
437 );
438 Ok(Self {
439 child: Some(child),
440 stdin: Some(stdin),
441 stdout_reader,
442 stderr_handle,
443 cancel_token,
444 })
445 }
446}
447
448impl Transport for ChildProcessTransport {
449 async fn read_line(&mut self) -> Result<Option<String>, WireError> {
450 use tokio_stream::StreamExt;
451 match self.stdout_reader.next().await {
452 Some(Ok(line)) => {
453 tracing::trace!(len = line.len(), "read line from child process transport");
454 Ok(Some(line))
455 }
456 Some(Err(e)) => Err(WireError::Io(e.to_string())),
457 None => Ok(None),
458 }
459 }
460
461 async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
462 let stdin = self.stdin.as_mut().ok_or(WireError::StreamClosed)?;
463 stdin.write_all(line.as_bytes()).await?;
464 stdin.write_all(b"\n").await?;
465 stdin.flush().await?;
466 tracing::trace!(len = line.len(), "wrote line to child process transport");
467 Ok(())
468 }
469
470 async fn shutdown(mut self) -> Result<(), WireError> {
471 tracing::info!("shutting down child process transport");
472 drop(self.stdin.take());
474
475 let grace = Duration::from_secs(3);
477 if let Some(mut child) = self.child.take() {
478 match tokio::time::timeout(grace, child.wait()).await {
479 Ok(Ok(_)) => {}
480 Ok(Err(_)) => {}
481 Err(_) => {
482 #[allow(unused_must_use)]
485 let _ = child.kill().await;
486 }
487 }
488 }
489
490 self.cancel_token.cancel();
492 if let Some(handle) = self.stderr_handle.take() {
493 handle.abort();
494 }
495
496 Ok(())
497 }
498}
499
500impl Drop for ChildProcessTransport {
501 fn drop(&mut self) {
502 self.cancel_token.cancel();
503 if let Some(handle) = self.stderr_handle.take() {
504 handle.abort();
505 }
506 }
507}
508
509#[derive(Debug)]
515pub struct ChannelTransport {
516 rx: tokio::sync::mpsc::UnboundedReceiver<String>,
517 tx: tokio::sync::mpsc::UnboundedSender<String>,
518}
519
520impl ChannelTransport {
521 #[must_use]
523 pub fn pair() -> (Self, Self) {
524 let (tx1, rx1) = tokio::sync::mpsc::unbounded_channel();
525 let (tx2, rx2) = tokio::sync::mpsc::unbounded_channel();
526 (Self { rx: rx1, tx: tx2 }, Self { rx: rx2, tx: tx1 })
527 }
528}
529
530impl Transport for ChannelTransport {
531 async fn read_line(&mut self) -> Result<Option<String>, WireError> {
532 Ok(self.rx.recv().await)
533 }
534
535 async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
536 self.tx
537 .send(line.to_string())
538 .map_err(|_| WireError::StreamClosed)
539 }
540}