Skip to main content

bob_chat/
state.rs

1//! # Type-Safe State Container
2//!
3//! A lightweight, type-safe dependency injection container for sharing state
4//! across chat event handlers without manual `Arc` cloning.
5//!
6//! Inspired by `actix-web::Extensions` and `http::Extensions`.
7//!
8//! ## Example
9//!
10//! ```rust,ignore
11//! use bob_chat::state::AppState;
12//!
13//! let mut state = AppState::new();
14//! state.insert(DatabasePool::new());
15//! state.insert(Config::load());
16//!
17//! // In a handler:
18//! let db = state.get::<DatabasePool>().unwrap();
19//! ```
20
21use std::{
22    any::{Any, TypeId},
23    collections::HashMap,
24};
25
26/// Type-safe state container using `TypeId` as key.
27///
28/// Stores one value per type. Inserting a value of type `T` replaces
29/// any previous value of the same type.
30#[derive(Default)]
31pub struct AppState {
32    map: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
33}
34
35impl std::fmt::Debug for AppState {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        let type_names: Vec<_> = self.map.keys().map(|id| format!("{id:?}")).collect();
38        f.debug_struct("AppState").field("types", &type_names).finish()
39    }
40}
41
42impl AppState {
43    /// Create an empty state container.
44    #[must_use]
45    pub fn new() -> Self {
46        Self { map: HashMap::new() }
47    }
48
49    /// Insert a value of type `T`. Replaces any previous value of the same type.
50    pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) {
51        self.map.insert(TypeId::of::<T>(), Box::new(val));
52    }
53
54    /// Get an immutable reference to the value of type `T`.
55    #[must_use]
56    pub fn get<T: 'static>(&self) -> Option<&T> {
57        self.map.get(&TypeId::of::<T>()).and_then(|boxed| boxed.downcast_ref())
58    }
59
60    /// Get a mutable reference to the value of type `T`.
61    pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
62        self.map.get_mut(&TypeId::of::<T>()).and_then(|boxed| boxed.downcast_mut())
63    }
64
65    /// Remove and return the value of type `T`.
66    pub fn remove<T: 'static>(&mut self) -> Option<T> {
67        self.map
68            .remove(&TypeId::of::<T>())
69            .and_then(|boxed| boxed.downcast().ok())
70            .map(|boxed| *boxed)
71    }
72
73    /// Check if a value of type `T` is present.
74    #[must_use]
75    pub fn contains<T: 'static>(&self) -> bool {
76        self.map.contains_key(&TypeId::of::<T>())
77    }
78
79    /// Return the number of stored values.
80    #[must_use]
81    pub fn len(&self) -> usize {
82        self.map.len()
83    }
84
85    /// Return `true` if no values are stored.
86    #[must_use]
87    pub fn is_empty(&self) -> bool {
88        self.map.is_empty()
89    }
90
91    /// Clear all stored values.
92    pub fn clear(&mut self) {
93        self.map.clear();
94    }
95}
96
97// ── Tests ────────────────────────────────────────────────────────────
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102
103    #[derive(Debug, Clone, PartialEq, Eq)]
104    struct DbPool(String);
105
106    #[derive(Debug, Clone, PartialEq, Eq)]
107    struct Config {
108        max_connections: u32,
109    }
110
111    #[test]
112    fn insert_and_get() {
113        let mut state = AppState::new();
114        state.insert(DbPool("postgres://localhost".into()));
115
116        let pool = state.get::<DbPool>().unwrap();
117        assert_eq!(pool.0, "postgres://localhost");
118    }
119
120    #[test]
121    fn different_types_coexist() {
122        let mut state = AppState::new();
123        state.insert(DbPool("pg".into()));
124        state.insert(Config { max_connections: 10 });
125
126        assert_eq!(state.get::<DbPool>().unwrap().0, "pg");
127        assert_eq!(state.get::<Config>().unwrap().max_connections, 10);
128    }
129
130    #[test]
131    fn insert_replaces_previous() {
132        let mut state = AppState::new();
133        state.insert(DbPool("old".into()));
134        state.insert(DbPool("new".into()));
135
136        assert_eq!(state.get::<DbPool>().unwrap().0, "new");
137    }
138
139    #[test]
140    fn get_missing_returns_none() {
141        let state = AppState::new();
142        assert!(state.get::<DbPool>().is_none());
143    }
144
145    #[test]
146    fn remove_returns_value() {
147        let mut state = AppState::new();
148        state.insert(DbPool("pg".into()));
149
150        let removed = state.remove::<DbPool>().unwrap();
151        assert_eq!(removed.0, "pg");
152        assert!(state.get::<DbPool>().is_none());
153    }
154
155    #[test]
156    fn contains_check() {
157        let mut state = AppState::new();
158        assert!(!state.contains::<DbPool>());
159
160        state.insert(DbPool("pg".into()));
161        assert!(state.contains::<DbPool>());
162    }
163
164    #[test]
165    fn len_and_is_empty() {
166        let mut state = AppState::new();
167        assert!(state.is_empty());
168        assert_eq!(state.len(), 0);
169
170        state.insert(DbPool("pg".into()));
171        state.insert(Config { max_connections: 5 });
172        assert_eq!(state.len(), 2);
173        assert!(!state.is_empty());
174    }
175
176    #[test]
177    fn clear_removes_all() {
178        let mut state = AppState::new();
179        state.insert(DbPool("pg".into()));
180        state.insert(Config { max_connections: 5 });
181
182        state.clear();
183        assert!(state.is_empty());
184    }
185
186    #[test]
187    fn get_mut_allows_mutation() {
188        let mut state = AppState::new();
189        state.insert(Config { max_connections: 5 });
190
191        state.get_mut::<Config>().unwrap().max_connections = 20;
192        assert_eq!(state.get::<Config>().unwrap().max_connections, 20);
193    }
194
195    // AppState must be Send + Sync for use across async tasks.
196    const _: () = {
197        const fn assert_send_sync<T: Send + Sync>() {}
198        assert_send_sync::<AppState>();
199    };
200}