1use std::collections::VecDeque;
8use std::path::Path;
9use std::time::Duration;
10
11use serde::de::DeserializeOwned;
12use serde::Serialize;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout};
15use tokio_util::codec::{FramedRead, LinesCodec};
16use tokio_util::sync::CancellationToken;
17
18use crate::client::WireClient;
19use crate::error::WireError;
20use crate::protocol::{
21 InitializeParams, InitializeResult, JsonRpcErrorResponse, JsonRpcRequest,
22 JsonRpcSuccessResponse, RawWireMessage,
23};
24
25pub const MAX_WIRE_LINE_LENGTH: usize = 16 * 1024 * 1024;
31
32pub trait Transport: Send {
34 fn read_line(
36 &mut self,
37 ) -> impl std::future::Future<Output = Result<Option<String>, WireError>> + Send;
38
39 fn write_line(
41 &mut self,
42 line: &str,
43 ) -> impl std::future::Future<Output = Result<(), WireError>> + Send;
44}
45
46pub struct TransportWireClient<T: Transport> {
52 transport: T,
53 request_id_counter: u64,
54 handshake_done: bool,
55 pending_messages: VecDeque<RawWireMessage>,
56}
57
58impl<T: Transport> TransportWireClient<T> {
59 pub fn new(transport: T) -> Self {
61 Self {
62 transport,
63 request_id_counter: 0,
64 handshake_done: false,
65 pending_messages: VecDeque::new(),
66 }
67 }
68
69 pub fn into_transport(self) -> T {
71 self.transport
72 }
73}
74
75impl<T: Transport> WireClient for TransportWireClient<T> {
76 fn next_id(&mut self) -> String {
77 self.request_id_counter += 1;
78 format!("req-{}", self.request_id_counter)
79 }
80
81 async fn send_request<Params: Serialize + Sync>(
82 &mut self,
83 req: &JsonRpcRequest<Params>,
84 ) -> Result<(), WireError> {
85 let line = serde_json::to_string(req).map_err(WireError::from)?;
86 self.transport.write_line(&line).await
87 }
88
89 async fn read_raw_message(&mut self) -> Result<RawWireMessage, WireError> {
90 if let Some(msg) = self.pending_messages.pop_front() {
91 return Ok(msg);
92 }
93 let line = match self.transport.read_line().await? {
94 Some(line) => line,
95 None => return Err(WireError::StreamClosed),
96 };
97 serde_json::from_str(&line).map_err(WireError::from)
98 }
99
100 async fn read_raw_message_timeout(
101 &mut self,
102 timeout: Duration,
103 ) -> Result<RawWireMessage, WireError> {
104 match tokio::time::timeout(timeout, self.read_raw_message()).await {
105 Ok(msg) => msg,
106 Err(_) => Err(WireError::Timeout(timeout)),
107 }
108 }
109
110 async fn read_response<Res: DeserializeOwned + Send>(
111 &mut self,
112 expected_id: &str,
113 ) -> Result<Res, WireError> {
114 loop {
115 if let Some(idx) = self
116 .pending_messages
117 .iter()
118 .position(|msg| msg.id.as_deref() == Some(expected_id))
119 {
120 let msg = self
121 .pending_messages
122 .remove(idx)
123 .ok_or_else(|| WireError::Internal("pending index invalid".to_string()))?;
124 return decode_raw_response(msg, expected_id);
125 }
126
127 match self.read_raw_message().await? {
128 msg if msg.id.as_deref() == Some(expected_id) => {
129 return decode_raw_response(msg, expected_id);
130 }
131 other => {
132 self.pending_messages.push_back(other);
133 }
134 }
135 }
136 }
137
138 async fn send_response<Res: Serialize + Send>(
139 &mut self,
140 id: &str,
141 result: Res,
142 ) -> Result<(), WireError> {
143 let resp = JsonRpcSuccessResponse {
144 jsonrpc: crate::protocol::JsonRpcVersion::default(),
145 id: id.to_string(),
146 result,
147 };
148 let line = serde_json::to_string(&resp).map_err(WireError::from)?;
149 self.transport.write_line(&line).await
150 }
151
152 async fn send_error(
153 &mut self,
154 id: &str,
155 code: i32,
156 message: &str,
157 ) -> Result<(), WireError> {
158 let resp = JsonRpcErrorResponse {
159 jsonrpc: crate::protocol::JsonRpcVersion::default(),
160 id: id.to_string(),
161 error: crate::protocol::JsonRpcError {
162 code,
163 message: message.to_string(),
164 data: None,
165 },
166 };
167 let line = serde_json::to_string(&resp).map_err(WireError::from)?;
168 self.transport.write_line(&line).await
169 }
170
171 async fn initialize(
172 &mut self,
173 params: InitializeParams,
174 ) -> Result<InitializeResult, WireError> {
175 let id = self.next_id();
176 let req = JsonRpcRequest {
177 jsonrpc: crate::protocol::JsonRpcVersion::default(),
178 method: "initialize".to_string(),
179 id: id.clone(),
180 params,
181 };
182 self.send_request(&req).await?;
183
184 let line = match self.transport.read_line().await? {
185 Some(line) => line,
186 None => return Err(WireError::StreamClosed),
187 };
188
189 if let Ok(error_resp) = serde_json::from_str::<JsonRpcErrorResponse>(&line) {
191 if error_resp.error.code == crate::protocol::METHOD_NOT_FOUND {
192 tracing::warn!(
193 code = error_resp.error.code,
194 "Server does not support initialize, falling back to legacy no-handshake mode"
195 );
196 self.handshake_done = true;
197 return Ok(InitializeResult {
198 protocol_version: crate::WIRE_PROTOCOL_LEGACY_VERSION.to_string(),
199 server: crate::protocol::ServerInfo {
200 name: "unknown".to_string(),
201 version: "unknown".to_string(),
202 },
203 slash_commands: vec![],
204 external_tools: None,
205 capabilities: None,
206 hooks: None,
207 });
208 }
209 return Err(WireError::RequestFailed {
210 code: error_resp.error.code,
211 message: error_resp.error.message,
212 });
213 }
214
215 let resp: JsonRpcSuccessResponse<InitializeResult> =
216 serde_json::from_str(&line).map_err(WireError::from)?;
217 self.handshake_done = true;
218 Ok(resp.result)
219 }
220
221 fn is_handshake_done(&self) -> bool {
222 self.handshake_done
223 }
224
225 async fn shutdown(self) -> Result<(), WireError> {
226 Ok(())
227 }
228}
229
230fn decode_raw_response<T: DeserializeOwned>(
231 msg: RawWireMessage,
232 _expected_id: &str,
233) -> Result<T, WireError> {
234 if let Some(error) = msg.error {
235 return Err(WireError::RequestFailed {
236 code: error.code,
237 message: error.message,
238 });
239 }
240 let result = msg
241 .result
242 .ok_or_else(|| WireError::Internal("response missing result".to_string()))?;
243 serde_json::from_value(result).map_err(WireError::from)
244}
245
246pub struct ChildProcessTransport {
252 #[allow(dead_code)]
253 child: Child,
254 stdin: ChildStdin,
255 stdout_reader: FramedRead<ChildStdout, LinesCodec>,
256 stderr_handle: Option<tokio::task::JoinHandle<()>>,
257 cancel_token: CancellationToken,
258}
259
260impl ChildProcessTransport {
261 pub async fn spawn(
267 kimi_binary: &str,
268 work_dir: Option<&Path>,
269 session: Option<&str>,
270 model: Option<&str>,
271 ) -> Result<Self, WireError> {
272 let mut child = None;
273 for attempt in 0..3 {
274 let mut cmd = tokio::process::Command::new(kimi_binary);
275 cmd.arg("--wire");
276 if let Some(dir) = work_dir {
277 cmd.arg("--work-dir").arg(dir);
278 }
279 if let Some(s) = session {
280 cmd.arg("--session").arg(s);
281 }
282 if let Some(m) = model {
283 cmd.arg("--model").arg(m);
284 }
285 cmd.stdin(std::process::Stdio::piped())
286 .stdout(std::process::Stdio::piped())
287 .stderr(std::process::Stdio::piped());
288
289 match cmd.kill_on_drop(true).spawn() {
290 Ok(spawned) => {
291 child = Some(spawned);
292 break;
293 }
294 Err(err) if err.raw_os_error() == Some(26) && attempt < 2 => {
295 tokio::time::sleep(Duration::from_millis(25)).await;
296 }
297 Err(err) => {
298 return Err(WireError::SpawnFailed(err.to_string()));
299 }
300 }
301 }
302
303 let mut child = child
304 .ok_or_else(|| WireError::SpawnFailed("all spawn attempts failed".to_string()))?;
305 let stdin = child
306 .stdin
307 .take()
308 .ok_or_else(|| WireError::SpawnFailed("no stdin".to_string()))?;
309 let stdout = child
310 .stdout
311 .take()
312 .ok_or_else(|| WireError::SpawnFailed("no stdout".to_string()))?;
313 let stdout_reader = FramedRead::new(
314 stdout,
315 LinesCodec::new_with_max_length(MAX_WIRE_LINE_LENGTH),
316 );
317
318 let cancel_token = CancellationToken::new();
319 let stderr_cancel = cancel_token.clone();
320 let stderr_handle = child.stderr.take().map(|stderr| {
321 tokio::spawn(async move {
322 let mut reader = BufReader::new(stderr).lines();
323 loop {
324 tokio::select! {
325 biased;
326 _ = stderr_cancel.cancelled() => break,
327 line = reader.next_line() => {
328 match line {
329 Ok(Some(line)) => {
330 #[cfg(feature = "redact")]
331 tracing::warn!(target: "kimi.stderr", "{}", crate::protocol::redact::scrub_secret_patterns(&line));
332 #[cfg(not(feature = "redact"))]
333 tracing::warn!(target: "kimi.stderr", "{line}");
334 }
335 _ => break,
336 }
337 }
338 }
339 }
340 })
341 });
342
343 Ok(Self {
344 child,
345 stdin,
346 stdout_reader,
347 stderr_handle,
348 cancel_token,
349 })
350 }
351}
352
353impl Transport for ChildProcessTransport {
354 async fn read_line(&mut self) -> Result<Option<String>, WireError> {
355 use tokio_stream::StreamExt;
356 match self.stdout_reader.next().await {
357 Some(Ok(line)) => Ok(Some(line)),
358 Some(Err(e)) => Err(WireError::Io(e.to_string())),
359 None => Ok(None),
360 }
361 }
362
363 async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
364 self.stdin.write_all(line.as_bytes()).await?;
365 self.stdin.write_all(b"\n").await?;
366 self.stdin.flush().await?;
367 Ok(())
368 }
369}
370
371impl Drop for ChildProcessTransport {
372 fn drop(&mut self) {
373 self.cancel_token.cancel();
374 if let Some(handle) = self.stderr_handle.take() {
375 handle.abort();
376 }
377 }
378}
379
380#[derive(Debug)]
386pub struct ChannelTransport {
387 rx: tokio::sync::mpsc::UnboundedReceiver<String>,
388 tx: tokio::sync::mpsc::UnboundedSender<String>,
389}
390
391impl ChannelTransport {
392 pub fn pair() -> (Self, Self) {
394 let (tx1, rx1) = tokio::sync::mpsc::unbounded_channel();
395 let (tx2, rx2) = tokio::sync::mpsc::unbounded_channel();
396 (
397 Self { rx: rx1, tx: tx2 },
398 Self { rx: rx2, tx: tx1 },
399 )
400 }
401}
402
403impl Transport for ChannelTransport {
404 async fn read_line(&mut self) -> Result<Option<String>, WireError> {
405 Ok(self.rx.recv().await)
406 }
407
408 async fn write_line(&mut self, line: &str) -> Result<(), WireError> {
409 self.tx
410 .send(line.to_string())
411 .map_err(|_| WireError::StreamClosed)
412 }
413}