use async_trait::async_trait;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::Mutex;
use crate::context::ConnectionContext;
use crate::error::ExtensionError;
use crate::extension::Extension;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ConnectionStatus {
Connected,
Disconnected,
}
pub struct StatusViewer {
current: Arc<Mutex<ConnectionStatus>>,
}
impl StatusViewer {
#[must_use]
pub fn new() -> Self {
Self {
current: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
}
}
pub async fn current_status(&self) -> ConnectionStatus {
let status = self.current.lock().await;
status.clone()
}
#[must_use]
pub fn status_handle(&self) -> Arc<Mutex<ConnectionStatus>> {
self.current.clone()
}
pub async fn is_connected(&self) -> bool {
matches!(self.current_status().await, ConnectionStatus::Connected)
}
}
impl Default for StatusViewer {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Extension for StatusViewer {
fn name(&self) -> &'static str {
"status_viewer"
}
fn version(&self) -> &'static str {
"2.0.0"
}
fn description(&self) -> &'static str {
"Tracks current WebSocket connection status"
}
fn handles_lifecycle(&self) -> bool {
true
}
fn handles_messages(&self) -> bool {
false
}
async fn on_connect(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
tracing::debug!(
connection_id = ctx.connection_id,
reconnect_count = ctx.reconnect_count,
"StatusViewer: connected"
);
*self.current.lock().await = ConnectionStatus::Connected;
Ok(())
}
async fn on_disconnect(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
tracing::debug!(
connection_id = ctx.connection_id,
"StatusViewer: disconnected"
);
*self.current.lock().await = ConnectionStatus::Disconnected;
Ok(())
}
}
pub struct AdvancedStatusViewer {
current: Arc<Mutex<ConnectionStatus>>,
history: Arc<Mutex<Vec<(Instant, ConnectionStatus)>>>,
max_history: usize,
}
impl AdvancedStatusViewer {
#[must_use]
pub fn new() -> Self {
Self::with_history_limit(100)
}
#[must_use]
pub fn with_history_limit(max_history: usize) -> Self {
Self {
current: Arc::new(Mutex::new(ConnectionStatus::Disconnected)),
history: Arc::new(Mutex::new(Vec::new())),
max_history,
}
}
pub async fn current_status(&self) -> ConnectionStatus {
self.current.lock().await.clone()
}
pub async fn get_history(&self) -> Vec<(Instant, ConnectionStatus)> {
self.history.lock().await.clone()
}
pub async fn get_uptime(&self) -> Option<std::time::Duration> {
self.history
.lock()
.await
.iter()
.rev()
.find(|(_, status)| matches!(status, ConnectionStatus::Connected))
.map(|(time, _)| time.elapsed())
}
pub async fn get_connection_count(&self) -> usize {
self.history
.lock()
.await
.iter()
.filter(|(_, status)| matches!(status, ConnectionStatus::Connected))
.count()
}
pub async fn is_connected(&self) -> bool {
matches!(self.current_status().await, ConnectionStatus::Connected)
}
async fn add_to_history(&self, status: ConnectionStatus) {
let mut history = self.history.lock().await;
history.push((Instant::now(), status));
if history.len() > self.max_history {
history.remove(0);
}
}
async fn set_status(&self, status: ConnectionStatus) {
let mut current = self.current.lock().await;
*current = status.clone();
drop(current);
self.add_to_history(status).await;
}
}
impl Default for AdvancedStatusViewer {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Extension for AdvancedStatusViewer {
fn name(&self) -> &'static str {
"advanced_status_viewer"
}
fn version(&self) -> &'static str {
"2.0.0"
}
fn description(&self) -> &'static str {
"Advanced status viewer with connection history and metrics"
}
fn handles_lifecycle(&self) -> bool {
true
}
fn handles_messages(&self) -> bool {
false
}
async fn on_connect(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
tracing::info!(
connection_id = ctx.connection_id,
reconnect_count = ctx.reconnect_count,
"AdvancedStatusViewer: connected"
);
self.set_status(ConnectionStatus::Connected).await;
Ok(())
}
async fn on_disconnect(&self, ctx: &ConnectionContext) -> Result<(), ExtensionError> {
tracing::info!(
connection_id = ctx.connection_id,
"AdvancedStatusViewer: disconnected"
);
self.set_status(ConnectionStatus::Disconnected).await;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_status_viewer() {
let viewer = StatusViewer::new();
assert_eq!(
viewer.current_status().await,
ConnectionStatus::Disconnected
);
let ctx = ConnectionContext::new(1);
viewer.on_connect(&ctx).await.unwrap();
assert_eq!(viewer.current_status().await, ConnectionStatus::Connected);
viewer.on_disconnect(&ctx).await.unwrap();
assert_eq!(
viewer.current_status().await,
ConnectionStatus::Disconnected
);
}
#[tokio::test]
async fn test_advanced_status_viewer() {
let viewer = AdvancedStatusViewer::new();
let ctx = ConnectionContext::new(1);
viewer.on_connect(&ctx).await.unwrap();
assert!(viewer.is_connected().await);
assert_eq!(viewer.get_connection_count().await, 1);
viewer.on_disconnect(&ctx).await.unwrap();
assert!(!viewer.is_connected().await);
viewer.on_connect(&ctx).await.unwrap();
assert_eq!(viewer.get_connection_count().await, 2);
assert_eq!(viewer.get_history().await.len(), 3);
}
}