oxigdal_websocket/server/
ws_server.rs1use crate::error::{Error, Result};
4use crate::protocol::{MessageFormat, ProtocolCodec, ProtocolConfig};
5use crate::server::connection::Connection;
6use crate::server::heartbeat::{HeartbeatConfig, HeartbeatMonitor};
7use crate::server::manager::ConnectionManager;
8use crate::server::pool::{ConnectionPool, PoolConfig};
9use crate::server::{DEFAULT_MAX_CONNECTIONS, DEFAULT_MAX_MESSAGE_SIZE};
10use std::net::SocketAddr;
11use std::sync::Arc;
12use tokio::net::{TcpListener, TcpStream};
13use tokio::sync::RwLock;
14use tokio_tungstenite::accept_async;
15
16#[derive(Debug, Clone)]
18pub struct ServerConfig {
19 pub bind_addr: SocketAddr,
21 pub max_connections: usize,
23 pub max_message_size: usize,
25 pub protocol: ProtocolConfig,
27 pub heartbeat: HeartbeatConfig,
29 pub pool: PoolConfig,
31}
32
33impl Default for ServerConfig {
34 fn default() -> Self {
35 Self {
36 bind_addr: SocketAddr::from(([0, 0, 0, 0], 9001)),
37 max_connections: DEFAULT_MAX_CONNECTIONS,
38 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
39 protocol: ProtocolConfig::default(),
40 heartbeat: HeartbeatConfig::default(),
41 pool: PoolConfig::default(),
42 }
43 }
44}
45
46pub struct Server {
48 config: ServerConfig,
49 manager: Arc<ConnectionManager>,
50 pool: Arc<ConnectionPool>,
51 heartbeat: Arc<HeartbeatMonitor>,
52 shutdown: Arc<RwLock<bool>>,
53}
54
55impl Server {
56 pub fn new(config: ServerConfig) -> Self {
58 let manager = Arc::new(ConnectionManager::new(config.max_connections));
59 let pool = Arc::new(ConnectionPool::new(config.pool.clone()));
60 let heartbeat = Arc::new(HeartbeatMonitor::new(config.heartbeat.clone()));
61
62 Self {
63 config,
64 manager,
65 pool,
66 heartbeat,
67 shutdown: Arc::new(RwLock::new(false)),
68 }
69 }
70
71 pub fn builder() -> ServerBuilder {
73 ServerBuilder::new()
74 }
75
76 pub async fn start(self: Arc<Self>) -> Result<()> {
78 tracing::info!("Starting WebSocket server on {}", self.config.bind_addr);
79
80 let heartbeat = self.heartbeat.clone();
82 tokio::spawn(async move {
83 if let Err(e) = heartbeat.start().await {
84 tracing::error!("Heartbeat monitor error: {}", e);
85 }
86 });
87
88 let pool = self.pool.clone();
90 let shutdown = self.shutdown.clone();
91 tokio::spawn(async move {
92 let mut interval = tokio::time::interval(std::time::Duration::from_secs(60));
93 loop {
94 if *shutdown.read().await {
95 break;
96 }
97 interval.tick().await;
98 if let Err(e) = pool.maintain().await {
99 tracing::error!("Pool maintenance error: {}", e);
100 }
101 }
102 });
103
104 let listener = TcpListener::bind(&self.config.bind_addr)
106 .await
107 .map_err(|e| Error::Connection(format!("Failed to bind: {}", e)))?;
108
109 tracing::info!("Server listening on {}", self.config.bind_addr);
110
111 loop {
113 if *self.shutdown.read().await {
115 break;
116 }
117
118 match listener.accept().await {
119 Ok((stream, addr)) => {
120 let server = self.clone();
121 tokio::spawn(async move {
122 if let Err(e) = server.handle_connection(stream, addr).await {
123 tracing::error!("Connection error from {}: {}", addr, e);
124 }
125 });
126 }
127 Err(e) => {
128 tracing::error!("Accept error: {}", e);
129 }
130 }
131 }
132
133 tracing::info!("Server stopped");
134 Ok(())
135 }
136
137 async fn handle_connection(&self, stream: TcpStream, addr: SocketAddr) -> Result<()> {
139 tracing::info!("New connection from {}", addr);
140
141 let ws_stream = accept_async(stream)
143 .await
144 .map_err(|e| Error::WebSocket(format!("Handshake failed: {}", e)))?;
145
146 let codec = ProtocolCodec::new(self.config.protocol.clone());
148
149 let (connection, rx) = Connection::new(ws_stream, addr, codec);
151 let connection = Arc::new(connection);
152
153 self.manager.add(connection.clone())?;
155
156 self.heartbeat.add_connection(connection.clone()).await;
158
159 let conn = connection.clone();
161 tokio::spawn(async move {
162 if let Err(e) = conn.process_outgoing(rx).await {
163 tracing::error!("Outgoing message handler error: {}", e);
164 }
165 });
166
167 loop {
169 match connection.receive().await {
170 Ok(Some(message)) => {
171 tracing::debug!(
172 "Received message from {}: {:?}",
173 addr,
174 message.message_type()
175 );
176 }
178 Ok(None) => {
179 }
181 Err(e) => {
182 tracing::error!("Receive error from {}: {}", addr, e);
183 break;
184 }
185 }
186
187 let state = connection.state().await;
189 if state != crate::server::connection::ConnectionState::Connected {
190 break;
191 }
192 }
193
194 self.manager.remove(&connection.id());
196 self.heartbeat.remove_connection(&connection.id()).await;
197
198 tracing::info!("Connection from {} closed", addr);
199 Ok(())
200 }
201
202 pub async fn shutdown(&self) -> Result<()> {
204 tracing::info!("Shutting down server");
205
206 let mut shutdown = self.shutdown.write().await;
208 *shutdown = true;
209 drop(shutdown);
210
211 self.heartbeat.shutdown().await;
213
214 self.manager.close_all().await?;
216
217 self.pool.shutdown().await?;
219
220 tracing::info!("Server shutdown complete");
221 Ok(())
222 }
223
224 pub fn manager(&self) -> &Arc<ConnectionManager> {
226 &self.manager
227 }
228
229 pub fn pool(&self) -> &Arc<ConnectionPool> {
231 &self.pool
232 }
233
234 pub fn heartbeat(&self) -> &Arc<HeartbeatMonitor> {
236 &self.heartbeat
237 }
238
239 pub async fn stats(&self) -> ServerStats {
241 let manager_stats = self.manager.stats();
242 let pool_stats = self.pool.stats();
243 let heartbeat_stats = self.heartbeat.stats().await;
244
245 ServerStats {
246 active_connections: manager_stats.active_connections,
247 total_connections: manager_stats.total_connections,
248 messages_sent: manager_stats.messages_sent,
249 messages_received: manager_stats.messages_received,
250 bytes_sent: manager_stats.bytes_sent,
251 bytes_received: manager_stats.bytes_received,
252 errors: manager_stats.errors,
253 pool_idle: pool_stats.idle_connections,
254 pool_active: pool_stats.active_connections,
255 heartbeat_monitored: heartbeat_stats.monitored_connections,
256 }
257 }
258}
259
260#[derive(Debug, Clone)]
262pub struct ServerStats {
263 pub active_connections: usize,
265 pub total_connections: u64,
267 pub messages_sent: u64,
269 pub messages_received: u64,
271 pub bytes_sent: u64,
273 pub bytes_received: u64,
275 pub errors: u64,
277 pub pool_idle: usize,
279 pub pool_active: usize,
281 pub heartbeat_monitored: usize,
283}
284
285pub struct ServerBuilder {
287 config: ServerConfig,
288}
289
290impl ServerBuilder {
291 pub fn new() -> Self {
293 Self {
294 config: ServerConfig::default(),
295 }
296 }
297
298 pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
300 self.config.bind_addr = addr;
301 self
302 }
303
304 pub fn max_connections(mut self, max: usize) -> Self {
306 self.config.max_connections = max;
307 self
308 }
309
310 pub fn max_message_size(mut self, size: usize) -> Self {
312 self.config.max_message_size = size;
313 self
314 }
315
316 pub fn message_format(mut self, format: MessageFormat) -> Self {
318 self.config.protocol.format = format;
319 self
320 }
321
322 pub fn heartbeat_interval(mut self, secs: u64) -> Self {
324 self.config.heartbeat.interval_secs = secs;
325 self
326 }
327
328 pub fn enable_heartbeat(mut self, enabled: bool) -> Self {
330 self.config.heartbeat.enabled = enabled;
331 self
332 }
333
334 pub fn pool_config(mut self, config: PoolConfig) -> Self {
336 self.config.pool = config;
337 self
338 }
339
340 pub fn build(self) -> Arc<Server> {
342 Arc::new(Server::new(self.config))
343 }
344}
345
346impl Default for ServerBuilder {
347 fn default() -> Self {
348 Self::new()
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_server_config_default() {
358 let config = ServerConfig::default();
359 assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS);
360 assert_eq!(config.max_message_size, DEFAULT_MAX_MESSAGE_SIZE);
361 }
362
363 #[test]
364 fn test_server_builder() {
365 let server = Server::builder()
366 .max_connections(5000)
367 .max_message_size(8 * 1024 * 1024)
368 .heartbeat_interval(60)
369 .build();
370
371 assert!(Arc::strong_count(&server) >= 1);
373 }
374
375 #[tokio::test]
376 async fn test_server_stats() {
377 let server = Server::builder().build();
378 let stats = server.stats().await;
379
380 assert_eq!(stats.active_connections, 0);
381 assert_eq!(stats.total_connections, 0);
382 }
383}