Skip to main content

qwencode_rs/transport/
stream.rs

1use crate::types::message::SDKMessage;
2use async_channel::{Receiver, Sender};
3use tracing::debug;
4
5/// Stream of SDK messages from the CLI
6pub struct MessageStream {
7    receiver: Receiver<Result<SDKMessage, anyhow::Error>>,
8    closed: bool,
9}
10
11impl MessageStream {
12    pub fn new(receiver: Receiver<Result<SDKMessage, anyhow::Error>>) -> Self {
13        MessageStream {
14            receiver,
15            closed: false,
16        }
17    }
18
19    /// Receive the next message from the stream
20    pub async fn next_message(&self) -> Option<Result<SDKMessage, anyhow::Error>> {
21        match self.receiver.recv().await {
22            Ok(msg) => {
23                debug!("Received message from stream");
24                Some(msg)
25            }
26            Err(_) => {
27                debug!("Message stream closed");
28                None
29            }
30        }
31    }
32
33    /// Check if the stream is closed
34    pub fn is_closed(&self) -> bool {
35        self.closed || self.receiver.is_closed()
36    }
37}
38
39/// Message handler that processes incoming messages
40pub struct MessageHandler {
41    sender: Sender<Result<SDKMessage, anyhow::Error>>,
42}
43
44impl Clone for MessageHandler {
45    fn clone(&self) -> Self {
46        MessageHandler {
47            sender: self.sender.clone(),
48        }
49    }
50}
51
52impl MessageHandler {
53    pub fn new(sender: Sender<Result<SDKMessage, anyhow::Error>>) -> Self {
54        MessageHandler { sender }
55    }
56
57    /// Send a message to the stream
58    pub async fn send_message(&self, message: SDKMessage) -> Result<(), anyhow::Error> {
59        self.sender.send(Ok(message)).await?;
60        Ok(())
61    }
62
63    /// Send an error to the stream
64    pub async fn send_error(&self, error: anyhow::Error) -> Result<(), anyhow::Error> {
65        self.sender.send(Err(error)).await?;
66        Ok(())
67    }
68
69    /// Close the handler
70    pub fn close(&self) {
71        self.sender.close();
72    }
73}
74
75/// Create a message stream pair (sender and MessageStream)
76pub fn create_message_stream() -> (MessageHandler, MessageStream) {
77    let (sender, receiver) = async_channel::unbounded();
78    let handler = MessageHandler::new(sender);
79    let stream = MessageStream::new(receiver);
80
81    (handler, stream)
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use crate::types::message::{MessageContent, MessageRole, SDKUserMessage};
88
89    #[tokio::test]
90    async fn test_message_stream_send_and_receive() {
91        let (handler, stream) = create_message_stream();
92
93        let message = SDKMessage::User(SDKUserMessage {
94            session_id: "test".to_string(),
95            message: MessageContent {
96                role: MessageRole::User,
97                content: "Hello".to_string(),
98            },
99            parent_tool_use_id: None,
100        });
101
102        handler.send_message(message.clone()).await.unwrap();
103
104        let received = stream.next_message().await.unwrap().unwrap();
105
106        assert_eq!(received.session_id(), "test");
107        assert!(received.is_user_message());
108    }
109
110    #[tokio::test]
111    async fn test_message_stream_send_error() {
112        let (handler, stream) = create_message_stream();
113
114        let error = anyhow::anyhow!("Test error");
115        handler.send_error(error).await.unwrap();
116
117        let received = stream.next_message().await.unwrap();
118        assert!(received.is_err());
119    }
120
121    #[tokio::test]
122    async fn test_message_stream_close() {
123        let (handler, stream) = create_message_stream();
124
125        assert!(!stream.is_closed());
126
127        handler.close();
128        assert!(stream.is_closed());
129    }
130
131    #[tokio::test]
132    async fn test_message_stream_multiple_messages() {
133        let (handler, stream) = create_message_stream();
134
135        for i in 0..3 {
136            let message = SDKMessage::User(SDKUserMessage {
137                session_id: format!("session-{}", i),
138                message: MessageContent {
139                    role: MessageRole::User,
140                    content: format!("Message {}", i),
141                },
142                parent_tool_use_id: None,
143            });
144
145            handler.send_message(message).await.unwrap();
146        }
147
148        for i in 0..3 {
149            let received = stream.next_message().await.unwrap().unwrap();
150            assert_eq!(received.session_id(), format!("session-{}", i));
151        }
152    }
153
154    #[test]
155    fn test_message_stream_initial_state() {
156        let (handler, stream) = create_message_stream();
157
158        // Stream is not closed while handler is alive
159        assert!(!stream.is_closed());
160
161        // Prevent unused variable warning
162        drop(handler);
163    }
164
165    #[tokio::test]
166    async fn test_message_handler_creation() {
167        let (handler, stream) = create_message_stream();
168
169        let message = SDKMessage::User(SDKUserMessage {
170            session_id: "test".to_string(),
171            message: MessageContent {
172                role: MessageRole::User,
173                content: "test".to_string(),
174            },
175            parent_tool_use_id: None,
176        });
177
178        handler.send_message(message).await.unwrap();
179        let received = stream.next_message().await.unwrap().unwrap();
180        assert_eq!(received.session_id(), "test");
181    }
182}