1use std::collections::VecDeque;
8use std::path::Path;
9use std::time::Duration;
10
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout};
15use tokio_util::codec::{FramedRead, LinesCodec};
16use tokio_util::sync::CancellationToken;
17
18use crate::client::WireClient;
19use crate::error::WireError;
20use crate::protocol::{
21 InitializeParams, InitializeResult, JsonRpcErrorResponse, JsonRpcRequest,
22 JsonRpcSuccessResponse, RawWireMessage,
23};
24
25pub const MAX_WIRE_LINE_LENGTH: usize = 16 * 1024 * 1024;
31
32pub const MAX_PENDING_MESSAGES: usize = 1024;
38
39pub trait Transport: Send {
41 fn read_line(
43 &mut self,
44 ) -> impl std::future::Future<Output = Result<Option<String>, WireError>> + Send;
45
46 fn write_line(
48 &mut self,
49 line: &str,
50 ) -> impl std::future::Future<Output = Result<(), WireError>> + Send;
51
52 fn shutdown(self) -> impl std::future::Future<Output = Result<(), WireError>> + Send
59 where
60 Self: Sized,
61 {
62 async { Ok(()) }
63 }
64}
65
66pub struct TransportWireClient<T: Transport> {
72 transport: T,
73 request_id_counter: u64,
74 handshake_done: bool,
75 pending_messages: VecDeque<RawWireMessage>,
76 default_timeout: Option<Duration>,
77}
78
79impl<T: Transport> TransportWireClient<T> {
80 pub fn new(transport: T) -> Self {
82 Self {
83 transport,
84 request_id_counter: 0,
85 handshake_done: false,
86 pending_messages: VecDeque::new(),
87 default_timeout: None,
88 }
89 }
90
91 pub fn into_transport(self) -> T {
93 self.transport
94 }
95
96 pub fn with_default_timeout(mut self, timeout: Duration) -> Self {
99 self.default_timeout = Some(timeout);
100 self
101 }
102}
103
104impl<T: Transport> WireClient for TransportWireClient<T> {
105 fn next_id(&mut self) -> String {
106 self.request_id_counter += 1;
107 format!("req-{}", self.request_id_counter)
108 }
109
110 async fn send_request<Params: Serialize + Sync>(
111 &mut self,
112 req: &JsonRpcRequest<Params>,
113 ) -> Result<(), WireError> {
114 let line = serde_json::to_string(req).map_err(WireError::from)?;
115 self.transport.write_line(&line).await
116 }
117
118 async fn read_raw_message(&mut self) -> Result<RawWireMessage, WireError> {
119 if let Some(msg) = self.pending_messages.pop_front() {
120 return Ok(msg);
121 }
122 let line = match self.transport.read_line().await? {
123 Some(line) => line,
124 None => return Err(WireError::StreamClosed),
125 };
126 serde_json::from_str(&line).map_err(WireError::from)
127 }
128
129 async fn read_raw_message_timeout(
130 &mut self,
131 timeout: Duration,
132 ) -> Result<RawWireMessage, WireError> {
133 match tokio::time::timeout(timeout, self.read_raw_message()).await {
134 Ok(msg) => msg,
135 Err(_) => Err(WireError::Timeout(timeout)),
136 }
137 }
138
139 async fn read_response<Res: DeserializeOwned + Send>(
140 &mut self,
141 expected_id: &str,
142 ) -> Result<Res, WireError> {
143 let fut = async {
144 loop {
145 if let Some(idx) = self
146 .pending_messages
147 .iter()
148 .position(|msg| msg.id.as_deref() == Some(expected_id))
149 {
150 let msg = self
151 .pending_messages
152 .remove(idx)
153 .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
154 return decode_raw_response(msg, expected_id);
155 }
156
157 let line = match self.transport.read_line().await? {
158 Some(line) => line,
159 None => return Err(WireError::StreamClosed),
160 };
161 let msg: RawWireMessage = serde_json::from_str(&line).map_err(WireError::from)?;
162 if msg.id.as_deref() == Some(expected_id) {
163 return decode_raw_response(msg, expected_id);
164 }
165 if self.pending_messages.len() >= MAX_PENDING_MESSAGES {
166 return Err(WireError::Internal(format!(
167 "pending message buffer overflow ({} entries) waiting for id {:?}",
168 MAX_PENDING_MESSAGES, expected_id
169 )));
170 }
171 self.pending_messages.push_back(msg);
172 }
173 };
174
175 match self.default_timeout {
176 Some(d) => tokio::time::timeout(d, fut)
177 .await
178 .map_err(|_| WireError::Timeout(d))?,
179 None => fut.await,
180 }
181 }
182
183 async fn send_response<Res: Serialize + Send>(
184 &mut self,
185 id: &str,
186 result: Res,
187 ) -> Result<(), WireError> {
188 let resp = JsonRpcSuccessResponse {
189 jsonrpc: crate::protocol::JsonRpcVersion::V2,
190 id: id.to_string(),
191 result,
192 };
193 let line = serde_json::to_string(&resp).map_err(WireError::from)?;
194 self.transport.write_line(&line).await
195 }
196
197 async fn send_error(
198 &mut self,
199 id: &str,
200 code: i32,
201 message: &str,
202 ) -> Result<(), WireError> {
203 let resp = JsonRpcErrorResponse {
204 jsonrpc: crate::protocol::JsonRpcVersion::V2,
205 id: id.to_string(),
206 error: crate::protocol::JsonRpcError {
207 code,
208 message: message.to_string(),
209 data: None,
210 },
211 };
212 let line = serde_json::to_string(&resp).map_err(WireError::from)?;
213 self.transport.write_line(&line).await
214 }
215
216 async fn initialize(
217 &mut self,
218 params: InitializeParams,
219 ) -> Result<InitializeResult, WireError> {
220 let id = self.next_id();
221 let req = JsonRpcRequest {
222 jsonrpc: crate::protocol::JsonRpcVersion::V2,
223 method: "initialize".to_string(),
224 id: id.clone(),
225 params,
226 };
227 self.send_request(&req).await?;
228
229 let line = match self.transport.read_line().await? {
230 Some(line) => line,
231 None => return Err(WireError::StreamClosed),
232 };
233
234 if let Ok(error_resp) = serde_json::from_str::<JsonRpcErrorResponse>(&line) {
236 if error_resp.error.code == crate::protocol::METHOD_NOT_FOUND {
237 tracing::warn!(
238 code = error_resp.error.code,
239 "Server does not support initialize, falling back to legacy no-handshake mode"
240 );
241 self.handshake_done = true;
242 return Ok(InitializeResult {
243 protocol_version: crate::WIRE_PROTOCOL_LEGACY_VERSION.to_string(),
244 server: crate::protocol::ServerInfo {
245 name: "unknown".to_string(),
246 version: "unknown".to_string(),
247 },
248 slash_commands: vec![],
249 external_tools: None,
250 capabilities: None,
251 hooks: None,
252 });
253 }
254 return Err(WireError::RequestFailed {
255 code: error_resp.error.code,
256 message: error_resp.error.message,
257 });
258 }
259
260 let resp: JsonRpcSuccessResponse<InitializeResult> =
261 serde_json::from_str(&line).map_err(WireError::from)?;
262 self.handshake_done = true;
263 Ok(resp.result)
264 }
265
266 fn is_handshake_done(&self) -> bool {
267 self.handshake_done
268 }
269
270 async fn shutdown(self) -> Result<(), WireError> {
271 self.transport.shutdown().await
272 }
273}
274
275fn decode_raw_response<T: DeserializeOwned>(
276 msg: RawWireMessage,
277 _expected_id: &str,
278) -> Result<T, WireError> {
279 if let Some(error) = msg.error {
280 return Err(WireError::RequestFailed {
281 code: error.code,
282 message: error.message,
283 });
284 }
285 let result = msg
286 .result
287 .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
288 serde_json::from_value(result).map_err(WireError::from)
289}
290
291#[derive(Debug)]
297pub struct ChildProcessTransport {
298 child: Option<Child>,
299 stdin: Option<ChildStdin>,
300 stdout_reader: FramedRead<ChildStdout, LinesCodec>,
301 stderr_handle: Option<tokio::task::JoinHandle<()>>,
302 cancel_token: CancellationToken,
303}
304
305impl ChildProcessTransport {
306 pub async fn spawn(
312 kimi_binary: &str,
313 work_dir: Option<&Path>,
314 session: Option<&str>,
315 model: Option<&str>,
316 ) -> Result<Self, WireError> {
317 let mut child = None;
318 for attempt in 0..3 {
319 let mut cmd = tokio::process::Command::new(kimi_binary);
320 cmd.arg("--wire");
321 if let Some(dir) = work_dir {
322 cmd.arg("--work-dir").arg(dir);
323 }
324 if let Some(s) = session {
325 cmd.arg("--session").arg(s);
326 }
327 if let Some(m) = model {
328 cmd.arg("--model").arg(m);
329 }
330 cmd.stdin(std::process::Stdio::piped())
331 .stdout(std::process::Stdio::piped())
332 .stderr(std::process::Stdio::piped());
333
334 match cmd.kill_on_drop(true).spawn() {
335 Ok(spawned) => {
336 child = Some(spawned);
337 break;
338 }
339 Err(err) if err.raw_os_error() == Some(26) && attempt < 2 => {
340 tokio::time::sleep(Duration::from_millis(25)).await;
341 }
342 Err(err) => {
343 return Err(WireError::SpawnFailed(err.to_string()));
344 }
345 }
346 }
347
348 let mut child = child
349 .ok_or_else(|| WireError::SpawnFailed("all spawn attempts failed".to_string()))?;
350 let stdin = child
351 .stdin
352 .take()
353 .ok_or_else(|| WireError::SpawnFailed("no stdin".to_string()))?;
354 let stdout = child
355 .stdout
356 .take()
357 .ok_or_else(|| WireError::SpawnFailed("no stdout".to_string()))?;
358 let stdout_reader = FramedRead::new(
359 stdout,
360 LinesCodec::new_with_max_length(MAX_WIRE_LINE_LENGTH),
361 );
362
363 let cancel_token = CancellationToken::new();
364 let stderr_cancel = cancel_token.clone();
365 let stderr_handle = child.stderr.take().map(|stderr| {
366 tokio::spawn(async move {
367 let mut reader = BufReader::new(stderr).lines();
368 loop {
369 tokio::select! {
370 biased;
371 _ = stderr_cancel.cancelled() => break,
372 line = reader.next_line() => {
373 match line {
374 Ok(Some(line)) => {
375 #[cfg(feature = "redact")]
376 tracing::warn!(target: "kimi.stderr", "{}", crate::protocol::redact::scrub_secret_patterns(&line));
377 #[cfg(not(feature = "redact"))]
378 tracing::warn!(target: "kimi.stderr", "{line}");
379 }
380 _ => break,
381 }
382 }
383 }
384 }
385 })
386 });
387
388 Ok(Self {
389 child: Some(child),
390 stdin: Some(stdin),
391 stdout_reader,
392 stderr_handle,
393 cancel_token,
394 })
395 }
396}
397
398impl Transport for ChildProcessTransport {
399 async fn read_line(&mut self) -> Result<Option<String>, WireError> {
400 use tokio_stream::StreamExt;
401 match self.stdout_reader.next().await {
402 Some(Ok(line)) => Ok(Some(line)),
403 Some(Err(e)) => Err(WireError::Io(e.to_string())),
404 None => Ok(None),
405 }
406 }
407
408 async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
409 let stdin = self
410 .stdin
411 .as_mut()
412 .ok_or(WireError::StreamClosed)?;
413 stdin.write_all(line.as_bytes()).await?;
414 stdin.write_all(b"\n").await?;
415 stdin.flush().await?;
416 Ok(())
417 }
418
419 async fn shutdown(mut self) -> Result<(), WireError> {
420 drop(self.stdin.take());
422
423 let grace = Duration::from_secs(3);
425 if let Some(mut child) = self.child.take() {
426 match tokio::time::timeout(grace, child.wait()).await {
427 Ok(Ok(_)) => {}
428 Ok(Err(_)) => {}
429 Err(_) => {
430 let _ = child.kill().await;
431 }
432 }
433 }
434
435 self.cancel_token.cancel();
437 if let Some(handle) = self.stderr_handle.take() {
438 handle.abort();
439 }
440
441 Ok(())
442 }
443}
444
445impl Drop for ChildProcessTransport {
446 fn drop(&mut self) {
447 self.cancel_token.cancel();
448 if let Some(handle) = self.stderr_handle.take() {
449 handle.abort();
450 }
451 }
452}
453
454#[derive(Debug)]
460pub struct ChannelTransport {
461 rx: tokio::sync::mpsc::UnboundedReceiver<String>,
462 tx: tokio::sync::mpsc::UnboundedSender<String>,
463}
464
465impl ChannelTransport {
466 pub fn pair() -> (Self, Self) {
468 let (tx1, rx1) = tokio::sync::mpsc::unbounded_channel();
469 let (tx2, rx2) = tokio::sync::mpsc::unbounded_channel();
470 (
471 Self { rx: rx1, tx: tx2 },
472 Self { rx: rx2, tx: tx1 },
473 )
474 }
475}
476
477impl Transport for ChannelTransport {
478 async fn read_line(&mut self) -> Result<Option<String>, WireError> {
479 Ok(self.rx.recv().await)
480 }
481
482 async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
483 self.tx
484 .send(line.to_string())
485 .map_err(|_| WireError::StreamClosed)
486 }
487}