use crate::error::{Error, Result};
use crate::protocol::ProtocolCodec;
use crate::protocol::message::Message;
use futures::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::SystemTime;
use tokio::net::TcpStream;
use tokio::sync::{Mutex, mpsc};
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::Message as WsMessage;
use uuid::Uuid;
pub type ConnectionId = Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Connecting,
Connected,
Disconnecting,
Disconnected,
}
pub struct Connection {
id: ConnectionId,
remote_addr: SocketAddr,
state: Arc<Mutex<ConnectionState>>,
ws: Arc<Mutex<WebSocketStream<TcpStream>>>,
codec: Arc<ProtocolCodec>,
tx: mpsc::UnboundedSender<Message>,
last_activity: Arc<AtomicU64>,
metadata: Arc<Mutex<ConnectionMetadata>>,
stats: Arc<ConnectionStatistics>,
}
#[derive(Debug, Default, Clone)]
pub struct ConnectionMetadata {
pub user_id: Option<String>,
pub tags: std::collections::HashMap<String, String>,
pub subscriptions: std::collections::HashSet<String>,
pub rooms: std::collections::HashSet<String>,
}
#[derive(Debug, Default)]
pub struct ConnectionStatistics {
pub messages_sent: AtomicU64,
pub messages_received: AtomicU64,
pub bytes_sent: AtomicU64,
pub bytes_received: AtomicU64,
pub errors: AtomicU64,
}
impl Connection {
pub fn new(
ws: WebSocketStream<TcpStream>,
remote_addr: SocketAddr,
codec: ProtocolCodec,
) -> (Self, mpsc::UnboundedReceiver<Message>) {
let (tx, rx) = mpsc::unbounded_channel();
let connection = Self {
id: Uuid::new_v4(),
remote_addr,
state: Arc::new(Mutex::new(ConnectionState::Connected)),
ws: Arc::new(Mutex::new(ws)),
codec: Arc::new(codec),
tx,
last_activity: Arc::new(AtomicU64::new(Self::current_timestamp())),
metadata: Arc::new(Mutex::new(ConnectionMetadata::default())),
stats: Arc::new(ConnectionStatistics::default()),
};
(connection, rx)
}
pub fn id(&self) -> ConnectionId {
self.id
}
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
pub async fn state(&self) -> ConnectionState {
*self.state.lock().await
}
pub async fn set_state(&self, new_state: ConnectionState) {
let mut state = self.state.lock().await;
*state = new_state;
}
pub async fn send(&self, message: Message) -> Result<()> {
self.tx
.send(message)
.map_err(|e| Error::Connection(format!("Failed to send message: {}", e)))?;
Ok(())
}
pub async fn receive(&self) -> Result<Option<Message>> {
let mut ws = self.ws.lock().await;
match ws.next().await {
Some(Ok(ws_msg)) => {
self.update_activity();
self.stats.messages_received.fetch_add(1, Ordering::Relaxed);
match ws_msg {
WsMessage::Binary(data) => {
let bytes: &[u8] = &data;
self.stats
.bytes_received
.fetch_add(bytes.len() as u64, Ordering::Relaxed);
let message = self.codec.decode(bytes)?;
Ok(Some(message))
}
WsMessage::Text(text) => {
let bytes = text.as_bytes();
self.stats
.bytes_received
.fetch_add(bytes.len() as u64, Ordering::Relaxed);
let message = self.codec.decode(bytes)?;
Ok(Some(message))
}
WsMessage::Ping(data) => {
ws.send(WsMessage::Pong(data)).await?;
Ok(None)
}
WsMessage::Pong(_) => {
Ok(None)
}
WsMessage::Close(_) => {
self.set_state(ConnectionState::Disconnecting).await;
Ok(None)
}
_ => Ok(None),
}
}
Some(Err(e)) => {
self.stats.errors.fetch_add(1, Ordering::Relaxed);
Err(Error::WebSocket(e.to_string()))
}
None => {
self.set_state(ConnectionState::Disconnected).await;
Ok(None)
}
}
}
pub async fn process_outgoing(&self, mut rx: mpsc::UnboundedReceiver<Message>) -> Result<()> {
while let Some(message) = rx.recv().await {
if let Err(e) = self.send_message(message).await {
tracing::error!("Failed to send message: {}", e);
self.stats.errors.fetch_add(1, Ordering::Relaxed);
}
}
Ok(())
}
async fn send_message(&self, message: Message) -> Result<()> {
let encoded = self.codec.encode(&message)?;
self.stats
.bytes_sent
.fetch_add(encoded.len() as u64, Ordering::Relaxed);
self.stats.messages_sent.fetch_add(1, Ordering::Relaxed);
let mut ws = self.ws.lock().await;
ws.send(WsMessage::Binary(encoded.to_vec().into())).await?;
self.update_activity();
Ok(())
}
pub async fn ping(&self) -> Result<()> {
let mut ws = self.ws.lock().await;
ws.send(WsMessage::Ping(Vec::new().into())).await?;
self.update_activity();
Ok(())
}
pub async fn close(&self) -> Result<()> {
self.set_state(ConnectionState::Disconnecting).await;
let mut ws = self.ws.lock().await;
ws.close(None).await?;
self.set_state(ConnectionState::Disconnected).await;
Ok(())
}
pub async fn metadata(&self) -> ConnectionMetadata {
self.metadata.lock().await.clone()
}
pub async fn update_metadata<F>(&self, f: F)
where
F: FnOnce(&mut ConnectionMetadata),
{
let mut metadata = self.metadata.lock().await;
f(&mut metadata);
}
pub fn last_activity(&self) -> u64 {
self.last_activity.load(Ordering::Relaxed)
}
fn update_activity(&self) {
self.last_activity
.store(Self::current_timestamp(), Ordering::Relaxed);
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
pub fn is_idle(&self, timeout_secs: u64) -> bool {
let now = Self::current_timestamp();
let last = self.last_activity();
now.saturating_sub(last) > timeout_secs
}
pub fn stats(&self) -> ConnectionStats {
ConnectionStats {
messages_sent: self.stats.messages_sent.load(Ordering::Relaxed),
messages_received: self.stats.messages_received.load(Ordering::Relaxed),
bytes_sent: self.stats.bytes_sent.load(Ordering::Relaxed),
bytes_received: self.stats.bytes_received.load(Ordering::Relaxed),
errors: self.stats.errors.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ConnectionStats {
pub messages_sent: u64,
pub messages_received: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub errors: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_id() {
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
assert_ne!(id1, id2);
}
#[test]
fn test_connection_state() {
assert_eq!(ConnectionState::Connected, ConnectionState::Connected);
assert_ne!(ConnectionState::Connected, ConnectionState::Disconnected);
}
#[test]
fn test_connection_metadata() {
let mut metadata = ConnectionMetadata {
user_id: Some("user123".to_string()),
..Default::default()
};
metadata
.tags
.insert("role".to_string(), "admin".to_string());
assert_eq!(metadata.user_id, Some("user123".to_string()));
assert_eq!(metadata.tags.get("role"), Some(&"admin".to_string()));
}
#[test]
fn test_connection_stats() {
let stats = ConnectionStatistics::default();
stats.messages_sent.fetch_add(5, Ordering::Relaxed);
stats.bytes_sent.fetch_add(1024, Ordering::Relaxed);
assert_eq!(stats.messages_sent.load(Ordering::Relaxed), 5);
assert_eq!(stats.bytes_sent.load(Ordering::Relaxed), 1024);
}
}