runagent 0.1.49

RunAgent SDK for Rust - Client SDK for interacting with deployed AI agents
Documentation
//! WebSocket client for streaming agent interactions

use crate::types::{MessageType, RunAgentError, RunAgentResult, SafeMessage};
use crate::utils::config::Config;
use crate::utils::serializer::CoreSerializer;
use futures::{SinkExt, Stream, StreamExt};
use serde_json::Value;
use std::collections::HashMap;
use std::pin::Pin;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use url::Url;

/// WebSocket client for agent streaming
pub struct SocketClient {
    base_socket_url: String,
    api_key: Option<String>,
    api_prefix: String,
    serializer: CoreSerializer,
}

impl SocketClient {
    /// Create a new WebSocket client
    pub fn new(
        base_socket_url: &str,
        api_key: Option<String>,
        api_prefix: Option<&str>,
    ) -> RunAgentResult<Self> {
        let serializer = CoreSerializer::new(10.0)?;

        Ok(Self {
            base_socket_url: base_socket_url.trim_end_matches('/').to_string(),
            api_key,
            api_prefix: api_prefix.unwrap_or("/api/v1").to_string(),
            serializer,
        })
    }

    /// Create a default WebSocket client using configuration
    #[allow(clippy::should_implement_trait)]
    pub fn default() -> RunAgentResult<Self> {
        let config = Config::load()?;
        let base_url = config.base_url();

        // Convert HTTP URL to WebSocket URL
        let ws_url = if base_url.starts_with("https://") {
            base_url.replace("https://", "wss://")
        } else if base_url.starts_with("http://") {
            base_url.replace("http://", "ws://")
        } else {
            format!("ws://{}", base_url)
        };

        Self::new(&ws_url, config.api_key(), Some("/api/v1"))
    }

    fn get_websocket_url(&self, agent_id: &str, _entrypoint_tag: &str) -> RunAgentResult<Url> {
        let path = format!("agents/{}/run-stream", agent_id);
        let mut full_url = format!("{}{}/{}", self.base_socket_url, self.api_prefix, path);

        // Add API key as token parameter if available
        if let Some(ref api_key) = self.api_key {
            full_url = format!("{}?token={}", full_url, api_key);
        }

        Url::parse(&full_url)
            .map_err(|e| RunAgentError::validation(format!("Invalid WebSocket URL: {}", e)))
    }

    /// Run agent with streaming response
    pub async fn run_stream(
        &self,
        agent_id: &str,
        entrypoint_tag: &str,
        input_args: &[Value],
        input_kwargs: &HashMap<String, Value>,
        user_id: Option<&str>,
        persistent_memory: bool,
    ) -> RunAgentResult<Pin<Box<dyn Stream<Item = RunAgentResult<Value>> + Send>>> {
        let url = self.get_websocket_url(agent_id, entrypoint_tag)?;

        tracing::debug!("Connecting to WebSocket: {}", url);

        // Connect to WebSocket
        let (ws_stream, _) = connect_async(url).await.map_err(|e| {
            RunAgentError::connection(format!("WebSocket connection failed: {}", e))
        })?;

        let (mut write, mut read) = ws_stream.split();

        // Prepare start stream request with id field (as middleware expects)
        let mut request_data = serde_json::json!({
            "id": "stream_start",
            "entrypoint_tag": entrypoint_tag,
            "input_args": input_args,
            "input_kwargs": input_kwargs,
            "timeout_seconds": 600,
            "async_execution": false
        });

        // Add persistent storage parameters if provided (matches Python SDK)
        if let Some(uid) = user_id {
            if let Some(obj) = request_data.as_object_mut() {
                obj.insert("user_id".to_string(), serde_json::json!(uid));
            }
        }
        if persistent_memory {
            if let Some(obj) = request_data.as_object_mut() {
                obj.insert(
                    "persistent_memory".to_string(),
                    serde_json::json!(persistent_memory),
                );
            }
        }

        // Send the request data directly (matching Python SDK format)
        let serialized_msg = serde_json::to_string(&request_data)?;
        write
            .send(Message::Text(serialized_msg))
            .await
            .map_err(|e| {
                RunAgentError::connection(format!("Failed to send start message: {}", e))
            })?;

        // Clone serializer for use in async stream
        let serializer = self.serializer.clone();

        // Create stream that processes incoming messages (matching Python SDK behavior)
        let stream = async_stream::stream! {
            while let Some(message) = read.next().await {
                match message {
                    Ok(Message::Text(text)) => {
                        // Parse as plain JSON (matching Python SDK)
                        match serde_json::from_str::<serde_json::Value>(&text) {
                            Ok(msg) => {
                                let message_type = msg.get("type").and_then(|v| v.as_str());

                                match message_type {
                                    Some("status") => {
                                        if let Some(status) = msg.get("status").and_then(|v| v.as_str()) {
                                            if status == "stream_completed" {
                                                break;
                                            } else if status == "stream_started" {
                                                continue; // Skip status messages
                                            }
                                        }
                                    }
                                    Some("error") => {
                                        let error_msg = msg.get("error")
                                            .or_else(|| msg.get("detail"))
                                            .and_then(|v| v.as_str())
                                            .unwrap_or("Unknown error");
                                        yield Err(RunAgentError::server(format!("Stream error: {}", error_msg)));
                                        break;
                                    }
                                    Some("data") => {
                                        // Extract content and deserialize it using the common deserializer
                                        if let Some(content) = msg.get("content") {
                                            // Use common deserializer preparation logic (handles JSON strings)
                                            let prepared = serializer.prepare_for_deserialization(content.clone());

                                            // Deserialize using the common serializer (handles {type, payload} structure)
                                            match serializer.deserialize_object(prepared) {
                                                Ok(deserialized) => yield Ok(deserialized),
                                                Err(e) => {
                                                    yield Err(RunAgentError::server(format!("Deserialization error: {}", e)));
                                                    break;
                                                }
                                            }
                                        } else {
                                            // If no content, yield the whole message
                                            yield Ok(msg);
                                        }
                                    }
                                    _ => {
                                        // For other message types, yield the whole message
                                        yield Ok(msg);
                                    }
                                }
                            }
                            Err(e) => {
                                yield Err(RunAgentError::server(format!("Stream error: JSON error: {}", e)));
                                break;
                            }
                        }
                    }
                    Ok(Message::Close(_)) => {
                        break;
                    }
                    Ok(_) => {
                        // Ignore binary and other message types
                        continue;
                    }
                    Err(e) => {
                        yield Err(RunAgentError::connection(format!("WebSocket error: {}", e)));
                        break;
                    }
                }
            }
        };

        Ok(Box::pin(stream))
    }

    /// Send a ping message to test connection
    pub async fn ping(&self, agent_id: &str, entrypoint_tag: &str) -> RunAgentResult<bool> {
        let url = self.get_websocket_url(agent_id, entrypoint_tag)?;

        let (ws_stream, _) = connect_async(url).await.map_err(|e| {
            RunAgentError::connection(format!("WebSocket connection failed: {}", e))
        })?;

        let (mut write, mut read) = ws_stream.split();

        // Send ping
        let ping_msg = SafeMessage::new(
            "ping".to_string(),
            MessageType::Status,
            serde_json::json!({"ping": true}),
        );

        let serialized_msg = self.serializer.serialize_message(&ping_msg)?;
        write
            .send(Message::Text(serialized_msg))
            .await
            .map_err(|e| RunAgentError::connection(format!("Failed to send ping: {}", e)))?;

        // Wait for pong
        tokio::time::timeout(std::time::Duration::from_secs(5), read.next())
            .await
            .map_err(|_| RunAgentError::connection("Ping timeout".to_string()))?
            .ok_or_else(|| RunAgentError::connection("Connection closed during ping".to_string()))?
            .map_err(|e| RunAgentError::connection(format!("Ping failed: {}", e)))?;

        Ok(true)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_websocket_url_construction() {
        let client = SocketClient::new("ws://localhost:8000", None, Some("/api/v1")).unwrap();
        let url = client.get_websocket_url("test-agent", "generic").unwrap();
        // Updated expected URL to match the actual implementation
        assert_eq!(
            url.as_str(),
            "ws://localhost:8000/api/v1/agents/test-agent/run-stream"
        );
    }

    #[test]
    fn test_client_creation() {
        let client = SocketClient::new("ws://localhost:8000", None, None);
        assert!(client.is_ok());
    }

    #[test]
    fn test_url_conversion() {
        // Test HTTP to WebSocket URL conversion
        let client = SocketClient::default();
        // This would test the URL conversion logic in a real implementation
        assert!(client.is_ok());
    }
}