avl_console/
state.rs

1//! Application state management
2
3use crate::{
4    config::ConsoleConfig,
5    error::{ConsoleError, Result},
6};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10
11/// Shared application state
12pub struct AppState {
13    /// Configuration
14    pub config: ConsoleConfig,
15
16    /// Active WebSocket connections (user_id -> connection count)
17    pub ws_connections: Arc<RwLock<HashMap<String, usize>>>,
18
19    /// Rate limiter state (user_id -> request count)
20    pub rate_limiter: Arc<RwLock<HashMap<String, (u32, std::time::Instant)>>>,
21
22    /// Session store (session_id -> user_id)
23    pub sessions: Arc<RwLock<HashMap<String, String>>>,
24
25    /// Cached metrics for dashboard
26    pub metrics_cache: Arc<RwLock<Option<DashboardMetrics>>>,
27}
28
29/// Dashboard metrics cache
30#[derive(Debug, Clone)]
31pub struct DashboardMetrics {
32    pub database_count: usize,
33    pub storage_buckets: usize,
34    pub storage_size_bytes: u64,
35    pub active_connections: usize,
36    pub requests_per_minute: u32,
37    pub last_updated: std::time::Instant,
38}
39
40// Type alias for convenience
41pub type ConsoleState = AppState;
42
43impl AppState {
44    /// Create new application state
45    pub async fn new(config: &ConsoleConfig) -> Result<Self> {
46        config.validate()?;
47
48        Ok(Self {
49            config: config.clone(),
50            ws_connections: Arc::new(RwLock::new(HashMap::new())),
51            rate_limiter: Arc::new(RwLock::new(HashMap::new())),
52            sessions: Arc::new(RwLock::new(HashMap::new())),
53            metrics_cache: Arc::new(RwLock::new(None)),
54        })
55    }
56
57    /// Check if user can create a new WebSocket connection
58    pub async fn can_create_ws_connection(&self, user_id: &str) -> bool {
59        let connections = self.ws_connections.read().await;
60        let count = connections.get(user_id).copied().unwrap_or(0);
61        count < self.config.max_ws_connections
62    }
63
64    /// Increment WebSocket connection count for user
65    pub async fn increment_ws_connection(&self, user_id: String) {
66        let mut connections = self.ws_connections.write().await;
67        *connections.entry(user_id).or_insert(0) += 1;
68    }
69
70    /// Decrement WebSocket connection count for user
71    pub async fn decrement_ws_connection(&self, user_id: &str) {
72        let mut connections = self.ws_connections.write().await;
73        if let Some(count) = connections.get_mut(user_id) {
74            *count = count.saturating_sub(1);
75            if *count == 0 {
76                connections.remove(user_id);
77            }
78        }
79    }
80
81    /// Check rate limit for user
82    pub async fn check_rate_limit(&self, user_id: &str) -> Result<()> {
83        let mut limiter = self.rate_limiter.write().await;
84        let now = std::time::Instant::now();
85
86        if let Some((count, last_check)) = limiter.get_mut(user_id) {
87            if now.duration_since(*last_check).as_secs() >= 60 {
88                // Reset after 1 minute
89                *count = 1;
90                *last_check = now;
91            } else if *count >= self.config.rate_limit {
92                return Err(ConsoleError::RateLimitExceeded);
93            } else {
94                *count += 1;
95            }
96        } else {
97            limiter.insert(user_id.to_string(), (1, now));
98        }
99
100        Ok(())
101    }
102
103    /// Store session
104    pub async fn store_session(&self, session_id: String, user_id: String) {
105        let mut sessions = self.sessions.write().await;
106        sessions.insert(session_id, user_id);
107    }
108
109    /// Get user ID from session
110    pub async fn get_session(&self, session_id: &str) -> Option<String> {
111        let sessions = self.sessions.read().await;
112        sessions.get(session_id).cloned()
113    }
114
115    /// Remove session
116    pub async fn remove_session(&self, session_id: &str) {
117        let mut sessions = self.sessions.write().await;
118        sessions.remove(session_id);
119    }
120
121    /// Update metrics cache
122    pub async fn update_metrics(&self, metrics: DashboardMetrics) {
123        let mut cache = self.metrics_cache.write().await;
124        *cache = Some(metrics);
125    }
126
127    /// Get cached metrics
128    pub async fn get_metrics(&self) -> Option<DashboardMetrics> {
129        let cache = self.metrics_cache.read().await;
130        cache.clone()
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[tokio::test]
139    async fn test_ws_connection_limit() {
140        let config = ConsoleConfig::default();
141        let state = AppState::new(&config).await.unwrap();
142
143        assert!(state.can_create_ws_connection("user1").await);
144
145        for _ in 0..config.max_ws_connections {
146            state.increment_ws_connection("user1".to_string()).await;
147        }
148
149        assert!(!state.can_create_ws_connection("user1").await);
150    }
151
152    #[tokio::test]
153    async fn test_rate_limiting() {
154        let mut config = ConsoleConfig::default();
155        config.rate_limit = 5;
156        let state = AppState::new(&config).await.unwrap();
157
158        for _ in 0..5 {
159            assert!(state.check_rate_limit("user1").await.is_ok());
160        }
161
162        assert!(state.check_rate_limit("user1").await.is_err());
163    }
164}