Skip to main content

commy_sdk_rust/
connection.rs

1//! WebSocket connection management
2
3use crate::error::{CommyError, Result};
4use crate::message::{ClientMessage, ServerMessage};
5use futures::{SinkExt, StreamExt};
6use std::sync::Arc;
7use tokio::sync::{mpsc, RwLock};
8use tokio_tungstenite::{connect_async, tungstenite::Message};
9
10/// Connection state
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ConnectionState {
13    Disconnected,
14    Connecting,
15    Connected,
16    Authenticated,
17    Closing,
18}
19
20/// Manages WebSocket connection
21pub struct Connection {
22    state: Arc<RwLock<ConnectionState>>,
23    tx: mpsc::UnboundedSender<ClientMessage>,
24    rx: Arc<RwLock<mpsc::UnboundedReceiver<ServerMessage>>>,
25}
26
27impl Connection {
28    /// Create a new connection
29    pub async fn new(url: &str) -> Result<Self> {
30        let (ws_stream, _) = connect_async(url)
31            .await
32            .map_err(|e| CommyError::WebSocketError(e.to_string()))?;
33
34        let (mut write, mut read) = ws_stream.split();
35        let (tx, mut rx) = mpsc::unbounded_channel::<ClientMessage>();
36        let (server_tx, server_rx) = mpsc::unbounded_channel::<ServerMessage>();
37
38        // Spawn tasks to handle message routing
39        tokio::spawn(async move {
40            while let Some(msg) = rx.recv().await {
41                if let Ok(serialized) = serde_json::to_string(&msg) {
42                    let _ = write.send(Message::Text(serialized)).await;
43                }
44            }
45        });
46
47        tokio::spawn(async move {
48            while let Some(Ok(msg)) = read.next().await {
49                if let Message::Text(text) = msg {
50                    match serde_json::from_str::<ServerMessage>(&text) {
51                        Ok(server_msg) => {
52                            let _ = server_tx.send(server_msg);
53                        }
54                        Err(e) => {
55                            eprintln!("[Client] Failed to deserialize ServerMessage: {}", e);
56                            eprintln!("[Client] Raw message: {}", text);
57                        }
58                    }
59                }
60            }
61        });
62
63        Ok(Self {
64            state: Arc::new(RwLock::new(ConnectionState::Connected)),
65            tx,
66            rx: Arc::new(RwLock::new(server_rx)),
67        })
68    }
69
70    /// Send a message to the server
71    pub async fn send(&self, message: ClientMessage) -> Result<()> {
72        self.tx
73            .send(message)
74            .map_err(|e| CommyError::ChannelError(format!("Failed to send message: {}", e)))?;
75        Ok(())
76    }
77
78    /// Receive a message from the server
79    pub async fn recv(&self) -> Result<Option<ServerMessage>> {
80        let mut rx = self.rx.write().await;
81        Ok(rx.recv().await)
82    }
83
84    /// Get current connection state
85    pub async fn state(&self) -> ConnectionState {
86        *self.state.read().await
87    }
88
89    /// Set connection state
90    pub async fn set_state(&self, state: ConnectionState) {
91        *self.state.write().await = state;
92    }
93
94    /// Check if connected
95    pub async fn is_connected(&self) -> bool {
96        matches!(
97            *self.state.read().await,
98            ConnectionState::Connected | ConnectionState::Authenticated
99        )
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_connection_state() {
109        let state = ConnectionState::Connected;
110        assert!(matches!(state, ConnectionState::Connected));
111    }
112}