use super::router::MessageRouter;
use super::types::{ChannelType, InboundMessage, OutboundMessage};
use crate::agent::AgentManager;
use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade},
extract::State,
response::IntoResponse,
};
use futures::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Serialize, Deserialize)]
pub struct WsMessage {
pub text: String,
pub session_id: Option<String>,
pub metadata: Option<serde_json::Value>,
}
pub struct WsChannel {
router: Arc<MessageRouter>,
}
impl WsChannel {
pub fn new(agent_manager: Arc<RwLock<AgentManager>>) -> Self {
Self {
router: Arc::new(MessageRouter::new(agent_manager)),
}
}
pub async fn handler(
ws: WebSocketUpgrade,
State(channel): State<Arc<Self>>,
) -> impl IntoResponse {
ws.on_upgrade(move |socket| Self::handle_socket(socket, channel))
}
async fn handle_socket(socket: WebSocket, channel: Arc<Self>) {
let (mut sender, mut receiver) = socket.split();
let session_id: Arc<str> = Arc::from(uuid::Uuid::new_v4().to_string());
tracing::info!("WebSocket connected: {}", session_id);
while let Some(msg) = receiver.next().await {
match msg {
Ok(Message::Text(text)) => {
let ws_msg: WsMessage = match serde_json::from_str(&text) {
Ok(m) => m,
Err(e) => {
let error = OutboundMessage::error(format!("Invalid message: {}", e));
let _ = sender.send(Message::Text(
serde_json::to_string(&error).unwrap_or_default()
)).await;
continue;
}
};
let sid = ws_msg.session_id
.map(|s| Arc::from(s.as_str()))
.unwrap_or_else(|| session_id.clone());
let inbound = InboundMessage {
text: Arc::from(ws_msg.text.as_str()),
session_id: sid,
channel: ChannelType::WebSocket,
user_id: None,
metadata: ws_msg.metadata,
};
match channel.router.route(inbound).await {
Ok(response) => {
let _ = sender.send(Message::Text(
serde_json::to_string(&response).unwrap_or_default()
)).await;
}
Err(e) => {
let error = OutboundMessage::error(e.to_string());
let _ = sender.send(Message::Text(
serde_json::to_string(&error).unwrap_or_default()
)).await;
}
}
}
Ok(Message::Close(_)) => {
tracing::info!("WebSocket disconnected: {}", session_id);
break;
}
Err(e) => {
tracing::error!("WebSocket error: {}", e);
break;
}
_ => {}
}
}
tracing::info!("WebSocket session ended: {}", session_id);
}
}