1use anyhow::{Context, Result};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::process::{Child, ChildStdin, ChildStdout};
8use tokio::sync::mpsc;
9use tokio_util::sync::CancellationToken;
10use tracing::{debug, error, info, warn};
11
12use crate::transport::protocol::{create_notification, create_request, ProtocolMessage};
13use crate::types::config::QueryOptions;
14use crate::types::message::SDKMessage;
15
16#[derive(Debug, Clone, Serialize)]
18pub struct CLIRequest {
19 #[serde(rename = "type")]
21 pub request_type: String,
22 pub prompt: String,
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub session_id: Option<String>,
27 #[serde(flatten)]
29 pub options: QueryOptions,
30}
31
32#[derive(Debug, Clone, Serialize)]
34pub struct InitializeRequest {
35 pub protocol_version: String,
36 pub client: String,
37 pub client_version: String,
38}
39
40#[derive(Debug, Clone, Deserialize)]
42pub struct InitializeResponse {
43 pub protocol_version: String,
44 pub capabilities: CLICapabilities,
45}
46
47#[derive(Debug, Clone, Deserialize)]
49pub struct CLICapabilities {
50 #[serde(default)]
51 pub streaming: bool,
52 #[serde(default)]
53 pub tool_use: bool,
54 #[serde(default)]
55 pub multi_turn: bool,
56}
57
58pub async fn spawn_cli_process(executable_path: Option<&str>) -> Result<CLIProcess> {
60 let executable = executable_path.unwrap_or("qwen");
61
62 info!("Spawning QwenCode CLI process: {}", executable);
63
64 let mut child = tokio::process::Command::new(executable)
65 .kill_on_drop(true)
66 .stdin(std::process::Stdio::piped())
67 .stdout(std::process::Stdio::piped())
68 .stderr(std::process::Stdio::piped())
69 .spawn()
70 .context("Failed to spawn QwenCode CLI process")?;
71
72 let stdin = child.stdin.take().context("Failed to get stdin handle")?;
73
74 let stdout = child.stdout.take().context("Failed to get stdout handle")?;
75
76 let stderr = child.stderr.take().context("Failed to get stderr handle")?;
77
78 let (stderr_tx, stderr_rx) = mpsc::unbounded_channel::<String>();
80 tokio::spawn(read_stderr(stderr, stderr_tx));
81
82 debug!(
83 "QwenCode CLI process spawned successfully (PID: {:?})",
84 child.id()
85 );
86
87 Ok(CLIProcess {
88 child,
89 stdin,
90 stdout,
91 stderr_rx,
92 message_counter: 0,
93 })
94}
95
96pub struct CLIProcess {
98 child: Child,
99 stdin: ChildStdin,
100 stdout: ChildStdout,
101 stderr_rx: mpsc::UnboundedReceiver<String>,
102 message_counter: u64,
103}
104
105impl CLIProcess {
106 pub async fn initialize(
108 &mut self,
109 cancel_token: &CancellationToken,
110 ) -> Result<InitializeResponse> {
111 info!("Initializing CLI connection");
112
113 let init_request = InitializeRequest {
114 protocol_version: "1.0".to_string(),
115 client: "qwencode-rs".to_string(),
116 client_version: env!("CARGO_PKG_VERSION").to_string(),
117 };
118
119 let json = serde_json::to_string(&init_request)?;
120 let message = format!("{}\n", json);
121
122 self.stdin
123 .write_all(message.as_bytes())
124 .await
125 .context("Failed to send initialize request")?;
126 self.stdin.flush().await.context("Failed to flush stdin")?;
127
128 debug!("Initialize request sent");
129
130 let mut reader = BufReader::new(&mut self.stdout);
132 let mut line = String::new();
133
134 tokio::select! {
135 result = reader.read_line(&mut line) => {
136 let bytes_read = result.context("Failed to read initialize response")?;
137 if bytes_read == 0 {
138 return Err(anyhow::anyhow!("CLI process exited before responding"));
139 }
140
141 debug!("Initialize response: {}", line.trim());
142 let response: InitializeResponse = serde_json::from_str(&line)
143 .context("Failed to parse initialize response")?;
144
145 info!("CLI initialized with protocol version: {}", response.protocol_version);
146 Ok(response)
147 }
148 _ = cancel_token.cancelled() => {
149 Err(anyhow::anyhow!("Initialize cancelled"))
150 }
151 }
152 }
153
154 pub async fn send_query(&mut self, request: &CLIRequest) -> Result<()> {
156 self.message_counter += 1;
157 let id = self.message_counter;
158
159 let params = serde_json::to_value(request)?;
160 let message = create_request(id, "query", Some(params));
161
162 self.send_message(&message).await
163 }
164
165 async fn send_message(&mut self, message: &ProtocolMessage) -> Result<()> {
167 let json = serde_json::to_string(message)?;
168 let line = format!("{}\n", json);
169
170 debug!("Sending to CLI: {}", json);
171
172 self.stdin
173 .write_all(line.as_bytes())
174 .await
175 .context("Failed to write to stdin")?;
176 self.stdin.flush().await.context("Failed to flush stdin")?;
177
178 Ok(())
179 }
180
181 pub async fn read_message(&mut self) -> Result<Option<ProtocolMessage>> {
183 let mut reader = BufReader::new(&mut self.stdout);
184 let mut line = String::new();
185
186 let bytes_read = reader
187 .read_line(&mut line)
188 .await
189 .context("Failed to read from stdout")?;
190
191 if bytes_read == 0 {
192 debug!("stdout closed (EOF)");
193 return Ok(None);
194 }
195
196 let line = line.trim().to_string();
197 if line.is_empty() {
198 return Ok(None);
199 }
200
201 debug!("Received from CLI: {}", line);
202
203 let message: ProtocolMessage = serde_json::from_str(&line)
204 .with_context(|| format!("Failed to parse message: {}", line))?;
205
206 Ok(Some(message))
207 }
208
209 pub fn is_running(&mut self) -> bool {
211 self.child
212 .try_wait()
213 .map(|opt| opt.is_none())
214 .unwrap_or(false)
215 }
216
217 pub fn pid(&self) -> Option<u32> {
219 self.child.id()
220 }
221
222 pub async fn shutdown(&mut self) -> Result<()> {
224 info!("Shutting down CLI process (PID: {:?})", self.pid());
225
226 let close_msg = create_notification("close", None);
228 let _ = self.send_message(&close_msg).await;
229
230 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
232
233 if let Ok(Some(status)) = self.child.try_wait() {
235 debug!("Process exited with status: {:?}", status);
236 return Ok(());
237 }
238
239 if let Err(e) = self.child.kill().await {
241 warn!("Failed to kill process: {}", e);
242 }
243
244 match self.child.wait().await {
245 Ok(status) => {
246 info!("Process terminated with status: {:?}", status);
247 Ok(())
248 }
249 Err(e) => Err(anyhow::anyhow!("Failed to wait for process: {}", e)),
250 }
251 }
252
253 pub fn try_receive_stderr(&mut self) -> Option<String> {
255 self.stderr_rx.try_recv().ok()
256 }
257}
258
259async fn read_stderr(stderr: tokio::process::ChildStderr, sender: mpsc::UnboundedSender<String>) {
261 let mut reader = BufReader::new(stderr);
262 let mut line = String::new();
263
264 loop {
265 match reader.read_line(&mut line).await {
266 Ok(0) => {
267 debug!("stderr closed");
268 break;
269 }
270 Ok(_) => {
271 let trimmed = line.trim().to_string();
272 if !trimmed.is_empty() {
273 debug!("stderr: {}", trimmed);
274 let _ = sender.send(trimmed);
275 }
276 line.clear();
277 }
278 Err(e) => {
279 error!("Error reading stderr: {}", e);
280 break;
281 }
282 }
283 }
284}
285
286pub fn protocol_to_sdk_message(message: &ProtocolMessage) -> Result<Option<SDKMessage>> {
288 if let Some(method) = &message.method {
290 match method.as_str() {
291 "assistant_message" => {
292 if let Some(params) = &message.params {
293 let content = params
294 .get("content")
295 .and_then(|v| v.as_str())
296 .unwrap_or("")
297 .to_string();
298
299 return Ok(Some(SDKMessage::from_assistant_text(&content)));
300 }
301 }
302 "result" => {
303 if let Some(params) = &message.params {
304 return Ok(Some(SDKMessage::from_result_value(params.clone())));
305 }
306 }
307 "error" => {
308 if let Some(error) = &message.error {
309 return Err(anyhow::anyhow!("CLI error: {}", error.message));
310 }
311 }
312 _ => {
313 debug!("Unknown method: {}", method);
314 }
315 }
316 }
317
318 if let Some(result) = &message.result {
320 return Ok(Some(SDKMessage::from_result_value(result.clone())));
321 }
322
323 Ok(None)
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_cli_request_serialization() {
332 let request = CLIRequest {
333 request_type: "query".to_string(),
334 prompt: "Hello".to_string(),
335 session_id: Some("test-session".to_string()),
336 options: QueryOptions::default(),
337 };
338
339 let json = serde_json::to_string(&request).unwrap();
340 assert!(json.contains("\"type\":\"query\""));
341 assert!(json.contains("\"prompt\":\"Hello\""));
342 assert!(json.contains("\"session_id\":\"test-session\""));
343 }
344
345 #[test]
346 fn test_initialize_request_structure() {
347 let request = InitializeRequest {
348 protocol_version: "1.0".to_string(),
349 client: "qwencode-rs".to_string(),
350 client_version: "0.1.0".to_string(),
351 };
352
353 assert_eq!(request.protocol_version, "1.0");
354 assert_eq!(request.client, "qwencode-rs");
355 assert_eq!(request.client_version, "0.1.0");
356 }
357
358 #[test]
359 fn test_protocol_to_sdk_message_assistant() {
360 let protocol_msg = ProtocolMessage {
361 id: Some(1),
362 jsonrpc: "2.0".to_string(),
363 method: Some("assistant_message".to_string()),
364 params: Some(serde_json::json!({
365 "content": "Hello from assistant"
366 })),
367 result: None,
368 error: None,
369 };
370
371 let sdk_msg = protocol_to_sdk_message(&protocol_msg).unwrap().unwrap();
372 assert!(sdk_msg.is_assistant_message());
373 }
374
375 #[test]
376 fn test_protocol_to_sdk_message_result() {
377 let protocol_msg = ProtocolMessage {
378 id: Some(2),
379 jsonrpc: "2.0".to_string(),
380 method: None,
381 params: None,
382 result: Some(serde_json::json!({
383 "status": "success",
384 "data": "test data"
385 })),
386 error: None,
387 };
388
389 let sdk_msg = protocol_to_sdk_message(&protocol_msg).unwrap().unwrap();
390 assert!(sdk_msg.is_result_message());
391 }
392
393 #[test]
394 fn test_protocol_to_sdk_message_error() {
395 let protocol_msg = ProtocolMessage {
396 id: Some(3),
397 jsonrpc: "2.0".to_string(),
398 method: Some("error".to_string()),
399 params: None,
400 result: None,
401 error: Some(crate::transport::protocol::ProtocolError {
402 code: -1,
403 message: "Something went wrong".to_string(),
404 data: None,
405 }),
406 };
407
408 let result = protocol_to_sdk_message(&protocol_msg);
409 assert!(result.is_err());
410 assert!(result.unwrap_err().to_string().contains("CLI error"));
411 }
412
413 #[test]
414 fn test_protocol_to_sdk_message_unknown() {
415 let protocol_msg = ProtocolMessage {
416 id: Some(4),
417 jsonrpc: "2.0".to_string(),
418 method: Some("unknown_method".to_string()),
419 params: None,
420 result: None,
421 error: None,
422 };
423
424 let result = protocol_to_sdk_message(&protocol_msg).unwrap();
425 assert!(result.is_none());
426 }
427
428 #[tokio::test]
429 async fn test_cli_request_with_options() {
430 let options = QueryOptions {
431 model: Some("qwen-max".to_string()),
432 debug: true,
433 ..Default::default()
434 };
435
436 let request = CLIRequest {
437 request_type: "query".to_string(),
438 prompt: "Test prompt".to_string(),
439 session_id: None,
440 options,
441 };
442
443 let json = serde_json::to_string(&request).unwrap();
444 assert!(json.contains("\"model\":\"qwen-max\""));
445 assert!(json.contains("\"debug\":true"));
446 }
447}