Skip to main content

iii_sdk/
channels.rs

1use std::sync::Arc;
2
3use futures_util::{SinkExt, StreamExt};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use tokio::sync::Mutex;
7use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
8
9use crate::error::IIIError;
10
11#[derive(Debug, Clone, Serialize, Deserialize, Default)]
12#[serde(rename_all = "lowercase")]
13pub enum ChannelDirection {
14    #[default]
15    Read,
16    Write,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, Default)]
20pub struct StreamChannelRef {
21    pub channel_id: String,
22    pub access_key: String,
23    pub direction: ChannelDirection,
24}
25
26#[derive(Debug, Clone)]
27pub enum ChannelItem {
28    Text(String),
29    Binary(Vec<u8>),
30}
31
32type WsWriter = futures_util::stream::SplitSink<
33    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
34    WsMessage,
35>;
36
37type WsReader = futures_util::stream::SplitStream<
38    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
39>;
40
41fn build_channel_url(
42    engine_ws_base: &str,
43    channel_id: &str,
44    access_key: &str,
45    direction: &str,
46) -> String {
47    let base = engine_ws_base.trim_end_matches('/');
48    let encoded_key = urlencoded(access_key);
49    format!("{base}/ws/channels/{channel_id}?key={encoded_key}&dir={direction}")
50}
51
52fn urlencoded(s: &str) -> String {
53    let mut result = String::with_capacity(s.len());
54    for b in s.bytes() {
55        match b {
56            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
57                result.push(b as char);
58            }
59            _ => {
60                result.push('%');
61                result.push(char::from(b"0123456789ABCDEF"[(b >> 4) as usize]));
62                result.push(char::from(b"0123456789ABCDEF"[(b & 0x0F) as usize]));
63            }
64        }
65    }
66    result
67}
68
69/// WebSocket-backed writer for streaming binary data and text messages.
70pub struct ChannelWriter {
71    url: String,
72    ws: Arc<Mutex<Option<WsWriter>>>,
73}
74
75impl ChannelWriter {
76    pub fn new(engine_ws_base: &str, channel_ref: &StreamChannelRef) -> Self {
77        Self {
78            url: build_channel_url(
79                engine_ws_base,
80                &channel_ref.channel_id,
81                &channel_ref.access_key,
82                "write",
83            ),
84            ws: Arc::new(Mutex::new(None)),
85        }
86    }
87
88    async fn ensure_connected(&self) -> Result<(), IIIError> {
89        let mut guard = self.ws.lock().await;
90        if guard.is_some() {
91            return Ok(());
92        }
93        let (stream, _) = connect_async(&self.url).await?;
94        let (writer, _reader) = stream.split();
95        *guard = Some(writer);
96        Ok(())
97    }
98
99    const MAX_FRAME_SIZE: usize = 64 * 1024;
100
101    pub async fn write(&self, data: &[u8]) -> Result<(), IIIError> {
102        self.ensure_connected().await?;
103        let mut guard = self.ws.lock().await;
104        let ws = guard.as_mut().ok_or(IIIError::NotConnected)?;
105        for chunk in data.chunks(Self::MAX_FRAME_SIZE) {
106            ws.send(WsMessage::Binary(chunk.to_vec().into())).await?;
107        }
108        Ok(())
109    }
110
111    pub async fn send_message(&self, msg: &str) -> Result<(), IIIError> {
112        self.ensure_connected().await?;
113        let mut guard = self.ws.lock().await;
114        let ws = guard.as_mut().ok_or(IIIError::NotConnected)?;
115        ws.send(WsMessage::Text(msg.to_string().into())).await?;
116        Ok(())
117    }
118
119    pub async fn close(&self) -> Result<(), IIIError> {
120        // Delay the close frame slightly to allow the TCP stack to flush
121        // all buffered send() data. Without this, the close frame can arrive
122        // at the engine before all data frames, causing data truncation.
123        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
124        let mut guard = self.ws.lock().await;
125        if let Some(ws) = guard.as_mut() {
126            ws.send(WsMessage::Close(None)).await?;
127        }
128        *guard = None;
129        Ok(())
130    }
131}
132
133type MessageCallback = Box<dyn Fn(String) + Send + Sync>;
134type MessageCallbackList = Arc<Mutex<Vec<MessageCallback>>>;
135
136/// WebSocket-backed reader for streaming binary data and text messages.
137pub struct ChannelReader {
138    url: String,
139    ws: Arc<Mutex<Option<WsReader>>>,
140    message_callbacks: MessageCallbackList,
141}
142
143impl ChannelReader {
144    pub fn new(engine_ws_base: &str, channel_ref: &StreamChannelRef) -> Self {
145        Self {
146            url: build_channel_url(
147                engine_ws_base,
148                &channel_ref.channel_id,
149                &channel_ref.access_key,
150                "read",
151            ),
152            ws: Arc::new(Mutex::new(None)),
153            message_callbacks: Arc::new(Mutex::new(Vec::new())),
154        }
155    }
156
157    async fn ensure_connected(&self) -> Result<(), IIIError> {
158        let mut guard = self.ws.lock().await;
159        if guard.is_some() {
160            return Ok(());
161        }
162        let (stream, _) = connect_async(&self.url).await?;
163        let (_writer, reader) = stream.split();
164        *guard = Some(reader);
165        Ok(())
166    }
167
168    /// Register a callback for text messages received on this channel.
169    pub async fn on_message<F>(&self, callback: F)
170    where
171        F: Fn(String) + Send + Sync + 'static,
172    {
173        self.message_callbacks.lock().await.push(Box::new(callback));
174    }
175
176    /// Read the next binary chunk from the channel.
177    /// Text messages are dispatched to registered callbacks.
178    /// Returns `None` when the stream is closed.
179    pub async fn next_binary(&self) -> Result<Option<Vec<u8>>, IIIError> {
180        self.ensure_connected().await?;
181
182        loop {
183            let mut guard = self.ws.lock().await;
184            let mut reader = guard.take().ok_or(IIIError::NotConnected)?;
185            drop(guard);
186
187            let msg = reader.next().await;
188
189            let mut guard = self.ws.lock().await;
190            *guard = Some(reader);
191            drop(guard);
192
193            match msg {
194                Some(Ok(WsMessage::Binary(data))) => return Ok(Some(data.to_vec())),
195                Some(Ok(WsMessage::Text(text))) => {
196                    let callbacks = self.message_callbacks.lock().await;
197                    for cb in callbacks.iter() {
198                        cb(text.to_string());
199                    }
200                }
201                Some(Ok(WsMessage::Close(_))) | None => return Ok(None),
202                Some(Ok(_)) => continue,
203                Some(Err(e)) => return Err(IIIError::WebSocket(e.to_string())),
204            }
205        }
206    }
207
208    /// Read the entire stream into a single `Vec<u8>`.
209    pub async fn read_all(&self) -> Result<Vec<u8>, IIIError> {
210        let mut buffer = Vec::new();
211        while let Some(chunk) = self.next_binary().await? {
212            buffer.extend_from_slice(&chunk);
213        }
214        Ok(buffer)
215    }
216
217    pub async fn close(&self) -> Result<(), IIIError> {
218        let mut guard = self.ws.lock().await;
219        *guard = None;
220        Ok(())
221    }
222}
223
224/// Check if a JSON value looks like a StreamChannelRef.
225pub fn is_channel_ref(value: &Value) -> bool {
226    value.is_object()
227        && value.get("channel_id").is_some_and(|v| v.is_string())
228        && value.get("access_key").is_some_and(|v| v.is_string())
229        && value.get("direction").is_some_and(|v| v.is_string())
230}
231
232/// Extract all channel references from a JSON value's top-level fields,
233/// returning the field path and the deserialized ref.
234pub fn extract_channel_refs(data: &Value) -> Vec<(String, StreamChannelRef)> {
235    let mut refs = Vec::new();
236    extract_refs_recursive(data, String::new(), &mut refs);
237    refs
238}
239
240fn extract_refs_recursive(
241    data: &Value,
242    prefix: String,
243    refs: &mut Vec<(String, StreamChannelRef)>,
244) {
245    if let Some(obj) = data.as_object() {
246        for (key, value) in obj {
247            let path = if prefix.is_empty() {
248                key.clone()
249            } else {
250                format!("{prefix}.{key}")
251            };
252
253            if is_channel_ref(value) {
254                if let Ok(channel_ref) = serde_json::from_value::<StreamChannelRef>(value.clone()) {
255                    refs.push((path, channel_ref));
256                }
257            } else if value.is_object() {
258                extract_refs_recursive(value, path.clone(), refs);
259            } else if let Some(arr) = value.as_array() {
260                for (idx, item) in arr.iter().enumerate() {
261                    extract_refs_recursive(item, format!("{path}[{idx}]"), refs);
262                }
263            }
264        }
265    } else if let Some(arr) = data.as_array() {
266        for (idx, item) in arr.iter().enumerate() {
267            let path = if prefix.is_empty() {
268                format!("[{idx}]")
269            } else {
270                format!("{prefix}[{idx}]")
271            };
272            extract_refs_recursive(item, path, refs);
273        }
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use std::sync::Arc;
280    use tokio::sync::Mutex;
281
282    use super::*;
283
284    // ---------------------------------------------------------------------------
285    // ChannelWriter::close() – timing and state tests
286    // ---------------------------------------------------------------------------
287
288    /// close() must sleep at least 10 ms before clearing the ws field, even when
289    /// there is no live WebSocket connection (ws = None).
290    #[tokio::test]
291    async fn close_sleeps_before_clearing_ws() {
292        tokio::time::pause();
293
294        let writer = ChannelWriter {
295            url: "ws://test".to_string(),
296            ws: Arc::new(Mutex::new(None)),
297        };
298
299        let start = tokio::time::Instant::now();
300        writer.close().await.expect("close() should not fail");
301        let elapsed = start.elapsed();
302
303        assert!(
304            elapsed >= std::time::Duration::from_millis(10),
305            "expected at least 10 ms elapsed, got {:?}",
306            elapsed
307        );
308    }
309
310    /// close() must return Ok(()) when the writer was never connected (ws = None).
311    #[tokio::test]
312    async fn close_when_not_connected_succeeds() {
313        let writer = ChannelWriter {
314            url: "ws://test".to_string(),
315            ws: Arc::new(Mutex::new(None)),
316        };
317
318        let result = writer.close().await;
319        assert!(result.is_ok(), "expected Ok(()), got {:?}", result);
320    }
321
322    /// After close() completes, ws must be None regardless of its initial value.
323    #[tokio::test]
324    async fn close_sets_ws_to_none() {
325        let writer = ChannelWriter {
326            url: "ws://test".to_string(),
327            ws: Arc::new(Mutex::new(None)),
328        };
329
330        writer.close().await.expect("close() should not fail");
331
332        let guard = writer.ws.lock().await;
333        assert!(
334            guard.is_none(),
335            "expected ws to be None after close(), but it was Some"
336        );
337    }
338
339    // ---------------------------------------------------------------------------
340    // build_channel_url helper
341    // ---------------------------------------------------------------------------
342
343    /// build_channel_url must produce the correct URL structure including query
344    /// parameters for key and direction.
345    #[test]
346    fn build_channel_url_formats_correctly() {
347        let url = build_channel_url("http://engine", "chan-1", "mykey", "write");
348        assert_eq!(url, "http://engine/ws/channels/chan-1?key=mykey&dir=write");
349    }
350
351    /// build_channel_url must strip a trailing slash from the base URL so the
352    /// resulting URL does not contain a double slash.
353    #[test]
354    fn build_channel_url_strips_trailing_slash_from_base() {
355        let url = build_channel_url("http://engine/", "chan-2", "k", "read");
356        assert_eq!(url, "http://engine/ws/channels/chan-2?key=k&dir=read");
357    }
358
359    // ---------------------------------------------------------------------------
360    // urlencoded helper
361    // ---------------------------------------------------------------------------
362
363    /// urlencoded must percent-encode characters outside the unreserved set
364    /// (letters, digits, -, _, ., ~).
365    #[test]
366    fn urlencoded_encodes_special_chars() {
367        let encoded = urlencoded("hello world+/=");
368        assert_eq!(encoded, "hello%20world%2B%2F%3D");
369    }
370
371    /// urlencoded must leave unreserved characters (A-Z, a-z, 0-9, -, _, ., ~)
372    /// unchanged.
373    #[test]
374    fn urlencoded_leaves_unreserved_chars_unchanged() {
375        let input = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~";
376        assert_eq!(urlencoded(input), input);
377    }
378
379    /// urlencoded must produce an empty string for an empty input.
380    #[test]
381    fn urlencoded_returns_empty_for_empty_input() {
382        assert_eq!(urlencoded(""), "");
383    }
384}