armature_websocket/
server.rs1use crate::connection::{Connection, ConnectionWriter};
4use crate::error::{WebSocketError, WebSocketResult};
5use crate::handler::WebSocketHandler;
6use crate::message::Message;
7use crate::room::RoomManager;
8use futures_util::StreamExt;
9use std::net::SocketAddr;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::net::{TcpListener, TcpStream};
13use tokio::sync::mpsc;
14use tokio_tungstenite::accept_async;
15
16#[derive(Debug, Clone)]
18pub struct WebSocketServerConfig {
19 pub bind_addr: SocketAddr,
21 pub max_message_size: usize,
23 pub heartbeat_interval: Duration,
25 pub connection_timeout: Duration,
27}
28
29impl Default for WebSocketServerConfig {
30 fn default() -> Self {
31 Self {
32 bind_addr: "0.0.0.0:9001".parse().unwrap(),
33 max_message_size: 64 * 1024, heartbeat_interval: Duration::from_secs(30),
35 connection_timeout: Duration::from_secs(60),
36 }
37 }
38}
39
40#[derive(Debug, Default)]
42pub struct WebSocketServerBuilder {
43 config: WebSocketServerConfig,
44}
45
46impl WebSocketServerBuilder {
47 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
54 self.config.bind_addr = addr;
55 self
56 }
57
58 pub fn bind(mut self, addr: &str) -> WebSocketResult<Self> {
60 self.config.bind_addr = addr
61 .parse()
62 .map_err(|e| WebSocketError::Server(format!("Invalid address: {}", e)))?;
63 Ok(self)
64 }
65
66 pub fn max_message_size(mut self, size: usize) -> Self {
68 self.config.max_message_size = size;
69 self
70 }
71
72 pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
74 self.config.heartbeat_interval = interval;
75 self
76 }
77
78 pub fn connection_timeout(mut self, timeout: Duration) -> Self {
80 self.config.connection_timeout = timeout;
81 self
82 }
83
84 pub fn build<H: WebSocketHandler>(self, handler: H) -> WebSocketServer<H> {
86 WebSocketServer::new(self.config, handler)
87 }
88}
89
90pub struct WebSocketServer<H: WebSocketHandler> {
92 config: WebSocketServerConfig,
93 handler: Arc<H>,
94 room_manager: Arc<RoomManager>,
95}
96
97impl<H: WebSocketHandler> WebSocketServer<H> {
98 pub fn new(config: WebSocketServerConfig, handler: H) -> Self {
100 Self {
101 config,
102 handler: Arc::new(handler),
103 room_manager: Arc::new(RoomManager::new()),
104 }
105 }
106
107 pub fn builder() -> WebSocketServerBuilder {
109 WebSocketServerBuilder::new()
110 }
111
112 pub fn room_manager(&self) -> &Arc<RoomManager> {
114 &self.room_manager
115 }
116
117 pub async fn run(&self) -> WebSocketResult<()> {
119 let listener = TcpListener::bind(self.config.bind_addr).await?;
120 tracing::info!(addr = %self.config.bind_addr, "WebSocket server listening");
121
122 loop {
123 match listener.accept().await {
124 Ok((stream, addr)) => {
125 let handler = Arc::clone(&self.handler);
126 let room_manager = Arc::clone(&self.room_manager);
127 let config = self.config.clone();
128
129 tokio::spawn(async move {
130 if let Err(e) =
131 Self::handle_connection(stream, addr, handler, room_manager, config)
132 .await
133 {
134 tracing::error!(addr = %addr, error = %e, "Connection error");
135 }
136 });
137 }
138 Err(e) => {
139 tracing::error!(error = %e, "Failed to accept connection");
140 }
141 }
142 }
143 }
144
145 async fn handle_connection(
147 stream: TcpStream,
148 addr: SocketAddr,
149 handler: Arc<H>,
150 room_manager: Arc<RoomManager>,
151 _config: WebSocketServerConfig,
152 ) -> WebSocketResult<()> {
153 let ws_stream = accept_async(stream).await?;
154 let connection_id = uuid::Uuid::new_v4().to_string();
155
156 tracing::debug!(connection_id = %connection_id, addr = %addr, "WebSocket connection established");
157
158 let (write, mut read) = ws_stream.split();
160
161 let (tx, rx) = mpsc::unbounded_channel();
163
164 let connection = Connection::new(connection_id.clone(), Some(addr), tx);
166
167 room_manager.register_connection(connection.clone());
169
170 handler.on_connect(&connection_id).await;
172
173 let writer = ConnectionWriter::new(write, rx);
175 let writer_handle = tokio::spawn(async move { writer.run().await });
176
177 while let Some(result) = read.next().await {
179 match result {
180 Ok(msg) => {
181 if msg.is_close() {
182 break;
183 }
184
185 let message: Message = msg.into();
186
187 if message.is_ping() {
189 let pong_payload = handler.on_ping(&connection_id, message.as_bytes()).await;
190 let _ = connection.send(Message::pong(pong_payload));
191 continue;
192 }
193
194 if message.is_pong() {
195 handler.on_pong(&connection_id, message.as_bytes()).await;
196 continue;
197 }
198
199 handler.on_message(&connection_id, message).await;
201 }
202 Err(e) => {
203 let ws_error = WebSocketError::Protocol(e);
204 handler.on_error(&connection_id, &ws_error).await;
205 break;
206 }
207 }
208 }
209
210 connection.close();
212
213 let _ = writer_handle.await;
215
216 handler.on_disconnect(&connection_id).await;
218
219 room_manager.unregister_connection(&connection_id);
221
222 tracing::debug!(connection_id = %connection_id, "WebSocket connection closed");
223
224 Ok(())
225 }
226}
227