1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
//! WebSocket handler for real-time agent chat and event streaming.
//!
//! Protocol:
//! - Client sends: `{"type":"subscribe","agent_id":"..."}` to watch an agent
//! - Client sends: `{"type":"message","agent_id":"...","content":"..."}` to send
//! - Server sends: `{"type":"agent_event","agent_id":"...","event":{...}}`
//! - Server sends: `{"type":"agent_message","agent_id":"...","content":"..."}`
//! - Heartbeat ping/pong every 30s
use std::sync::Arc;
use std::time::Duration;
use axum::extract::ws::{Message, WebSocket};
use tokio::sync::broadcast;
use crate::sse::{EventBus, IpcEvent};
#[derive(serde::Deserialize)]
struct WsClientMsg {
#[serde(rename = "type")]
msg_type: String,
agent_id: Option<String>,
content: Option<String>,
}
#[derive(serde::Serialize)]
struct WsServerMsg {
#[serde(rename = "type")]
msg_type: String,
agent_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
event: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
}
/// Max inbound WebSocket message size (64 KiB).
const MAX_WS_MSG_BYTES: usize = 65_536;
/// Handle an upgraded WebSocket connection.
pub async fn handle_ws(mut socket: WebSocket, bus: Arc<EventBus>) {
let mut rx = bus.subscribe();
let mut agent_filter: Option<String> = None;
let mut heartbeat = tokio::time::interval(Duration::from_secs(30));
loop {
tokio::select! {
// Incoming message from client
maybe_msg = socket.recv() => {
match maybe_msg {
Some(Ok(Message::Text(text))) => {
if text.len() > MAX_WS_MSG_BYTES {
continue;
}
if let Ok(msg) = serde_json::from_str::<WsClientMsg>(&text) {
match msg.msg_type.as_str() {
"subscribe" => {
agent_filter = msg.agent_id;
}
"message" => {
if let (Some(aid), Some(content)) = (msg.agent_id, msg.content) {
bus.publish(IpcEvent {
from: "ws-client".into(),
to: Some(aid),
content,
event_type: "direct".into(),
ts: chrono::Utc::now().to_rfc3339(),
});
}
}
_ => {}
}
}
}
Some(Ok(Message::Close(_))) | None => break,
_ => {}
}
}
// Outgoing events from EventBus
result = rx.recv() => {
match result {
Ok(event) => {
if let Some(ref filter) = agent_filter {
let matches = event.from == *filter
|| event.to.as_deref() == Some(filter);
if !matches { continue; }
}
let server_msg = WsServerMsg {
msg_type: "agent_event".into(),
agent_id: event.from.clone(),
event: serde_json::from_str(&event.content).ok(),
content: Some(event.content),
};
let json = serde_json::to_string(&server_msg).unwrap_or_default();
if socket.send(Message::Text(json)).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => break,
}
}
// Heartbeat
_ = heartbeat.tick() => {
if socket.send(Message::Ping(vec![])).await.is_err() {
break;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ws_client_msg_deserialize() {
let json = r#"{"type":"subscribe","agent_id":"copilot-1"}"#;
let msg: WsClientMsg = serde_json::from_str(json).unwrap();
assert_eq!(msg.msg_type, "subscribe");
assert_eq!(msg.agent_id.as_deref(), Some("copilot-1"));
}
#[test]
fn ws_server_msg_serialize() {
let msg = WsServerMsg {
msg_type: "agent_event".into(),
agent_id: "copilot-1".into(),
event: None,
content: Some("hello".into()),
};
let json = serde_json::to_string(&msg).unwrap();
assert!(json.contains("agent_event"));
assert!(json.contains("copilot-1"));
}
}