Skip to main content

qwencode_rs/transport/
communication.rs

1// CLI process communication layer
2// Handles spawning the QwenCode CLI process and managing bidirectional communication
3
4use anyhow::{Context, Result};
5use serde::{Deserialize, Serialize};
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::process::{Child, ChildStdin, ChildStdout};
8use tokio::sync::mpsc;
9use tokio_util::sync::CancellationToken;
10use tracing::{debug, error, info, warn};
11
12use crate::transport::protocol::{create_notification, create_request, ProtocolMessage};
13use crate::types::config::QueryOptions;
14use crate::types::message::SDKMessage;
15
16/// Request to send to the CLI process
17#[derive(Debug, Clone, Serialize)]
18pub struct CLIRequest {
19    /// Request type
20    #[serde(rename = "type")]
21    pub request_type: String,
22    /// Prompt text
23    pub prompt: String,
24    /// Session ID
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub session_id: Option<String>,
27    /// Query options
28    #[serde(flatten)]
29    pub options: QueryOptions,
30}
31
32/// Initialize request
33#[derive(Debug, Clone, Serialize)]
34pub struct InitializeRequest {
35    pub protocol_version: String,
36    pub client: String,
37    pub client_version: String,
38}
39
40/// Initialize response from CLI
41#[derive(Debug, Clone, Deserialize)]
42pub struct InitializeResponse {
43    pub protocol_version: String,
44    pub capabilities: CLICapabilities,
45}
46
47/// CLI capabilities
48#[derive(Debug, Clone, Deserialize)]
49pub struct CLICapabilities {
50    #[serde(default)]
51    pub streaming: bool,
52    #[serde(default)]
53    pub tool_use: bool,
54    #[serde(default)]
55    pub multi_turn: bool,
56}
57
58/// Spawn QwenCode CLI process and return stdin/stdout handles
59pub async fn spawn_cli_process(executable_path: Option<&str>) -> Result<CLIProcess> {
60    let executable = executable_path.unwrap_or("qwen");
61
62    info!("Spawning QwenCode CLI process: {}", executable);
63
64    let mut child = tokio::process::Command::new(executable)
65        .kill_on_drop(true)
66        .stdin(std::process::Stdio::piped())
67        .stdout(std::process::Stdio::piped())
68        .stderr(std::process::Stdio::piped())
69        .spawn()
70        .context("Failed to spawn QwenCode CLI process")?;
71
72    let stdin = child.stdin.take().context("Failed to get stdin handle")?;
73
74    let stdout = child.stdout.take().context("Failed to get stdout handle")?;
75
76    let stderr = child.stderr.take().context("Failed to get stderr handle")?;
77
78    // Spawn stderr reader task
79    let (stderr_tx, stderr_rx) = mpsc::unbounded_channel::<String>();
80    tokio::spawn(read_stderr(stderr, stderr_tx));
81
82    debug!(
83        "QwenCode CLI process spawned successfully (PID: {:?})",
84        child.id()
85    );
86
87    Ok(CLIProcess {
88        child,
89        stdin,
90        stdout,
91        stderr_rx,
92        message_counter: 0,
93    })
94}
95
96/// Handle to a spawned CLI process
97pub struct CLIProcess {
98    child: Child,
99    stdin: ChildStdin,
100    stdout: ChildStdout,
101    stderr_rx: mpsc::UnboundedReceiver<String>,
102    message_counter: u64,
103}
104
105impl CLIProcess {
106    /// Send initialize request and wait for response
107    pub async fn initialize(
108        &mut self,
109        cancel_token: &CancellationToken,
110    ) -> Result<InitializeResponse> {
111        info!("Initializing CLI connection");
112
113        let init_request = InitializeRequest {
114            protocol_version: "1.0".to_string(),
115            client: "qwencode-rs".to_string(),
116            client_version: env!("CARGO_PKG_VERSION").to_string(),
117        };
118
119        let json = serde_json::to_string(&init_request)?;
120        let message = format!("{}\n", json);
121
122        self.stdin
123            .write_all(message.as_bytes())
124            .await
125            .context("Failed to send initialize request")?;
126        self.stdin.flush().await.context("Failed to flush stdin")?;
127
128        debug!("Initialize request sent");
129
130        // Read response
131        let mut reader = BufReader::new(&mut self.stdout);
132        let mut line = String::new();
133
134        tokio::select! {
135            result = reader.read_line(&mut line) => {
136                let bytes_read = result.context("Failed to read initialize response")?;
137                if bytes_read == 0 {
138                    return Err(anyhow::anyhow!("CLI process exited before responding"));
139                }
140
141                debug!("Initialize response: {}", line.trim());
142                let response: InitializeResponse = serde_json::from_str(&line)
143                    .context("Failed to parse initialize response")?;
144
145                info!("CLI initialized with protocol version: {}", response.protocol_version);
146                Ok(response)
147            }
148            _ = cancel_token.cancelled() => {
149                Err(anyhow::anyhow!("Initialize cancelled"))
150            }
151        }
152    }
153
154    /// Send a query request to the CLI
155    pub async fn send_query(&mut self, request: &CLIRequest) -> Result<()> {
156        self.message_counter += 1;
157        let id = self.message_counter;
158
159        let params = serde_json::to_value(request)?;
160        let message = create_request(id, "query", Some(params));
161
162        self.send_message(&message).await
163    }
164
165    /// Send a generic protocol message
166    async fn send_message(&mut self, message: &ProtocolMessage) -> Result<()> {
167        let json = serde_json::to_string(message)?;
168        let line = format!("{}\n", json);
169
170        debug!("Sending to CLI: {}", json);
171
172        self.stdin
173            .write_all(line.as_bytes())
174            .await
175            .context("Failed to write to stdin")?;
176        self.stdin.flush().await.context("Failed to flush stdin")?;
177
178        Ok(())
179    }
180
181    /// Read next message from stdout
182    pub async fn read_message(&mut self) -> Result<Option<ProtocolMessage>> {
183        let mut reader = BufReader::new(&mut self.stdout);
184        let mut line = String::new();
185
186        let bytes_read = reader
187            .read_line(&mut line)
188            .await
189            .context("Failed to read from stdout")?;
190
191        if bytes_read == 0 {
192            debug!("stdout closed (EOF)");
193            return Ok(None);
194        }
195
196        let line = line.trim().to_string();
197        if line.is_empty() {
198            return Ok(None);
199        }
200
201        debug!("Received from CLI: {}", line);
202
203        let message: ProtocolMessage = serde_json::from_str(&line)
204            .with_context(|| format!("Failed to parse message: {}", line))?;
205
206        Ok(Some(message))
207    }
208
209    /// Check if process is still running
210    pub fn is_running(&mut self) -> bool {
211        self.child
212            .try_wait()
213            .map(|opt| opt.is_none())
214            .unwrap_or(false)
215    }
216
217    /// Get process ID
218    pub fn pid(&self) -> Option<u32> {
219        self.child.id()
220    }
221
222    /// Gracefully shutdown the process
223    pub async fn shutdown(&mut self) -> Result<()> {
224        info!("Shutting down CLI process (PID: {:?})", self.pid());
225
226        // Send close notification
227        let close_msg = create_notification("close", None);
228        let _ = self.send_message(&close_msg).await;
229
230        // Wait a bit for graceful shutdown
231        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
232
233        // Try graceful exit
234        if let Ok(Some(status)) = self.child.try_wait() {
235            debug!("Process exited with status: {:?}", status);
236            return Ok(());
237        }
238
239        // Force kill if still running
240        if let Err(e) = self.child.kill().await {
241            warn!("Failed to kill process: {}", e);
242        }
243
244        match self.child.wait().await {
245            Ok(status) => {
246                info!("Process terminated with status: {:?}", status);
247                Ok(())
248            }
249            Err(e) => Err(anyhow::anyhow!("Failed to wait for process: {}", e)),
250        }
251    }
252
253    /// Poll stderr for any messages
254    pub fn try_receive_stderr(&mut self) -> Option<String> {
255        self.stderr_rx.try_recv().ok()
256    }
257}
258
259/// Read from stderr and send to channel
260async fn read_stderr(stderr: tokio::process::ChildStderr, sender: mpsc::UnboundedSender<String>) {
261    let mut reader = BufReader::new(stderr);
262    let mut line = String::new();
263
264    loop {
265        match reader.read_line(&mut line).await {
266            Ok(0) => {
267                debug!("stderr closed");
268                break;
269            }
270            Ok(_) => {
271                let trimmed = line.trim().to_string();
272                if !trimmed.is_empty() {
273                    debug!("stderr: {}", trimmed);
274                    let _ = sender.send(trimmed);
275                }
276                line.clear();
277            }
278            Err(e) => {
279                error!("Error reading stderr: {}", e);
280                break;
281            }
282        }
283    }
284}
285
286/// Convert a ProtocolMessage to SDKMessage
287pub fn protocol_to_sdk_message(message: &ProtocolMessage) -> Result<Option<SDKMessage>> {
288    // Check if it's a method call (incoming message from CLI)
289    if let Some(method) = &message.method {
290        match method.as_str() {
291            "assistant_message" => {
292                if let Some(params) = &message.params {
293                    let content = params
294                        .get("content")
295                        .and_then(|v| v.as_str())
296                        .unwrap_or("")
297                        .to_string();
298
299                    return Ok(Some(SDKMessage::from_assistant_text(&content)));
300                }
301            }
302            "result" => {
303                if let Some(params) = &message.params {
304                    return Ok(Some(SDKMessage::from_result_value(params.clone())));
305                }
306            }
307            "error" => {
308                if let Some(error) = &message.error {
309                    return Err(anyhow::anyhow!("CLI error: {}", error.message));
310                }
311            }
312            _ => {
313                debug!("Unknown method: {}", method);
314            }
315        }
316    }
317
318    // Check if it's a response
319    if let Some(result) = &message.result {
320        return Ok(Some(SDKMessage::from_result_value(result.clone())));
321    }
322
323    Ok(None)
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_cli_request_serialization() {
332        let request = CLIRequest {
333            request_type: "query".to_string(),
334            prompt: "Hello".to_string(),
335            session_id: Some("test-session".to_string()),
336            options: QueryOptions::default(),
337        };
338
339        let json = serde_json::to_string(&request).unwrap();
340        assert!(json.contains("\"type\":\"query\""));
341        assert!(json.contains("\"prompt\":\"Hello\""));
342        assert!(json.contains("\"session_id\":\"test-session\""));
343    }
344
345    #[test]
346    fn test_initialize_request_structure() {
347        let request = InitializeRequest {
348            protocol_version: "1.0".to_string(),
349            client: "qwencode-rs".to_string(),
350            client_version: "0.1.0".to_string(),
351        };
352
353        assert_eq!(request.protocol_version, "1.0");
354        assert_eq!(request.client, "qwencode-rs");
355        assert_eq!(request.client_version, "0.1.0");
356    }
357
358    #[test]
359    fn test_protocol_to_sdk_message_assistant() {
360        let protocol_msg = ProtocolMessage {
361            id: Some(1),
362            jsonrpc: "2.0".to_string(),
363            method: Some("assistant_message".to_string()),
364            params: Some(serde_json::json!({
365                "content": "Hello from assistant"
366            })),
367            result: None,
368            error: None,
369        };
370
371        let sdk_msg = protocol_to_sdk_message(&protocol_msg).unwrap().unwrap();
372        assert!(sdk_msg.is_assistant_message());
373    }
374
375    #[test]
376    fn test_protocol_to_sdk_message_result() {
377        let protocol_msg = ProtocolMessage {
378            id: Some(2),
379            jsonrpc: "2.0".to_string(),
380            method: None,
381            params: None,
382            result: Some(serde_json::json!({
383                "status": "success",
384                "data": "test data"
385            })),
386            error: None,
387        };
388
389        let sdk_msg = protocol_to_sdk_message(&protocol_msg).unwrap().unwrap();
390        assert!(sdk_msg.is_result_message());
391    }
392
393    #[test]
394    fn test_protocol_to_sdk_message_error() {
395        let protocol_msg = ProtocolMessage {
396            id: Some(3),
397            jsonrpc: "2.0".to_string(),
398            method: Some("error".to_string()),
399            params: None,
400            result: None,
401            error: Some(crate::transport::protocol::ProtocolError {
402                code: -1,
403                message: "Something went wrong".to_string(),
404                data: None,
405            }),
406        };
407
408        let result = protocol_to_sdk_message(&protocol_msg);
409        assert!(result.is_err());
410        assert!(result.unwrap_err().to_string().contains("CLI error"));
411    }
412
413    #[test]
414    fn test_protocol_to_sdk_message_unknown() {
415        let protocol_msg = ProtocolMessage {
416            id: Some(4),
417            jsonrpc: "2.0".to_string(),
418            method: Some("unknown_method".to_string()),
419            params: None,
420            result: None,
421            error: None,
422        };
423
424        let result = protocol_to_sdk_message(&protocol_msg).unwrap();
425        assert!(result.is_none());
426    }
427
428    #[tokio::test]
429    async fn test_cli_request_with_options() {
430        let options = QueryOptions {
431            model: Some("qwen-max".to_string()),
432            debug: true,
433            ..Default::default()
434        };
435
436        let request = CLIRequest {
437            request_type: "query".to_string(),
438            prompt: "Test prompt".to_string(),
439            session_id: None,
440            options,
441        };
442
443        let json = serde_json::to_string(&request).unwrap();
444        assert!(json.contains("\"model\":\"qwen-max\""));
445        assert!(json.contains("\"debug\":true"));
446    }
447}