rsclaw 0.0.1-alpha.1

rsclaw: High-performance AI agent (BETA). Optimized for M4 Max and 2GB VPS. 100% compatible with openclaw
Documentation
use super::router::MessageRouter;
use super::types::{ChannelType, InboundMessage, OutboundMessage};
use crate::agent::AgentManager;
use axum::{
    extract::ws::{Message, WebSocket, WebSocketUpgrade},
    extract::State,
    response::IntoResponse,
};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;

/// WebSocket message format.
#[derive(Debug, Serialize, Deserialize)]
pub struct WsMessage {
    pub text: String,
    pub session_id: Option<String>,
    pub metadata: Option<serde_json::Value>,
}

/// WebSocket channel handler.
pub struct WsChannel {
    router: Arc<MessageRouter>,
}

impl WsChannel {
    /// Create a new WebSocket channel.
    pub fn new(agent_manager: Arc<RwLock<AgentManager>>) -> Self {
        Self {
            router: Arc::new(MessageRouter::new(agent_manager)),
        }
    }

    /// Handle WebSocket upgrade.
    pub async fn handler(
        ws: WebSocketUpgrade,
        State(channel): State<Arc<Self>>,
    ) -> impl IntoResponse {
        ws.on_upgrade(move |socket| Self::handle_socket(socket, channel))
    }

    async fn handle_socket(socket: WebSocket, channel: Arc<Self>) {
        let (mut sender, mut receiver) = socket.split();
        let session_id: Arc<str> = Arc::from(uuid::Uuid::new_v4().to_string());

        tracing::info!("WebSocket connected: {}", session_id);

        while let Some(msg) = receiver.next().await {
            match msg {
                Ok(Message::Text(text)) => {
                    let ws_msg: WsMessage = match serde_json::from_str(&text) {
                        Ok(m) => m,
                        Err(e) => {
                            let error = OutboundMessage::error(format!("Invalid message: {}", e));
                            let _ = sender.send(Message::Text(
                                serde_json::to_string(&error).unwrap_or_default()
                            )).await;
                            continue;
                        }
                    };

                    let sid = ws_msg.session_id
                        .map(|s| Arc::from(s.as_str()))
                        .unwrap_or_else(|| session_id.clone());

                    let inbound = InboundMessage {
                        text: Arc::from(ws_msg.text.as_str()),
                        session_id: sid,
                        channel: ChannelType::WebSocket,
                        user_id: None,
                        metadata: ws_msg.metadata,
                    };

                    match channel.router.route(inbound).await {
                        Ok(response) => {
                            let _ = sender.send(Message::Text(
                                serde_json::to_string(&response).unwrap_or_default()
                            )).await;
                        }
                        Err(e) => {
                            let error = OutboundMessage::error(e.to_string());
                            let _ = sender.send(Message::Text(
                                serde_json::to_string(&error).unwrap_or_default()
                            )).await;
                        }
                    }
                }
                Ok(Message::Close(_)) => {
                    tracing::info!("WebSocket disconnected: {}", session_id);
                    break;
                }
                Err(e) => {
                    tracing::error!("WebSocket error: {}", e);
                    break;
                }
                _ => {}
            }
        }

        tracing::info!("WebSocket session ended: {}", session_id);
    }
}