qwencode_rs/transport/
stream.rs1use crate::types::message::SDKMessage;
2use async_channel::{Receiver, Sender};
3use tracing::debug;
4
5pub struct MessageStream {
7 receiver: Receiver<Result<SDKMessage, anyhow::Error>>,
8 closed: bool,
9}
10
11impl MessageStream {
12 pub fn new(receiver: Receiver<Result<SDKMessage, anyhow::Error>>) -> Self {
13 MessageStream {
14 receiver,
15 closed: false,
16 }
17 }
18
19 pub async fn next_message(&self) -> Option<Result<SDKMessage, anyhow::Error>> {
21 match self.receiver.recv().await {
22 Ok(msg) => {
23 debug!("Received message from stream");
24 Some(msg)
25 }
26 Err(_) => {
27 debug!("Message stream closed");
28 None
29 }
30 }
31 }
32
33 pub fn is_closed(&self) -> bool {
35 self.closed || self.receiver.is_closed()
36 }
37}
38
39pub struct MessageHandler {
41 sender: Sender<Result<SDKMessage, anyhow::Error>>,
42}
43
44impl Clone for MessageHandler {
45 fn clone(&self) -> Self {
46 MessageHandler {
47 sender: self.sender.clone(),
48 }
49 }
50}
51
52impl MessageHandler {
53 pub fn new(sender: Sender<Result<SDKMessage, anyhow::Error>>) -> Self {
54 MessageHandler { sender }
55 }
56
57 pub async fn send_message(&self, message: SDKMessage) -> Result<(), anyhow::Error> {
59 self.sender.send(Ok(message)).await?;
60 Ok(())
61 }
62
63 pub async fn send_error(&self, error: anyhow::Error) -> Result<(), anyhow::Error> {
65 self.sender.send(Err(error)).await?;
66 Ok(())
67 }
68
69 pub fn close(&self) {
71 self.sender.close();
72 }
73}
74
75pub fn create_message_stream() -> (MessageHandler, MessageStream) {
77 let (sender, receiver) = async_channel::unbounded();
78 let handler = MessageHandler::new(sender);
79 let stream = MessageStream::new(receiver);
80
81 (handler, stream)
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use crate::types::message::{MessageContent, MessageRole, SDKUserMessage};
88
89 #[tokio::test]
90 async fn test_message_stream_send_and_receive() {
91 let (handler, stream) = create_message_stream();
92
93 let message = SDKMessage::User(SDKUserMessage {
94 session_id: "test".to_string(),
95 message: MessageContent {
96 role: MessageRole::User,
97 content: "Hello".to_string(),
98 },
99 parent_tool_use_id: None,
100 });
101
102 handler.send_message(message.clone()).await.unwrap();
103
104 let received = stream.next_message().await.unwrap().unwrap();
105
106 assert_eq!(received.session_id(), "test");
107 assert!(received.is_user_message());
108 }
109
110 #[tokio::test]
111 async fn test_message_stream_send_error() {
112 let (handler, stream) = create_message_stream();
113
114 let error = anyhow::anyhow!("Test error");
115 handler.send_error(error).await.unwrap();
116
117 let received = stream.next_message().await.unwrap();
118 assert!(received.is_err());
119 }
120
121 #[tokio::test]
122 async fn test_message_stream_close() {
123 let (handler, stream) = create_message_stream();
124
125 assert!(!stream.is_closed());
126
127 handler.close();
128 assert!(stream.is_closed());
129 }
130
131 #[tokio::test]
132 async fn test_message_stream_multiple_messages() {
133 let (handler, stream) = create_message_stream();
134
135 for i in 0..3 {
136 let message = SDKMessage::User(SDKUserMessage {
137 session_id: format!("session-{}", i),
138 message: MessageContent {
139 role: MessageRole::User,
140 content: format!("Message {}", i),
141 },
142 parent_tool_use_id: None,
143 });
144
145 handler.send_message(message).await.unwrap();
146 }
147
148 for i in 0..3 {
149 let received = stream.next_message().await.unwrap().unwrap();
150 assert_eq!(received.session_id(), format!("session-{}", i));
151 }
152 }
153
154 #[test]
155 fn test_message_stream_initial_state() {
156 let (handler, stream) = create_message_stream();
157
158 assert!(!stream.is_closed());
160
161 drop(handler);
163 }
164
165 #[tokio::test]
166 async fn test_message_handler_creation() {
167 let (handler, stream) = create_message_stream();
168
169 let message = SDKMessage::User(SDKUserMessage {
170 session_id: "test".to_string(),
171 message: MessageContent {
172 role: MessageRole::User,
173 content: "test".to_string(),
174 },
175 parent_tool_use_id: None,
176 });
177
178 handler.send_message(message).await.unwrap();
179 let received = stream.next_message().await.unwrap().unwrap();
180 assert_eq!(received.session_id(), "test");
181 }
182}