syncable-ag-ui-client 0.2.0

Client-side AG-UI event consumer for Rust applications - Syncable SDK
Documentation
//! WebSocket Client
//!
//! This module provides a client for consuming AG-UI events via WebSocket.
//!
//! # Example
//!
//! ```rust,ignore
//! use ag_ui_client::WsClient;
//! use futures::StreamExt;
//!
//! let client = WsClient::connect("ws://localhost:3000/ws").await?;
//! let mut stream = client.into_stream();
//!
//! while let Some(event) = stream.next().await {
//!     println!("Event: {:?}", event?.event_type());
//! }
//! ```

use std::pin::Pin;
use std::task::{Context, Poll};

use syncable_ag_ui_core::{Event, JsonValue};
use futures::{SinkExt, Stream};
use tokio_tungstenite::{
    connect_async,
    tungstenite::{self, Message},
    MaybeTlsStream, WebSocketStream,
};

use crate::error::{ClientError, Result};

/// Configuration for WebSocket client connections.
#[derive(Debug, Clone)]
pub struct WsConfig {
    /// Custom headers to include in the upgrade request.
    pub headers: Vec<(String, String)>,
    /// Whether to automatically respond to ping messages.
    pub auto_pong: bool,
}

impl Default for WsConfig {
    fn default() -> Self {
        Self {
            headers: Vec::new(),
            auto_pong: true,
        }
    }
}

impl WsConfig {
    /// Creates a new configuration with default values.
    pub fn new() -> Self {
        Self::default()
    }

    /// Adds a custom header.
    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
        self.headers.push((name.into(), value.into()));
        self
    }

    /// Adds an authorization bearer token.
    pub fn bearer_token(self, token: impl Into<String>) -> Self {
        self.header("Authorization", format!("Bearer {}", token.into()))
    }

    /// Disables automatic pong responses.
    pub fn disable_auto_pong(mut self) -> Self {
        self.auto_pong = false;
        self
    }
}

/// WebSocket client for consuming AG-UI event streams.
///
/// The client connects to a WebSocket endpoint and provides a stream of
/// parsed AG-UI events.
pub struct WsClient {
    socket: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
    auto_pong: bool,
}

impl WsClient {
    /// Connects to a WebSocket endpoint with default configuration.
    ///
    /// # Arguments
    ///
    /// * `url` - The WebSocket endpoint URL (ws:// or wss://)
    ///
    /// # Example
    ///
    /// ```rust,ignore
    /// let client = WsClient::connect("ws://localhost:3000/ws").await?;
    /// ```
    pub async fn connect(url: &str) -> Result<Self> {
        Self::connect_with_config(url, WsConfig::default()).await
    }

    /// Connects to a WebSocket endpoint with custom configuration.
    ///
    /// # Arguments
    ///
    /// * `url` - The WebSocket endpoint URL (ws:// or wss://)
    /// * `config` - Connection configuration
    ///
    /// # Example
    ///
    /// ```rust,ignore
    /// let config = WsConfig::new()
    ///     .bearer_token("my-token");
    /// let client = WsClient::connect_with_config("ws://localhost:3000/ws", config).await?;
    /// ```
    pub async fn connect_with_config(url: &str, config: WsConfig) -> Result<Self> {
        // Build the request with custom headers
        let mut request = tungstenite::http::Request::builder()
            .uri(url)
            .header("Host", extract_host(url)?)
            .header("Connection", "Upgrade")
            .header("Upgrade", "websocket")
            .header("Sec-WebSocket-Version", "13")
            .header(
                "Sec-WebSocket-Key",
                tungstenite::handshake::client::generate_key(),
            );

        for (name, value) in config.headers {
            request = request.header(name, value);
        }

        let request = request
            .body(())
            .map_err(|e| ClientError::connection(e.to_string()))?;

        let (socket, _response) = connect_async(request)
            .await
            .map_err(|e| ClientError::connection(e.to_string()))?;

        Ok(Self {
            socket,
            auto_pong: config.auto_pong,
        })
    }

    /// Converts this client into an event stream.
    ///
    /// The stream yields parsed AG-UI events as they arrive.
    pub fn into_stream(self) -> WsEventStream {
        WsEventStream {
            socket: self.socket,
            auto_pong: self.auto_pong,
        }
    }

    /// Closes the WebSocket connection gracefully.
    pub async fn close(mut self) -> Result<()> {
        self.socket
            .close(None)
            .await
            .map_err(|e| ClientError::connection(e.to_string()))
    }
}

/// A stream of AG-UI events from a WebSocket connection.
///
/// This stream yields `Result<Event>` items as events arrive from the server.
pub struct WsEventStream {
    socket: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
    auto_pong: bool,
}

impl Stream for WsEventStream {
    type Item = Result<Event<JsonValue>>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        loop {
            match Pin::new(&mut self.socket).poll_next(cx) {
                Poll::Ready(Some(Ok(msg))) => {
                    match msg {
                        Message::Text(text) => {
                            // Parse the event data as JSON
                            match serde_json::from_str::<Event<JsonValue>>(&text) {
                                Ok(event) => return Poll::Ready(Some(Ok(event))),
                                Err(e) => {
                                    return Poll::Ready(Some(Err(ClientError::parse(format!(
                                        "failed to parse event: {}",
                                        e
                                    )))))
                                }
                            }
                        }
                        Message::Ping(data) => {
                            if self.auto_pong {
                                // Send pong response
                                let mut socket = Pin::new(&mut self.socket);
                                let _ = socket.start_send_unpin(Message::Pong(data));
                            }
                            continue;
                        }
                        Message::Pong(_) => {
                            // Ignore pong messages
                            continue;
                        }
                        Message::Close(_) => {
                            return Poll::Ready(None);
                        }
                        Message::Binary(_) | Message::Frame(_) => {
                            // Ignore binary/frame messages for AG-UI
                            continue;
                        }
                    }
                }
                Poll::Ready(Some(Err(e))) => {
                    return Poll::Ready(Some(Err(ClientError::WebSocket(e))))
                }
                Poll::Ready(None) => return Poll::Ready(None),
                Poll::Pending => return Poll::Pending,
            }
        }
    }
}

/// Extracts the host from a URL for the Host header.
fn extract_host(url: &str) -> Result<String> {
    let url = url::Url::parse(url).map_err(|e| ClientError::InvalidUrl(e.to_string()))?;

    let host = url
        .host_str()
        .ok_or_else(|| ClientError::InvalidUrl("missing host".to_string()))?;

    match url.port() {
        Some(port) => Ok(format!("{}:{}", host, port)),
        None => Ok(host.to_string()),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_ws_config_default() {
        let config = WsConfig::default();
        assert!(config.headers.is_empty());
        assert!(config.auto_pong);
    }

    #[test]
    fn test_ws_config_builder() {
        let config = WsConfig::new()
            .header("X-Custom", "value")
            .bearer_token("token123")
            .disable_auto_pong();

        assert_eq!(config.headers.len(), 2);
        assert_eq!(config.headers[0], ("X-Custom".to_string(), "value".to_string()));
        assert_eq!(
            config.headers[1],
            ("Authorization".to_string(), "Bearer token123".to_string())
        );
        assert!(!config.auto_pong);
    }

    #[test]
    fn test_extract_host() {
        assert_eq!(extract_host("ws://localhost:3000/ws").unwrap(), "localhost:3000");
        assert_eq!(extract_host("wss://example.com/events").unwrap(), "example.com");
        assert_eq!(
            extract_host("ws://api.example.com:8080/stream").unwrap(),
            "api.example.com:8080"
        );
    }

    #[test]
    fn test_extract_host_invalid() {
        assert!(extract_host("not a url").is_err());
    }
}