1use std::any::{Any, TypeId, type_name};
18use std::collections::HashSet;
19use std::sync::Arc;
20
21use ahash::AHashMap;
22use uuid::Uuid;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
26pub struct BusTypeRef {
27 pub type_id: TypeId,
28 pub type_name: &'static str,
29}
30
31impl BusTypeRef {
32 pub fn of<T: Any + Send + Sync + 'static>() -> Self {
33 Self {
34 type_id: TypeId::of::<T>(),
35 type_name: type_name::<T>(),
36 }
37 }
38}
39
40#[derive(Debug, Clone, Default)]
45pub struct BusAccessPolicy {
46 pub allow: Option<Vec<BusTypeRef>>,
47 pub deny: Vec<BusTypeRef>,
48}
49
50impl BusAccessPolicy {
51 pub fn allow_only(types: Vec<BusTypeRef>) -> Self {
52 Self {
53 allow: Some(types),
54 deny: Vec::new(),
55 }
56 }
57
58 pub fn deny_only(types: Vec<BusTypeRef>) -> Self {
59 Self {
60 allow: None,
61 deny: types,
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
67struct BusAccessGuard {
68 transition_label: Arc<str>,
69 allow: Option<HashSet<TypeId>>,
70 allow_names: Arc<[&'static str]>,
71 deny: HashSet<TypeId>,
72 deny_names: Arc<[&'static str]>,
73}
74
75impl BusAccessGuard {
76 fn from_policy(transition_label: String, policy: BusAccessPolicy) -> Self {
77 let allow_names: Arc<[&'static str]> = policy
78 .allow
79 .as_ref()
80 .map(|types| types.iter().map(|t| t.type_name).collect())
81 .unwrap_or_default();
82 let allow = policy
83 .allow
84 .map(|types| types.into_iter().map(|t| t.type_id).collect::<HashSet<_>>());
85 let deny_names: Arc<[&'static str]> = policy.deny.iter().map(|t| t.type_name).collect();
86 let deny = policy
87 .deny
88 .into_iter()
89 .map(|type_ref| type_ref.type_id)
90 .collect::<HashSet<_>>();
91 Self {
92 transition_label: transition_label.into(),
93 allow,
94 allow_names,
95 deny,
96 deny_names,
97 }
98 }
99}
100
101#[derive(Debug, Clone)]
103pub enum BusAccessError {
104 Unauthorized {
105 transition: String,
106 resource: &'static str,
107 allow: Option<Vec<&'static str>>,
108 deny: Vec<&'static str>,
109 },
110 NotFound {
111 resource: &'static str,
112 },
113}
114
115impl std::fmt::Display for BusAccessError {
116 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117 match self {
118 BusAccessError::Unauthorized {
119 transition,
120 resource,
121 allow,
122 deny,
123 } => {
124 write!(
125 f,
126 "Bus access denied in transition `{transition}` for resource `{resource}`"
127 )?;
128 if let Some(allow_list) = allow {
129 write!(f, " (allow={allow_list:?})")?;
130 }
131 if !deny.is_empty() {
132 write!(f, " (deny={deny:?})")?;
133 }
134 Ok(())
135 }
136 BusAccessError::NotFound { resource } => {
137 write!(f, "Bus resource not found: `{resource}`")
138 }
139 }
140 }
141}
142
143impl std::error::Error for BusAccessError {}
144
145pub struct Bus {
150 resources: AHashMap<std::any::TypeId, Box<dyn Any + Send + Sync>>,
152 pub id: Uuid,
154 access_guard: Option<BusAccessGuard>,
156}
157
158impl Bus {
159 #[inline]
161 pub fn new() -> Self {
162 Self {
163 resources: AHashMap::new(),
164 id: Uuid::new_v4(),
165 access_guard: None,
166 }
167 }
168
169 #[inline]
182 pub fn insert<T: Any + Send + Sync + 'static>(&mut self, resource: T) {
183 let type_id = std::any::TypeId::of::<T>();
184 self.resources.insert(type_id, Box::new(resource));
185 }
186
187 #[inline]
201 pub fn read<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
202 match self.get::<T>() {
203 Ok(value) => Some(value),
204 Err(BusAccessError::NotFound { .. }) => None,
205 Err(err) => panic!("{err}"),
206 }
207 }
208
209 #[inline]
213 pub fn read_mut<T: Any + Send + Sync + 'static>(&mut self) -> Option<&mut T> {
214 match self.get_mut::<T>() {
215 Ok(value) => Some(value),
216 Err(BusAccessError::NotFound { .. }) => None,
217 Err(err) => panic!("{err}"),
218 }
219 }
220
221 #[inline]
223 pub fn get<T: Any + Send + Sync + 'static>(&self) -> Result<&T, BusAccessError> {
224 self.ensure_access::<T>()?;
225 let type_id = TypeId::of::<T>();
226 self.resources
227 .get(&type_id)
228 .and_then(|r| r.downcast_ref::<T>())
229 .ok_or(BusAccessError::NotFound {
230 resource: type_name::<T>(),
231 })
232 }
233
234 #[inline]
236 pub fn get_mut<T: Any + Send + Sync + 'static>(&mut self) -> Result<&mut T, BusAccessError> {
237 self.ensure_access::<T>()?;
238 let type_id = TypeId::of::<T>();
239 self.resources
240 .get_mut(&type_id)
241 .and_then(|r| r.downcast_mut::<T>())
242 .ok_or(BusAccessError::NotFound {
243 resource: type_name::<T>(),
244 })
245 }
246
247 #[inline]
249 pub fn has<T: Any + Send + Sync + 'static>(&self) -> bool {
250 if let Err(err) = self.ensure_access::<T>() {
251 panic!("{err}");
252 }
253 let type_id = std::any::TypeId::of::<T>();
254 self.resources.contains_key(&type_id)
255 }
256
257 pub fn remove<T: Any + Send + Sync + 'static>(&mut self) -> Option<T> {
261 if let Err(err) = self.ensure_access::<T>() {
262 panic!("{err}");
263 }
264 let type_id = std::any::TypeId::of::<T>();
265 self.resources
266 .remove(&type_id)
267 .and_then(|r| r.downcast::<T>().ok().map(|b| *b))
268 }
269
270 pub fn len(&self) -> usize {
272 self.resources.len()
273 }
274
275 pub fn is_empty(&self) -> bool {
277 self.resources.is_empty()
278 }
279
280 #[inline]
294 pub fn provide<T: Any + Send + Sync + 'static>(&mut self, resource: T) {
295 self.insert(resource);
296 }
297
298 #[inline]
317 pub fn require<T: Any + Send + Sync + 'static>(&self) -> &T {
318 self.read::<T>().unwrap_or_else(|| {
319 panic!(
320 "Bus: required resource `{}` not found. Did you forget to call bus.provide()?",
321 std::any::type_name::<T>()
322 )
323 })
324 }
325
326 #[inline]
339 pub fn try_require<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
340 self.read::<T>()
341 }
342
343 pub fn set_access_policy(
345 &mut self,
346 transition_label: impl Into<String>,
347 policy: Option<BusAccessPolicy>,
348 ) {
349 self.access_guard =
350 policy.map(|policy| BusAccessGuard::from_policy(transition_label.into(), policy));
351 }
352
353 pub fn clear_access_policy(&mut self) {
355 self.access_guard = None;
356 }
357
358 #[inline]
359 fn ensure_access<T: Any + Send + Sync + 'static>(&self) -> Result<(), BusAccessError> {
360 let Some(guard) = &self.access_guard else {
361 return Ok(());
362 };
363
364 let requested = TypeId::of::<T>();
365 if guard.deny.contains(&requested) {
366 return Err(BusAccessError::Unauthorized {
367 transition: guard.transition_label.to_string(),
368 resource: type_name::<T>(),
369 allow: if guard.allow_names.is_empty() {
370 None
371 } else {
372 Some(guard.allow_names.to_vec())
373 },
374 deny: guard.deny_names.to_vec(),
375 });
376 }
377
378 if let Some(allow) = &guard.allow
379 && !allow.contains(&requested)
380 {
381 return Err(BusAccessError::Unauthorized {
382 transition: guard.transition_label.to_string(),
383 resource: type_name::<T>(),
384 allow: Some(guard.allow_names.to_vec()),
385 deny: guard.deny_names.to_vec(),
386 });
387 }
388
389 Ok(())
390 }
391}
392
393impl Default for Bus {
394 fn default() -> Self {
395 Self::new()
396 }
397}
398
399#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
401pub struct ConnectionId(pub Uuid);
402
403impl Default for ConnectionId {
404 fn default() -> Self {
405 Self(Uuid::new_v4())
406 }
407}
408
409impl ConnectionId {
410 pub fn new() -> Self {
411 Self::default()
412 }
413}
414
415pub struct ConnectionBus {
420 pub bus: Bus,
422 pub id: ConnectionId,
424}
425
426impl ConnectionBus {
427 pub fn new(id: ConnectionId) -> Self {
429 Self {
430 bus: Bus::new(),
431 id,
432 }
433 }
434
435 pub fn from_bus(id: ConnectionId, bus: Bus) -> Self {
437 Self { bus, id }
438 }
439
440 pub fn connection_id(&self) -> ConnectionId {
442 self.id
443 }
444}
445
446impl std::ops::Deref for ConnectionBus {
447 type Target = Bus;
448
449 fn deref(&self) -> &Self::Target {
450 &self.bus
451 }
452}
453
454impl std::ops::DerefMut for ConnectionBus {
455 fn deref_mut(&mut self) -> &mut Self::Target {
456 &mut self.bus
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[test]
465 fn test_insert_and_read() {
466 let mut bus = Bus::new();
467 bus.insert(42i32);
468
469 assert!(bus.has::<i32>());
470 assert_eq!(*bus.read::<i32>().unwrap(), 42);
471 }
472
473 #[test]
474 fn test_read_none() {
475 let bus = Bus::new();
476 assert!(bus.read::<i32>().is_none());
477 }
478
479 #[test]
480 fn test_remove() {
481 let mut bus = Bus::new();
482 bus.insert(42i32);
483
484 let value = bus.remove::<i32>();
485 assert_eq!(value, Some(42));
486 assert!(!bus.has::<i32>());
487 }
488
489 #[test]
490 fn test_multiple_types() {
491 let mut bus = Bus::new();
492 bus.insert(42i32);
493 bus.insert("hello".to_string());
494
495 assert_eq!(*bus.read::<i32>().unwrap(), 42);
496 assert_eq!(bus.read::<String>().unwrap(), "hello");
497 }
498
499 #[test]
500 fn bus_policy_allow_only_blocks_unauthorized_get() {
501 let mut bus = Bus::new();
502 bus.insert(42i32);
503 bus.insert("hello".to_string());
504 bus.set_access_policy(
505 "OnlyInt",
506 Some(BusAccessPolicy::allow_only(vec![BusTypeRef::of::<i32>()])),
507 );
508
509 let err = bus.get::<String>().expect_err("String should be denied");
510 assert!(err.to_string().contains("OnlyInt"));
511 assert!(err.to_string().contains("alloc::string::String"));
512 }
513
514 #[test]
515 fn bus_policy_deny_only_blocks_explicit_type() {
516 let mut bus = Bus::new();
517 bus.insert(42i32);
518 bus.insert("hello".to_string());
519 bus.set_access_policy(
520 "DenyString",
521 Some(BusAccessPolicy::deny_only(vec![BusTypeRef::of::<String>()])),
522 );
523
524 let err = bus.get::<String>().expect_err("String should be denied");
525 assert!(err.to_string().contains("DenyString"));
526 }
527
528 #[test]
529 fn test_connection_bus() {
530 let id = ConnectionId::new();
531 let conn = ConnectionBus::new(id);
532
533 assert_eq!(conn.connection_id(), id);
534 }
535
536 #[test]
537 fn provide_and_require_round_trip() {
538 let mut bus = Bus::new();
539 bus.provide(42i32);
540 assert_eq!(*bus.require::<i32>(), 42);
541 }
542
543 #[test]
544 #[should_panic(expected = "required resource")]
545 fn require_panics_with_helpful_message_when_missing() {
546 let bus = Bus::new();
547 let _ = bus.require::<String>();
548 }
549
550 #[test]
551 fn try_require_returns_none_when_missing() {
552 let bus = Bus::new();
553 assert!(bus.try_require::<i32>().is_none());
554 }
555
556 #[test]
557 fn try_require_returns_some_when_present() {
558 let mut bus = Bus::new();
559 bus.provide("hello".to_string());
560 assert_eq!(bus.try_require::<String>().unwrap(), "hello");
561 }
562
563 #[test]
564 fn test_reinsertion_overwrites_previous_value() {
565 let mut bus = Bus::new();
566 bus.insert(42i32);
567 assert_eq!(*bus.read::<i32>().unwrap(), 42);
568
569 bus.insert(100i32);
570 assert_eq!(*bus.read::<i32>().unwrap(), 100);
571 }
572
573 #[test]
574 fn test_remove_then_read_returns_none() {
575 let mut bus = Bus::new();
576 bus.insert(42i32);
577 assert!(bus.read::<i32>().is_some());
578
579 let removed = bus.remove::<i32>();
580 assert_eq!(removed, Some(42));
581 assert!(bus.read::<i32>().is_none());
582 }
583
584 #[test]
585 fn test_is_empty_after_insertions_and_removals() {
586 let mut bus = Bus::new();
587 assert!(bus.is_empty());
588 assert_eq!(bus.len(), 0);
589
590 bus.insert(42i32);
591 assert!(!bus.is_empty());
592 assert_eq!(bus.len(), 1);
593
594 bus.insert("hello".to_string());
595 assert!(!bus.is_empty());
596 assert_eq!(bus.len(), 2);
597
598 bus.remove::<i32>();
599 assert!(!bus.is_empty());
600 assert_eq!(bus.len(), 1);
601
602 bus.remove::<String>();
603 assert!(bus.is_empty());
604 assert_eq!(bus.len(), 0);
605 }
606
607 #[test]
608 fn test_read_mut_modifies_value_in_place() {
609 let mut bus = Bus::new();
610 bus.insert(42i32);
611
612 if let Some(value) = bus.read_mut::<i32>() {
613 *value = 100;
614 }
615
616 assert_eq!(*bus.read::<i32>().unwrap(), 100);
617 }
618
619 #[test]
620 fn test_multiple_types_coexist() {
621 let mut bus = Bus::new();
622 bus.insert(42i32);
623 bus.insert(3.14f64);
624 bus.insert("hello".to_string());
625 bus.insert(true);
626
627 assert!(bus.has::<i32>());
628 assert!(bus.has::<f64>());
629 assert!(bus.has::<String>());
630 assert!(bus.has::<bool>());
631
632 assert_eq!(*bus.read::<i32>().unwrap(), 42);
633 assert_eq!(*bus.read::<f64>().unwrap(), 3.14);
634 assert_eq!(bus.read::<String>().unwrap(), "hello");
635 assert_eq!(*bus.read::<bool>().unwrap(), true);
636 }
637}