1use std::sync::Arc;
2
3use futures_util::{SinkExt, StreamExt};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tokio::sync::Mutex;
7use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
8
9use crate::error::IIIError;
10
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12#[serde(rename_all = "lowercase")]
13pub enum ChannelDirection {
14 #[default]
15 Read,
16 Write,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, Default)]
20pub struct StreamChannelRef {
21 pub channel_id: String,
22 pub access_key: String,
23 pub direction: ChannelDirection,
24}
25
26#[derive(Debug, Clone)]
27pub enum ChannelItem {
28 Text(String),
29 Binary(Vec<u8>),
30}
31
32type WsWriter = futures_util::stream::SplitSink<
33 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
34 WsMessage,
35>;
36
37type WsReader = futures_util::stream::SplitStream<
38 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
39>;
40
41fn build_channel_url(
42 engine_ws_base: &str,
43 channel_id: &str,
44 access_key: &str,
45 direction: &str,
46) -> String {
47 let base = engine_ws_base.trim_end_matches('/');
48 let encoded_key = urlencoded(access_key);
49 format!("{base}/ws/channels/{channel_id}?key={encoded_key}&dir={direction}")
50}
51
52fn urlencoded(s: &str) -> String {
53 let mut result = String::with_capacity(s.len());
54 for b in s.bytes() {
55 match b {
56 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
57 result.push(b as char);
58 }
59 _ => {
60 result.push('%');
61 result.push(char::from(b"0123456789ABCDEF"[(b >> 4) as usize]));
62 result.push(char::from(b"0123456789ABCDEF"[(b & 0x0F) as usize]));
63 }
64 }
65 }
66 result
67}
68
69pub struct ChannelWriter {
71 url: String,
72 ws: Arc<Mutex<Option<WsWriter>>>,
73}
74
75impl ChannelWriter {
76 pub fn new(engine_ws_base: &str, channel_ref: &StreamChannelRef) -> Self {
77 Self {
78 url: build_channel_url(
79 engine_ws_base,
80 &channel_ref.channel_id,
81 &channel_ref.access_key,
82 "write",
83 ),
84 ws: Arc::new(Mutex::new(None)),
85 }
86 }
87
88 async fn ensure_connected(&self) -> Result<(), IIIError> {
89 let mut guard = self.ws.lock().await;
90 if guard.is_some() {
91 return Ok(());
92 }
93 let (stream, _) = connect_async(&self.url).await?;
94 let (writer, _reader) = stream.split();
95 *guard = Some(writer);
96 Ok(())
97 }
98
99 const MAX_FRAME_SIZE: usize = 64 * 1024;
100
101 pub async fn write(&self, data: &[u8]) -> Result<(), IIIError> {
102 self.ensure_connected().await?;
103 let mut guard = self.ws.lock().await;
104 let ws = guard.as_mut().ok_or(IIIError::NotConnected)?;
105 for chunk in data.chunks(Self::MAX_FRAME_SIZE) {
106 ws.send(WsMessage::Binary(chunk.to_vec().into())).await?;
107 }
108 Ok(())
109 }
110
111 pub async fn send_message(&self, msg: &str) -> Result<(), IIIError> {
112 self.ensure_connected().await?;
113 let mut guard = self.ws.lock().await;
114 let ws = guard.as_mut().ok_or(IIIError::NotConnected)?;
115 ws.send(WsMessage::Text(msg.to_string().into())).await?;
116 Ok(())
117 }
118
119 pub async fn close(&self) -> Result<(), IIIError> {
120 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
124 let mut guard = self.ws.lock().await;
125 if let Some(ws) = guard.as_mut() {
126 ws.send(WsMessage::Close(None)).await?;
127 }
128 *guard = None;
129 Ok(())
130 }
131}
132
133type MessageCallback = Box<dyn Fn(String) + Send + Sync>;
134type MessageCallbackList = Arc<Mutex<Vec<MessageCallback>>>;
135
136pub struct ChannelReader {
138 url: String,
139 ws: Arc<Mutex<Option<WsReader>>>,
140 message_callbacks: MessageCallbackList,
141}
142
143impl ChannelReader {
144 pub fn new(engine_ws_base: &str, channel_ref: &StreamChannelRef) -> Self {
145 Self {
146 url: build_channel_url(
147 engine_ws_base,
148 &channel_ref.channel_id,
149 &channel_ref.access_key,
150 "read",
151 ),
152 ws: Arc::new(Mutex::new(None)),
153 message_callbacks: Arc::new(Mutex::new(Vec::new())),
154 }
155 }
156
157 async fn ensure_connected(&self) -> Result<(), IIIError> {
158 let mut guard = self.ws.lock().await;
159 if guard.is_some() {
160 return Ok(());
161 }
162 let (stream, _) = connect_async(&self.url).await?;
163 let (_writer, reader) = stream.split();
164 *guard = Some(reader);
165 Ok(())
166 }
167
168 pub async fn on_message<F>(&self, callback: F)
170 where
171 F: Fn(String) + Send + Sync + 'static,
172 {
173 self.message_callbacks.lock().await.push(Box::new(callback));
174 }
175
176 pub async fn next_binary(&self) -> Result<Option<Vec<u8>>, IIIError> {
180 self.ensure_connected().await?;
181
182 loop {
183 let mut guard = self.ws.lock().await;
184 let mut reader = guard.take().ok_or(IIIError::NotConnected)?;
185 drop(guard);
186
187 let msg = reader.next().await;
188
189 let mut guard = self.ws.lock().await;
190 *guard = Some(reader);
191 drop(guard);
192
193 match msg {
194 Some(Ok(WsMessage::Binary(data))) => return Ok(Some(data.to_vec())),
195 Some(Ok(WsMessage::Text(text))) => {
196 let callbacks = self.message_callbacks.lock().await;
197 for cb in callbacks.iter() {
198 cb(text.to_string());
199 }
200 }
201 Some(Ok(WsMessage::Close(_))) | None => return Ok(None),
202 Some(Ok(_)) => continue,
203 Some(Err(e)) => return Err(IIIError::WebSocket(e.to_string())),
204 }
205 }
206 }
207
208 pub async fn read_all(&self) -> Result<Vec<u8>, IIIError> {
210 let mut buffer = Vec::new();
211 while let Some(chunk) = self.next_binary().await? {
212 buffer.extend_from_slice(&chunk);
213 }
214 Ok(buffer)
215 }
216
217 pub async fn close(&self) -> Result<(), IIIError> {
218 let mut guard = self.ws.lock().await;
219 *guard = None;
220 Ok(())
221 }
222}
223
224pub fn is_channel_ref(value: &Value) -> bool {
226 value.is_object()
227 && value.get("channel_id").is_some_and(|v| v.is_string())
228 && value.get("access_key").is_some_and(|v| v.is_string())
229 && value.get("direction").is_some_and(|v| v.is_string())
230}
231
232pub fn extract_channel_refs(data: &Value) -> Vec<(String, StreamChannelRef)> {
235 let mut refs = Vec::new();
236 extract_refs_recursive(data, String::new(), &mut refs);
237 refs
238}
239
240fn extract_refs_recursive(
241 data: &Value,
242 prefix: String,
243 refs: &mut Vec<(String, StreamChannelRef)>,
244) {
245 if let Some(obj) = data.as_object() {
246 for (key, value) in obj {
247 let path = if prefix.is_empty() {
248 key.clone()
249 } else {
250 format!("{prefix}.{key}")
251 };
252
253 if is_channel_ref(value) {
254 if let Ok(channel_ref) = serde_json::from_value::<StreamChannelRef>(value.clone()) {
255 refs.push((path, channel_ref));
256 }
257 } else if value.is_object() {
258 extract_refs_recursive(value, path.clone(), refs);
259 } else if let Some(arr) = value.as_array() {
260 for (idx, item) in arr.iter().enumerate() {
261 extract_refs_recursive(item, format!("{path}[{idx}]"), refs);
262 }
263 }
264 }
265 } else if let Some(arr) = data.as_array() {
266 for (idx, item) in arr.iter().enumerate() {
267 let path = if prefix.is_empty() {
268 format!("[{idx}]")
269 } else {
270 format!("{prefix}[{idx}]")
271 };
272 extract_refs_recursive(item, path, refs);
273 }
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use std::sync::Arc;
280 use tokio::sync::Mutex;
281
282 use super::*;
283
284 #[tokio::test]
291 async fn close_sleeps_before_clearing_ws() {
292 tokio::time::pause();
293
294 let writer = ChannelWriter {
295 url: "ws://test".to_string(),
296 ws: Arc::new(Mutex::new(None)),
297 };
298
299 let start = tokio::time::Instant::now();
300 writer.close().await.expect("close() should not fail");
301 let elapsed = start.elapsed();
302
303 assert!(
304 elapsed >= std::time::Duration::from_millis(10),
305 "expected at least 10 ms elapsed, got {:?}",
306 elapsed
307 );
308 }
309
310 #[tokio::test]
312 async fn close_when_not_connected_succeeds() {
313 let writer = ChannelWriter {
314 url: "ws://test".to_string(),
315 ws: Arc::new(Mutex::new(None)),
316 };
317
318 let result = writer.close().await;
319 assert!(result.is_ok(), "expected Ok(()), got {:?}", result);
320 }
321
322 #[tokio::test]
324 async fn close_sets_ws_to_none() {
325 let writer = ChannelWriter {
326 url: "ws://test".to_string(),
327 ws: Arc::new(Mutex::new(None)),
328 };
329
330 writer.close().await.expect("close() should not fail");
331
332 let guard = writer.ws.lock().await;
333 assert!(
334 guard.is_none(),
335 "expected ws to be None after close(), but it was Some"
336 );
337 }
338
339 #[test]
346 fn build_channel_url_formats_correctly() {
347 let url = build_channel_url("http://engine", "chan-1", "mykey", "write");
348 assert_eq!(url, "http://engine/ws/channels/chan-1?key=mykey&dir=write");
349 }
350
351 #[test]
354 fn build_channel_url_strips_trailing_slash_from_base() {
355 let url = build_channel_url("http://engine/", "chan-2", "k", "read");
356 assert_eq!(url, "http://engine/ws/channels/chan-2?key=k&dir=read");
357 }
358
359 #[test]
366 fn urlencoded_encodes_special_chars() {
367 let encoded = urlencoded("hello world+/=");
368 assert_eq!(encoded, "hello%20world%2B%2F%3D");
369 }
370
371 #[test]
374 fn urlencoded_leaves_unreserved_chars_unchanged() {
375 let input = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~";
376 assert_eq!(urlencoded(input), input);
377 }
378
379 #[test]
381 fn urlencoded_returns_empty_for_empty_input() {
382 assert_eq!(urlencoded(""), "");
383 }
384}