use axum::extract::ws::Message;
use std::fmt;
use std::hash::{Hash, Hasher};
use tokio::sync::mpsc;
use uuid::Uuid;
#[derive(Clone, Copy, Eq)]
pub struct ConnectionId(Uuid);
impl ConnectionId {
#[must_use]
pub fn new() -> Self {
Self(Uuid::new_v4())
}
#[must_use]
pub fn as_uuid(&self) -> Uuid {
self.0
}
}
impl Default for ConnectionId {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for ConnectionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ConnectionId({})", self.0)
}
}
impl fmt::Display for ConnectionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl PartialEq for ConnectionId {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Hash for ConnectionId {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl From<Uuid> for ConnectionId {
fn from(uuid: Uuid) -> Self {
Self(uuid)
}
}
#[derive(Debug)]
pub struct WebSocketConnection {
pub id: ConnectionId,
pub sender: mpsc::Sender<Message>,
pub user_id: Option<String>,
pub rooms: Vec<String>,
pub client_ip: Option<String>,
}
impl WebSocketConnection {
#[must_use]
pub fn new(sender: mpsc::Sender<Message>) -> Self {
Self {
id: ConnectionId::new(),
sender,
user_id: None,
rooms: Vec::new(),
client_ip: None,
}
}
#[must_use]
pub fn authenticated(sender: mpsc::Sender<Message>, user_id: String) -> Self {
Self {
id: ConnectionId::new(),
sender,
user_id: Some(user_id),
rooms: Vec::new(),
client_ip: None,
}
}
#[must_use]
pub fn with_client_ip(mut self, ip: String) -> Self {
self.client_ip = Some(ip);
self
}
#[must_use]
pub fn is_authenticated(&self) -> bool {
self.user_id.is_some()
}
pub async fn send(&self, message: Message) -> Result<(), mpsc::error::SendError<Message>> {
self.sender.send(message).await
}
pub async fn send_text(
&self,
text: impl Into<String>,
) -> Result<(), mpsc::error::SendError<Message>> {
self.send(Message::Text(text.into().into())).await
}
pub async fn send_binary(&self, data: Vec<u8>) -> Result<(), mpsc::error::SendError<Message>> {
self.send(Message::Binary(data.into())).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_id_uniqueness() {
let id1 = ConnectionId::new();
let id2 = ConnectionId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_connection_id_display() {
let id = ConnectionId::new();
let display = format!("{}", id);
assert!(!display.is_empty());
}
#[tokio::test]
async fn test_websocket_connection_creation() {
let (tx, _rx) = mpsc::channel(32);
let conn = WebSocketConnection::new(tx);
assert!(!conn.is_authenticated());
assert!(conn.rooms.is_empty());
}
#[tokio::test]
async fn test_authenticated_connection() {
let (tx, _rx) = mpsc::channel(32);
let conn = WebSocketConnection::authenticated(tx, "user123".to_string());
assert!(conn.is_authenticated());
assert_eq!(conn.user_id, Some("user123".to_string()));
}
}