1use std::any::{Any, TypeId, type_name};
18use std::collections::{HashMap, HashSet};
19use uuid::Uuid;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
23pub struct BusTypeRef {
24 pub type_id: TypeId,
25 pub type_name: &'static str,
26}
27
28impl BusTypeRef {
29 pub fn of<T: Any + Send + Sync + 'static>() -> Self {
30 Self {
31 type_id: TypeId::of::<T>(),
32 type_name: type_name::<T>(),
33 }
34 }
35}
36
37#[derive(Debug, Clone, Default)]
42pub struct BusAccessPolicy {
43 pub allow: Option<Vec<BusTypeRef>>,
44 pub deny: Vec<BusTypeRef>,
45}
46
47impl BusAccessPolicy {
48 pub fn allow_only(types: Vec<BusTypeRef>) -> Self {
49 Self {
50 allow: Some(types),
51 deny: Vec::new(),
52 }
53 }
54
55 pub fn deny_only(types: Vec<BusTypeRef>) -> Self {
56 Self {
57 allow: None,
58 deny: types,
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
64struct BusAccessGuard {
65 transition_label: String,
66 allow: Option<HashSet<TypeId>>,
67 allow_names: Vec<&'static str>,
68 deny: HashSet<TypeId>,
69 deny_names: Vec<&'static str>,
70}
71
72impl BusAccessGuard {
73 fn from_policy(transition_label: String, policy: BusAccessPolicy) -> Self {
74 let allow_names = policy
75 .allow
76 .as_ref()
77 .map(|types| types.iter().map(|t| t.type_name).collect::<Vec<_>>())
78 .unwrap_or_default();
79 let allow = policy
80 .allow
81 .map(|types| types.into_iter().map(|t| t.type_id).collect::<HashSet<_>>());
82 let deny_names = policy.deny.iter().map(|t| t.type_name).collect::<Vec<_>>();
83 let deny = policy
84 .deny
85 .into_iter()
86 .map(|type_ref| type_ref.type_id)
87 .collect::<HashSet<_>>();
88 Self {
89 transition_label,
90 allow,
91 allow_names,
92 deny,
93 deny_names,
94 }
95 }
96}
97
98#[derive(Debug, Clone)]
100pub enum BusAccessError {
101 Unauthorized {
102 transition: String,
103 resource: &'static str,
104 allow: Option<Vec<&'static str>>,
105 deny: Vec<&'static str>,
106 },
107 NotFound {
108 resource: &'static str,
109 },
110}
111
112impl std::fmt::Display for BusAccessError {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 match self {
115 BusAccessError::Unauthorized {
116 transition,
117 resource,
118 allow,
119 deny,
120 } => {
121 write!(
122 f,
123 "Bus access denied in transition `{transition}` for resource `{resource}`"
124 )?;
125 if let Some(allow_list) = allow {
126 write!(f, " (allow={allow_list:?})")?;
127 }
128 if !deny.is_empty() {
129 write!(f, " (deny={deny:?})")?;
130 }
131 Ok(())
132 }
133 BusAccessError::NotFound { resource } => {
134 write!(f, "Bus resource not found: `{resource}`")
135 }
136 }
137 }
138}
139
140impl std::error::Error for BusAccessError {}
141
142pub struct Bus {
147 resources: HashMap<std::any::TypeId, Box<dyn Any + Send + Sync>>,
149 pub id: Uuid,
151 access_guard: Option<BusAccessGuard>,
153}
154
155impl Bus {
156 pub fn new() -> Self {
158 Self {
159 resources: HashMap::new(),
160 id: Uuid::new_v4(),
161 access_guard: None,
162 }
163 }
164
165 pub fn insert<T: Any + Send + Sync + 'static>(&mut self, resource: T) {
178 let type_id = std::any::TypeId::of::<T>();
179 self.resources.insert(type_id, Box::new(resource));
180 }
181
182 pub fn read<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
196 match self.get::<T>() {
197 Ok(value) => Some(value),
198 Err(BusAccessError::NotFound { .. }) => None,
199 Err(err) => panic!("{err}"),
200 }
201 }
202
203 pub fn read_mut<T: Any + Send + Sync + 'static>(&mut self) -> Option<&mut T> {
207 match self.get_mut::<T>() {
208 Ok(value) => Some(value),
209 Err(BusAccessError::NotFound { .. }) => None,
210 Err(err) => panic!("{err}"),
211 }
212 }
213
214 pub fn get<T: Any + Send + Sync + 'static>(&self) -> Result<&T, BusAccessError> {
216 self.ensure_access::<T>()?;
217 let type_id = TypeId::of::<T>();
218 self.resources
219 .get(&type_id)
220 .and_then(|r| r.downcast_ref::<T>())
221 .ok_or(BusAccessError::NotFound {
222 resource: type_name::<T>(),
223 })
224 }
225
226 pub fn get_mut<T: Any + Send + Sync + 'static>(&mut self) -> Result<&mut T, BusAccessError> {
228 self.ensure_access::<T>()?;
229 let type_id = TypeId::of::<T>();
230 self.resources
231 .get_mut(&type_id)
232 .and_then(|r| r.downcast_mut::<T>())
233 .ok_or(BusAccessError::NotFound {
234 resource: type_name::<T>(),
235 })
236 }
237
238 pub fn has<T: Any + Send + Sync + 'static>(&self) -> bool {
240 if let Err(err) = self.ensure_access::<T>() {
241 panic!("{err}");
242 }
243 let type_id = std::any::TypeId::of::<T>();
244 self.resources.contains_key(&type_id)
245 }
246
247 pub fn remove<T: Any + Send + Sync + 'static>(&mut self) -> Option<T> {
251 if let Err(err) = self.ensure_access::<T>() {
252 panic!("{err}");
253 }
254 let type_id = std::any::TypeId::of::<T>();
255 self.resources
256 .remove(&type_id)
257 .and_then(|r| r.downcast::<T>().ok().map(|b| *b))
258 }
259
260 pub fn len(&self) -> usize {
262 self.resources.len()
263 }
264
265 pub fn is_empty(&self) -> bool {
267 self.resources.is_empty()
268 }
269
270 pub fn set_access_policy(
272 &mut self,
273 transition_label: impl Into<String>,
274 policy: Option<BusAccessPolicy>,
275 ) {
276 self.access_guard =
277 policy.map(|policy| BusAccessGuard::from_policy(transition_label.into(), policy));
278 }
279
280 pub fn clear_access_policy(&mut self) {
282 self.access_guard = None;
283 }
284
285 fn ensure_access<T: Any + Send + Sync + 'static>(&self) -> Result<(), BusAccessError> {
286 let Some(guard) = &self.access_guard else {
287 return Ok(());
288 };
289
290 let requested = TypeId::of::<T>();
291 if guard.deny.contains(&requested) {
292 return Err(BusAccessError::Unauthorized {
293 transition: guard.transition_label.clone(),
294 resource: type_name::<T>(),
295 allow: if guard.allow_names.is_empty() {
296 None
297 } else {
298 Some(guard.allow_names.clone())
299 },
300 deny: guard.deny_names.clone(),
301 });
302 }
303
304 if let Some(allow) = &guard.allow {
305 if !allow.contains(&requested) {
306 return Err(BusAccessError::Unauthorized {
307 transition: guard.transition_label.clone(),
308 resource: type_name::<T>(),
309 allow: Some(guard.allow_names.clone()),
310 deny: guard.deny_names.clone(),
311 });
312 }
313 }
314
315 Ok(())
316 }
317}
318
319impl Default for Bus {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
327pub struct ConnectionId(pub Uuid);
328
329impl ConnectionId {
330 pub fn new() -> Self {
331 Self(Uuid::new_v4())
332 }
333}
334
335pub struct ConnectionBus {
340 pub bus: Bus,
342 pub id: ConnectionId,
344}
345
346impl ConnectionBus {
347 pub fn new(id: ConnectionId) -> Self {
349 Self {
350 bus: Bus::new(),
351 id,
352 }
353 }
354
355 pub fn from_bus(id: ConnectionId, bus: Bus) -> Self {
357 Self { bus, id }
358 }
359
360 pub fn connection_id(&self) -> ConnectionId {
362 self.id
363 }
364}
365
366impl std::ops::Deref for ConnectionBus {
367 type Target = Bus;
368
369 fn deref(&self) -> &Self::Target {
370 &self.bus
371 }
372}
373
374impl std::ops::DerefMut for ConnectionBus {
375 fn deref_mut(&mut self) -> &mut Self::Target {
376 &mut self.bus
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_insert_and_read() {
386 let mut bus = Bus::new();
387 bus.insert(42i32);
388
389 assert!(bus.has::<i32>());
390 assert_eq!(*bus.read::<i32>().unwrap(), 42);
391 }
392
393 #[test]
394 fn test_read_none() {
395 let bus = Bus::new();
396 assert!(bus.read::<i32>().is_none());
397 }
398
399 #[test]
400 fn test_remove() {
401 let mut bus = Bus::new();
402 bus.insert(42i32);
403
404 let value = bus.remove::<i32>();
405 assert_eq!(value, Some(42));
406 assert!(!bus.has::<i32>());
407 }
408
409 #[test]
410 fn test_multiple_types() {
411 let mut bus = Bus::new();
412 bus.insert(42i32);
413 bus.insert("hello".to_string());
414
415 assert_eq!(*bus.read::<i32>().unwrap(), 42);
416 assert_eq!(bus.read::<String>().unwrap(), "hello");
417 }
418
419 #[test]
420 fn bus_policy_allow_only_blocks_unauthorized_get() {
421 let mut bus = Bus::new();
422 bus.insert(42i32);
423 bus.insert("hello".to_string());
424 bus.set_access_policy(
425 "OnlyInt",
426 Some(BusAccessPolicy::allow_only(vec![BusTypeRef::of::<i32>()])),
427 );
428
429 let err = bus.get::<String>().expect_err("String should be denied");
430 assert!(err.to_string().contains("OnlyInt"));
431 assert!(err.to_string().contains("alloc::string::String"));
432 }
433
434 #[test]
435 fn bus_policy_deny_only_blocks_explicit_type() {
436 let mut bus = Bus::new();
437 bus.insert(42i32);
438 bus.insert("hello".to_string());
439 bus.set_access_policy(
440 "DenyString",
441 Some(BusAccessPolicy::deny_only(vec![BusTypeRef::of::<String>()])),
442 );
443
444 let err = bus.get::<String>().expect_err("String should be denied");
445 assert!(err.to_string().contains("DenyString"));
446 }
447
448 #[test]
449 fn test_connection_bus() {
450 let id = ConnectionId::new();
451 let conn = ConnectionBus::new(id);
452
453 assert_eq!(conn.connection_id(), id);
454 }
455}