iii-sdk 0.4.1

SDK for III Engine - a platform for building distributed applications
Documentation
use std::sync::Arc;

use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::Mutex;
use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};

use crate::error::IIIError;

#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ChannelDirection {
    #[default]
    Read,
    Write,
}

#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct StreamChannelRef {
    pub channel_id: String,
    pub access_key: String,
    pub direction: ChannelDirection,
}

#[derive(Debug, Clone)]
pub enum ChannelItem {
    Text(String),
    Binary(Vec<u8>),
}

type WsWriter = futures_util::stream::SplitSink<
    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
    WsMessage,
>;

type WsReader = futures_util::stream::SplitStream<
    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
>;

fn build_channel_url(
    engine_ws_base: &str,
    channel_id: &str,
    access_key: &str,
    direction: &str,
) -> String {
    let base = engine_ws_base.trim_end_matches('/');
    let encoded_key = urlencoded(access_key);
    format!("{base}/ws/channels/{channel_id}?key={encoded_key}&dir={direction}")
}

fn urlencoded(s: &str) -> String {
    let mut result = String::with_capacity(s.len());
    for b in s.bytes() {
        match b {
            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
                result.push(b as char);
            }
            _ => {
                result.push('%');
                result.push(char::from(b"0123456789ABCDEF"[(b >> 4) as usize]));
                result.push(char::from(b"0123456789ABCDEF"[(b & 0x0F) as usize]));
            }
        }
    }
    result
}

/// WebSocket-backed writer for streaming binary data and text messages.
pub struct ChannelWriter {
    url: String,
    ws: Arc<Mutex<Option<WsWriter>>>,
}

impl ChannelWriter {
    pub fn new(engine_ws_base: &str, channel_ref: &StreamChannelRef) -> Self {
        Self {
            url: build_channel_url(
                engine_ws_base,
                &channel_ref.channel_id,
                &channel_ref.access_key,
                "write",
            ),
            ws: Arc::new(Mutex::new(None)),
        }
    }

    async fn ensure_connected(&self) -> Result<(), IIIError> {
        let mut guard = self.ws.lock().await;
        if guard.is_some() {
            return Ok(());
        }
        let (stream, _) = connect_async(&self.url).await?;
        let (writer, _reader) = stream.split();
        *guard = Some(writer);
        Ok(())
    }

    const MAX_FRAME_SIZE: usize = 64 * 1024;

    pub async fn write(&self, data: &[u8]) -> Result<(), IIIError> {
        self.ensure_connected().await?;
        let mut guard = self.ws.lock().await;
        let ws = guard.as_mut().ok_or(IIIError::NotConnected)?;
        for chunk in data.chunks(Self::MAX_FRAME_SIZE) {
            ws.send(WsMessage::Binary(chunk.to_vec().into())).await?;
        }
        Ok(())
    }

    pub async fn send_message(&self, msg: &str) -> Result<(), IIIError> {
        self.ensure_connected().await?;
        let mut guard = self.ws.lock().await;
        let ws = guard.as_mut().ok_or(IIIError::NotConnected)?;
        ws.send(WsMessage::Text(msg.to_string().into())).await?;
        Ok(())
    }

    pub async fn close(&self) -> Result<(), IIIError> {
        let mut guard = self.ws.lock().await;
        if let Some(ws) = guard.as_mut() {
            ws.send(WsMessage::Close(None)).await?;
        }
        *guard = None;
        Ok(())
    }
}

type MessageCallback = Box<dyn Fn(String) + Send + Sync>;
type MessageCallbackList = Arc<Mutex<Vec<MessageCallback>>>;

/// WebSocket-backed reader for streaming binary data and text messages.
pub struct ChannelReader {
    url: String,
    ws: Arc<Mutex<Option<WsReader>>>,
    message_callbacks: MessageCallbackList,
}

impl ChannelReader {
    pub fn new(engine_ws_base: &str, channel_ref: &StreamChannelRef) -> Self {
        Self {
            url: build_channel_url(
                engine_ws_base,
                &channel_ref.channel_id,
                &channel_ref.access_key,
                "read",
            ),
            ws: Arc::new(Mutex::new(None)),
            message_callbacks: Arc::new(Mutex::new(Vec::new())),
        }
    }

    async fn ensure_connected(&self) -> Result<(), IIIError> {
        let mut guard = self.ws.lock().await;
        if guard.is_some() {
            return Ok(());
        }
        let (stream, _) = connect_async(&self.url).await?;
        let (_writer, reader) = stream.split();
        *guard = Some(reader);
        Ok(())
    }

    /// Register a callback for text messages received on this channel.
    pub async fn on_message<F>(&self, callback: F)
    where
        F: Fn(String) + Send + Sync + 'static,
    {
        self.message_callbacks.lock().await.push(Box::new(callback));
    }

    /// Read the next binary chunk from the channel.
    /// Text messages are dispatched to registered callbacks.
    /// Returns `None` when the stream is closed.
    pub async fn next_binary(&self) -> Result<Option<Vec<u8>>, IIIError> {
        self.ensure_connected().await?;

        loop {
            let mut guard = self.ws.lock().await;
            let mut reader = guard.take().ok_or(IIIError::NotConnected)?;
            drop(guard);

            let msg = reader.next().await;

            let mut guard = self.ws.lock().await;
            *guard = Some(reader);
            drop(guard);

            match msg {
                Some(Ok(WsMessage::Binary(data))) => return Ok(Some(data.to_vec())),
                Some(Ok(WsMessage::Text(text))) => {
                    let callbacks = self.message_callbacks.lock().await;
                    for cb in callbacks.iter() {
                        cb(text.to_string());
                    }
                }
                Some(Ok(WsMessage::Close(_))) | None => return Ok(None),
                Some(Ok(_)) => continue,
                Some(Err(e)) => return Err(IIIError::WebSocket(e.to_string())),
            }
        }
    }

    /// Read the entire stream into a single Vec<u8>.
    pub async fn read_all(&self) -> Result<Vec<u8>, IIIError> {
        let mut buffer = Vec::new();
        while let Some(chunk) = self.next_binary().await? {
            buffer.extend_from_slice(&chunk);
        }
        Ok(buffer)
    }

    pub async fn close(&self) -> Result<(), IIIError> {
        let mut guard = self.ws.lock().await;
        *guard = None;
        Ok(())
    }
}

/// Check if a JSON value looks like a StreamChannelRef.
pub fn is_channel_ref(value: &Value) -> bool {
    value.is_object()
        && value.get("channel_id").is_some_and(|v| v.is_string())
        && value.get("access_key").is_some_and(|v| v.is_string())
        && value.get("direction").is_some_and(|v| v.is_string())
}

/// Extract all channel references from a JSON value's top-level fields,
/// returning the field path and the deserialized ref.
pub fn extract_channel_refs(data: &Value) -> Vec<(String, StreamChannelRef)> {
    let mut refs = Vec::new();
    extract_refs_recursive(data, String::new(), &mut refs);
    refs
}

fn extract_refs_recursive(
    data: &Value,
    prefix: String,
    refs: &mut Vec<(String, StreamChannelRef)>,
) {
    if let Some(obj) = data.as_object() {
        for (key, value) in obj {
            let path = if prefix.is_empty() {
                key.clone()
            } else {
                format!("{prefix}.{key}")
            };

            if is_channel_ref(value) {
                if let Ok(channel_ref) = serde_json::from_value::<StreamChannelRef>(value.clone()) {
                    refs.push((path, channel_ref));
                }
            } else if value.is_object() {
                extract_refs_recursive(value, path.clone(), refs);
            } else if let Some(arr) = value.as_array() {
                for (idx, item) in arr.iter().enumerate() {
                    extract_refs_recursive(item, format!("{path}[{idx}]"), refs);
                }
            }
        }
    } else if let Some(arr) = data.as_array() {
        for (idx, item) in arr.iter().enumerate() {
            let path = if prefix.is_empty() {
                format!("[{idx}]")
            } else {
                format!("{prefix}[{idx}]")
            };
            extract_refs_recursive(item, path, refs);
        }
    }
}