1#![warn(missing_docs)]
9
10use async_trait::async_trait;
11use futures_util::{SinkExt, StreamExt};
12use serde::{Serialize, de::DeserializeOwned};
13use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
14use tokio::sync::{RwLock, broadcast, mpsc};
15use tokio_tungstenite::tungstenite::protocol::Message as WsMessage;
16use wae_types::{WaeError, WaeErrorKind, WaeResult};
17
18pub type ConnectionId = String;
20
21pub type RoomId = String;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum Message {
27 Text(String),
29 Binary(Vec<u8>),
31 Ping,
33 Pong,
35 Close,
37}
38
39impl Message {
40 pub fn text(content: impl Into<String>) -> Self {
42 Message::Text(content.into())
43 }
44
45 pub fn binary(data: impl Into<Vec<u8>>) -> Self {
47 Message::Binary(data.into())
48 }
49
50 pub fn is_text(&self) -> bool {
52 matches!(self, Message::Text(_))
53 }
54
55 pub fn is_binary(&self) -> bool {
57 matches!(self, Message::Binary(_))
58 }
59
60 pub fn as_text(&self) -> Option<&str> {
62 match self {
63 Message::Text(s) => Some(s),
64 _ => None,
65 }
66 }
67
68 pub fn as_binary(&self) -> Option<&[u8]> {
70 match self {
71 Message::Binary(data) => Some(data),
72 _ => None,
73 }
74 }
75}
76
77impl From<WsMessage> for Message {
78 fn from(msg: WsMessage) -> Self {
79 match msg {
80 WsMessage::Text(s) => Message::Text(s.to_string()),
81 WsMessage::Binary(data) => Message::Binary(data.to_vec()),
82 WsMessage::Ping(_) => Message::Ping,
83 WsMessage::Pong(_) => Message::Pong,
84 WsMessage::Close(_) => Message::Close,
85 _ => Message::Close,
86 }
87 }
88}
89
90impl From<Message> for WsMessage {
91 fn from(msg: Message) -> Self {
92 match msg {
93 Message::Text(s) => WsMessage::Text(s.into()),
94 Message::Binary(data) => WsMessage::Binary(data.into()),
95 Message::Ping => WsMessage::Ping(Vec::new().into()),
96 Message::Pong => WsMessage::Pong(Vec::new().into()),
97 Message::Close => WsMessage::Close(None),
98 }
99 }
100}
101
102#[derive(Debug, Clone)]
104pub struct Connection {
105 pub id: ConnectionId,
107 pub addr: SocketAddr,
109 pub connected_at: std::time::Instant,
111 pub metadata: HashMap<String, String>,
113 pub rooms: Vec<RoomId>,
115}
116
117impl Connection {
118 pub fn new(id: ConnectionId, addr: SocketAddr) -> Self {
120 Self { id, addr, connected_at: std::time::Instant::now(), metadata: HashMap::new(), rooms: Vec::new() }
121 }
122
123 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
125 self.metadata.insert(key.into(), value.into());
126 self
127 }
128
129 pub fn duration(&self) -> Duration {
131 self.connected_at.elapsed()
132 }
133}
134
135pub struct ConnectionManager {
137 connections: Arc<RwLock<HashMap<ConnectionId, Connection>>>,
138 max_connections: u32,
139}
140
141impl ConnectionManager {
142 pub fn new(max_connections: u32) -> Self {
144 Self { connections: Arc::new(RwLock::new(HashMap::new())), max_connections }
145 }
146
147 pub async fn add(&self, connection: Connection) -> WaeResult<()> {
149 let mut connections = self.connections.write().await;
150 if connections.len() >= self.max_connections as usize {
151 return Err(WaeError::new(WaeErrorKind::ResourceConflict {
152 resource: "Connection".to_string(),
153 reason: format!("Maximum connections ({}) exceeded", self.max_connections),
154 }));
155 }
156 connections.insert(connection.id.clone(), connection);
157 Ok(())
158 }
159
160 pub async fn remove(&self, id: &str) -> Option<Connection> {
162 let mut connections = self.connections.write().await;
163 connections.remove(id)
164 }
165
166 pub async fn get(&self, id: &str) -> Option<Connection> {
168 let connections = self.connections.read().await;
169 connections.get(id).cloned()
170 }
171
172 pub async fn exists(&self, id: &str) -> bool {
174 let connections = self.connections.read().await;
175 connections.contains_key(id)
176 }
177
178 pub async fn count(&self) -> usize {
180 let connections = self.connections.read().await;
181 connections.len()
182 }
183
184 pub async fn all_ids(&self) -> Vec<ConnectionId> {
186 let connections = self.connections.read().await;
187 connections.keys().cloned().collect()
188 }
189
190 pub async fn join_room(&self, id: &str, room: &str) -> WaeResult<()> {
192 let mut connections = self.connections.write().await;
193 if let Some(conn) = connections.get_mut(id) {
194 if !conn.rooms.contains(&room.to_string()) {
195 conn.rooms.push(room.to_string());
196 }
197 return Ok(());
198 }
199 Err(WaeError::not_found("Connection", id))
200 }
201
202 pub async fn leave_room(&self, id: &str, room: &str) -> WaeResult<()> {
204 let mut connections = self.connections.write().await;
205 if let Some(conn) = connections.get_mut(id) {
206 conn.rooms.retain(|r| r != room);
207 return Ok(());
208 }
209 Err(WaeError::not_found("Connection", id))
210 }
211}
212
213pub struct RoomManager {
215 rooms: Arc<RwLock<HashMap<RoomId, Vec<ConnectionId>>>>,
216}
217
218impl RoomManager {
219 pub fn new() -> Self {
221 Self { rooms: Arc::new(RwLock::new(HashMap::new())) }
222 }
223
224 pub async fn create_room(&self, room_id: &str) {
226 let mut rooms = self.rooms.write().await;
227 rooms.entry(room_id.to_string()).or_insert_with(Vec::new);
228 }
229
230 pub async fn delete_room(&self, room_id: &str) -> Option<Vec<ConnectionId>> {
232 let mut rooms = self.rooms.write().await;
233 rooms.remove(room_id)
234 }
235
236 pub async fn join(&self, room_id: &str, connection_id: &str) {
238 let mut rooms = self.rooms.write().await;
239 let room = rooms.entry(room_id.to_string()).or_insert_with(Vec::new);
240 if !room.contains(&connection_id.to_string()) {
241 room.push(connection_id.to_string());
242 }
243 }
244
245 pub async fn leave(&self, room_id: &str, connection_id: &str) {
247 let mut rooms = self.rooms.write().await;
248 if let Some(room) = rooms.get_mut(room_id) {
249 room.retain(|id| id != connection_id);
250 if room.is_empty() {
251 rooms.remove(room_id);
252 }
253 }
254 }
255
256 pub async fn get_members(&self, room_id: &str) -> Vec<ConnectionId> {
258 let rooms = self.rooms.read().await;
259 rooms.get(room_id).cloned().unwrap_or_default()
260 }
261
262 pub async fn room_exists(&self, room_id: &str) -> bool {
264 let rooms = self.rooms.read().await;
265 rooms.contains_key(room_id)
266 }
267
268 pub async fn room_count(&self) -> usize {
270 let rooms = self.rooms.read().await;
271 rooms.len()
272 }
273
274 pub async fn member_count(&self, room_id: &str) -> usize {
276 let rooms = self.rooms.read().await;
277 rooms.get(room_id).map(|r| r.len()).unwrap_or(0)
278 }
279
280 pub async fn broadcast(&self, room_id: &str, sender: &Sender, message: &Message) -> WaeResult<Vec<ConnectionId>> {
282 let members = self.get_members(room_id).await;
283 let mut sent_to = Vec::new();
284 for conn_id in &members {
285 if sender.send_to(conn_id, message.clone()).await.is_ok() {
286 sent_to.push(conn_id.clone());
287 }
288 }
289 Ok(sent_to)
290 }
291}
292
293impl Default for RoomManager {
294 fn default() -> Self {
295 Self::new()
296 }
297}
298
299#[derive(Clone)]
301pub struct Sender {
302 senders: Arc<RwLock<HashMap<ConnectionId, mpsc::UnboundedSender<Message>>>>,
303}
304
305impl Sender {
306 pub fn new() -> Self {
308 Self { senders: Arc::new(RwLock::new(HashMap::new())) }
309 }
310
311 pub async fn register(&self, connection_id: ConnectionId, sender: mpsc::UnboundedSender<Message>) {
313 let mut senders = self.senders.write().await;
314 senders.insert(connection_id, sender);
315 }
316
317 pub async fn unregister(&self, connection_id: &str) {
319 let mut senders = self.senders.write().await;
320 senders.remove(connection_id);
321 }
322
323 pub async fn send_to(&self, connection_id: &str, message: Message) -> WaeResult<()> {
325 let senders = self.senders.read().await;
326 if let Some(sender) = senders.get(connection_id) {
327 sender
328 .send(message)
329 .map_err(|e| WaeError::new(WaeErrorKind::InternalError { reason: format!("Send failed: {}", e) }))?;
330 return Ok(());
331 }
332 Err(WaeError::not_found("Connection", connection_id))
333 }
334
335 pub async fn broadcast(&self, message: Message) -> WaeResult<usize> {
337 let senders = self.senders.read().await;
338 let mut count = 0;
339 for sender in senders.values() {
340 if sender.send(message.clone()).is_ok() {
341 count += 1;
342 }
343 }
344 Ok(count)
345 }
346
347 pub async fn count(&self) -> usize {
349 let senders = self.senders.read().await;
350 senders.len()
351 }
352}
353
354impl Default for Sender {
355 fn default() -> Self {
356 Self::new()
357 }
358}
359
360#[derive(Debug, Clone)]
362pub struct ServerConfig {
363 pub host: String,
365 pub port: u16,
367 pub max_connections: u32,
369 pub heartbeat_interval: Duration,
371 pub connection_timeout: Duration,
373}
374
375impl Default for ServerConfig {
376 fn default() -> Self {
377 Self {
378 host: "0.0.0.0".to_string(),
379 port: 8080,
380 max_connections: 1000,
381 heartbeat_interval: Duration::from_secs(30),
382 connection_timeout: Duration::from_secs(60),
383 }
384 }
385}
386
387impl ServerConfig {
388 pub fn new() -> Self {
390 Self::default()
391 }
392
393 pub fn host(mut self, host: impl Into<String>) -> Self {
395 self.host = host.into();
396 self
397 }
398
399 pub fn port(mut self, port: u16) -> Self {
401 self.port = port;
402 self
403 }
404
405 pub fn max_connections(mut self, max: u32) -> Self {
407 self.max_connections = max;
408 self
409 }
410
411 pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
413 self.heartbeat_interval = interval;
414 self
415 }
416}
417
418#[async_trait]
420pub trait ClientHandler: Send + Sync {
421 async fn on_connect(&self, connection: &Connection) -> WaeResult<()>;
423
424 async fn on_message(&self, connection: &Connection, message: Message) -> WaeResult<()>;
426
427 async fn on_disconnect(&self, connection: &Connection);
429}
430
431pub struct DefaultClientHandler;
433
434#[async_trait]
435impl ClientHandler for DefaultClientHandler {
436 async fn on_connect(&self, _connection: &Connection) -> WaeResult<()> {
437 Ok(())
438 }
439
440 async fn on_message(&self, _connection: &Connection, _message: Message) -> WaeResult<()> {
441 Ok(())
442 }
443
444 async fn on_disconnect(&self, _connection: &Connection) {}
445}
446
447pub struct WebSocketServer {
449 config: ServerConfig,
450 connection_manager: Arc<ConnectionManager>,
451 room_manager: Arc<RoomManager>,
452 sender: Sender,
453 shutdown_tx: broadcast::Sender<()>,
454}
455
456impl WebSocketServer {
457 pub fn new(config: ServerConfig) -> Self {
459 let (shutdown_tx, _) = broadcast::channel(1);
460 Self {
461 config,
462 connection_manager: Arc::new(ConnectionManager::new(1000)),
463 room_manager: Arc::new(RoomManager::new()),
464 sender: Sender::new(),
465 shutdown_tx,
466 }
467 }
468
469 pub fn connection_manager(&self) -> &Arc<ConnectionManager> {
471 &self.connection_manager
472 }
473
474 pub fn room_manager(&self) -> &Arc<RoomManager> {
476 &self.room_manager
477 }
478
479 pub fn sender(&self) -> &Sender {
481 &self.sender
482 }
483
484 pub fn config(&self) -> &ServerConfig {
486 &self.config
487 }
488
489 pub async fn start<H: ClientHandler + 'static>(&self, handler: H) -> WaeResult<()> {
491 let addr = format!("{}:{}", self.config.host, self.config.port);
492 let listener = tokio::net::TcpListener::bind(&addr)
493 .await
494 .map_err(|_e| WaeError::new(WaeErrorKind::ConnectionFailed { target: addr.clone() }))?;
495
496 tracing::info!("WebSocket server listening on {}", addr);
497
498 let mut shutdown_rx = self.shutdown_tx.subscribe();
499 let handler = Arc::new(handler);
500
501 loop {
502 tokio::select! {
503 accept_result = listener.accept() => {
504 match accept_result {
505 Ok((stream, addr)) => {
506 let connection_manager = self.connection_manager.clone();
507 let room_manager = self.room_manager.clone();
508 let sender = self.sender.clone();
509 let handler = handler.clone();
510 let config = self.config.clone();
511
512 tokio::spawn(async move {
513 if let Err(e) = Self::handle_connection(
514 stream,
515 addr,
516 connection_manager,
517 room_manager,
518 sender,
519 handler,
520 config,
521 ).await {
522 tracing::error!("Connection error: {}", e);
523 }
524 });
525 }
526 Err(e) => {
527 tracing::error!("Accept error: {}", e);
528 }
529 }
530 }
531 _ = shutdown_rx.recv() => {
532 tracing::info!("WebSocket server shutting down");
533 break;
534 }
535 }
536 }
537
538 Ok(())
539 }
540
541 async fn handle_connection<H: ClientHandler>(
542 stream: tokio::net::TcpStream,
543 addr: SocketAddr,
544 connection_manager: Arc<ConnectionManager>,
545 room_manager: Arc<RoomManager>,
546 sender: Sender,
547 handler: Arc<H>,
548 config: ServerConfig,
549 ) -> WaeResult<()> {
550 let ws_stream = tokio_tungstenite::accept_async(stream)
551 .await
552 .map_err(|_e| WaeError::new(WaeErrorKind::ConnectionFailed { target: addr.to_string() }))?;
553
554 let connection_id = uuid::Uuid::new_v4().to_string();
555 let connection = Connection::new(connection_id.clone(), addr);
556
557 if connection_manager.add(connection.clone()).await.is_err() {
558 return Err(WaeError::new(WaeErrorKind::ResourceConflict {
559 resource: "Connection".to_string(),
560 reason: format!("Maximum connections ({}) exceeded", config.max_connections),
561 }));
562 }
563
564 handler.on_connect(&connection).await?;
565 tracing::info!("Client connected: {} from {}", connection_id, addr);
566
567 let (ws_sender, mut ws_receiver) = ws_stream.split();
568 let (tx, mut rx) = mpsc::unbounded_channel::<Message>();
569
570 sender.register(connection_id.clone(), tx).await;
571
572 let send_task = async move {
573 let mut ws_sender = ws_sender;
574 while let Some(msg) = rx.recv().await {
575 if ws_sender.send(msg.into()).await.is_err() {
576 break;
577 }
578 }
579 let _ = ws_sender.close().await;
580 };
581
582 let connection_manager_clone = connection_manager.clone();
583 let room_manager_clone = room_manager.clone();
584 let sender_clone = sender.clone();
585 let connection_id_clone = connection_id.clone();
586 let connection_clone = connection.clone();
587 let handler_clone = handler.clone();
588 let recv_task = async move {
589 while let Some(msg_result) = ws_receiver.next().await {
590 match msg_result {
591 Ok(ws_msg) => {
592 let msg: Message = ws_msg.into();
593 if matches!(msg, Message::Close) {
594 break;
595 }
596 if handler_clone.on_message(&connection_clone, msg).await.is_err() {
597 break;
598 }
599 }
600 Err(_) => break,
601 }
602 }
603 };
604
605 tokio::select! {
606 _ = send_task => {},
607 _ = recv_task => {},
608 }
609
610 for room_id in &connection.rooms {
611 room_manager_clone.leave(room_id, &connection_id_clone).await;
612 }
613
614 connection_manager_clone.remove(&connection_id_clone).await;
615 sender_clone.unregister(&connection_id_clone).await;
616 handler.on_disconnect(&connection).await;
617
618 tracing::info!("Client disconnected: {}", connection_id);
619
620 Ok(())
621 }
622
623 pub fn shutdown(&self) {
625 let _ = self.shutdown_tx.send(());
626 }
627
628 pub async fn broadcast(&self, message: Message) -> WaeResult<usize> {
630 self.sender.broadcast(message).await
631 }
632
633 pub async fn broadcast_to_room(&self, room_id: &str, message: Message) -> WaeResult<Vec<ConnectionId>> {
635 self.room_manager.broadcast(room_id, &self.sender, &message).await
636 }
637}
638
639#[derive(Debug, Clone)]
641pub struct ClientConfig {
642 pub url: String,
644 pub reconnect_interval: Duration,
646 pub heartbeat_interval: Duration,
648 pub connection_timeout: Duration,
650 pub max_reconnect_attempts: u32,
652}
653
654impl Default for ClientConfig {
655 fn default() -> Self {
656 Self {
657 url: "ws://127.0.0.1:8080".to_string(),
658 reconnect_interval: Duration::from_secs(5),
659 heartbeat_interval: Duration::from_secs(30),
660 connection_timeout: Duration::from_secs(10),
661 max_reconnect_attempts: 0,
662 }
663 }
664}
665
666impl ClientConfig {
667 pub fn new(url: impl Into<String>) -> Self {
669 Self { url: url.into(), ..Self::default() }
670 }
671
672 pub fn reconnect_interval(mut self, interval: Duration) -> Self {
674 self.reconnect_interval = interval;
675 self
676 }
677
678 pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
680 self.heartbeat_interval = interval;
681 self
682 }
683
684 pub fn max_reconnect_attempts(mut self, attempts: u32) -> Self {
686 self.max_reconnect_attempts = attempts;
687 self
688 }
689}
690
691pub struct WebSocketClient {
693 config: ClientConfig,
694 sender: mpsc::UnboundedSender<Message>,
695 receiver: mpsc::UnboundedReceiver<Message>,
696}
697
698impl WebSocketClient {
699 pub fn new(config: ClientConfig) -> Self {
701 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::<Message>();
702 let (incoming_tx, incoming_rx) = mpsc::unbounded_channel::<Message>();
703
704 let config_clone = config.clone();
705
706 tokio::spawn(async move {
707 let mut attempt = 0u32;
708 loop {
709 match tokio_tungstenite::connect_async(&config_clone.url).await {
710 Ok((ws_stream, _)) => {
711 tracing::info!("WebSocket client connected to {}", config_clone.url);
712 attempt = 0;
713
714 let (mut ws_sender, mut ws_receiver) = ws_stream.split();
715
716 let send_task = async {
717 while let Some(msg) = outgoing_rx.recv().await {
718 if ws_sender.send(msg.into()).await.is_err() {
719 break;
720 }
721 }
722 };
723
724 let recv_task = async {
725 while let Some(msg_result) = ws_receiver.next().await {
726 match msg_result {
727 Ok(ws_msg) => {
728 let msg: Message = ws_msg.into();
729 if matches!(msg, Message::Close) {
730 break;
731 }
732 if incoming_tx.send(msg).is_err() {
733 break;
734 }
735 }
736 Err(_) => break,
737 }
738 }
739 };
740
741 tokio::select! {
742 _ = send_task => {},
743 _ = recv_task => {},
744 }
745
746 tracing::warn!("WebSocket client disconnected, attempting to reconnect...");
747 }
748 Err(e) => {
749 tracing::error!("WebSocket connection failed: {}", e);
750 }
751 }
752
753 attempt += 1;
754 if config_clone.max_reconnect_attempts > 0 && attempt >= config_clone.max_reconnect_attempts {
755 tracing::error!("Max reconnect attempts reached, giving up");
756 break;
757 }
758
759 tokio::time::sleep(config_clone.reconnect_interval).await;
760 }
761 });
762
763 Self { config, sender: outgoing_tx, receiver: incoming_rx }
764 }
765
766 pub async fn send(&self, message: Message) -> WaeResult<()> {
768 self.sender
769 .send(message)
770 .map_err(|e| WaeError::new(WaeErrorKind::InternalError { reason: format!("Send failed: {}", e) }))
771 }
772
773 pub async fn send_text(&self, text: impl Into<String>) -> WaeResult<()> {
775 self.send(Message::text(text)).await
776 }
777
778 pub async fn send_binary(&self, data: impl Into<Vec<u8>>) -> WaeResult<()> {
780 self.send(Message::binary(data)).await
781 }
782
783 pub async fn send_json<T: Serialize + ?Sized>(&self, value: &T) -> WaeResult<()> {
785 let json = serde_json::to_string(value).map_err(|_e| WaeError::serialization_failed("JSON"))?;
786 self.send_text(json).await
787 }
788
789 pub async fn receive(&mut self) -> Option<Message> {
791 self.receiver.recv().await
792 }
793
794 pub async fn receive_json<T: DeserializeOwned>(&mut self) -> WaeResult<Option<T>> {
796 match self.receive().await {
797 Some(msg) => {
798 let text = msg.as_text().ok_or_else(|| WaeError::deserialization_failed("Expected text message"))?;
799 let value: T = serde_json::from_str(text).map_err(|_e| WaeError::deserialization_failed("JSON"))?;
800 Ok(Some(value))
801 }
802 None => Ok(None),
803 }
804 }
805
806 pub fn config(&self) -> &ClientConfig {
808 &self.config
809 }
810
811 pub async fn close(&self) -> WaeResult<()> {
813 self.send(Message::Close).await
814 }
815}
816
817pub fn websocket_server(config: ServerConfig) -> WebSocketServer {
819 WebSocketServer::new(config)
820}
821
822pub fn websocket_client(config: ClientConfig) -> WebSocketClient {
824 WebSocketClient::new(config)
825}