Skip to main content

tirea_contract/runtime/
extensions.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3
4/// Type-keyed extension map for [`StepContext`](super::phase::StepContext).
5///
6/// Each slot is keyed by `TypeId` and holds one value of that type.
7/// Plugins insert domain-specific state (e.g. `InferenceContext`, `ToolGate`)
8/// via the [`Action::apply`](super::Action::apply) method; the loop reads
9/// them back after a phase completes.
10pub struct Extensions {
11    map: HashMap<TypeId, Box<dyn Any + Send>>,
12}
13
14impl Extensions {
15    pub fn new() -> Self {
16        Self {
17            map: HashMap::new(),
18        }
19    }
20
21    /// Get a shared reference to the value of type `T`.
22    pub fn get<T: 'static + Send>(&self) -> Option<&T> {
23        self.map
24            .get(&TypeId::of::<T>())
25            .and_then(|boxed| boxed.downcast_ref::<T>())
26    }
27
28    /// Get a mutable reference to the value of type `T`.
29    pub fn get_mut<T: 'static + Send>(&mut self) -> Option<&mut T> {
30        self.map
31            .get_mut(&TypeId::of::<T>())
32            .and_then(|boxed| boxed.downcast_mut::<T>())
33    }
34
35    /// Get a mutable reference, inserting `T::default()` if absent.
36    pub fn get_or_default<T: 'static + Send + Default>(&mut self) -> &mut T {
37        self.map
38            .entry(TypeId::of::<T>())
39            .or_insert_with(|| Box::new(T::default()))
40            .downcast_mut::<T>()
41            .expect("type mismatch in Extensions (impossible)")
42    }
43
44    /// Insert a value, returning any previous value of the same type.
45    pub fn insert<T: 'static + Send>(&mut self, val: T) -> Option<T> {
46        self.map
47            .insert(TypeId::of::<T>(), Box::new(val))
48            .and_then(|prev| prev.downcast::<T>().ok())
49            .map(|boxed| *boxed)
50    }
51
52    /// Remove all entries.
53    pub fn clear(&mut self) {
54        self.map.clear();
55    }
56
57    /// Check if the map contains a value of type `T`.
58    pub fn contains<T: 'static + Send>(&self) -> bool {
59        self.map.contains_key(&TypeId::of::<T>())
60    }
61}
62
63impl Default for Extensions {
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl std::fmt::Debug for Extensions {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("Extensions")
72            .field("len", &self.map.len())
73            .finish()
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn insert_and_get() {
83        let mut ext = Extensions::new();
84        ext.insert(42u32);
85        assert_eq!(ext.get::<u32>(), Some(&42));
86    }
87
88    #[test]
89    fn get_mut() {
90        let mut ext = Extensions::new();
91        ext.insert(String::from("hello"));
92        ext.get_mut::<String>().unwrap().push_str(" world");
93        assert_eq!(ext.get::<String>().unwrap(), "hello world");
94    }
95
96    #[test]
97    fn get_or_default() {
98        let mut ext = Extensions::new();
99        let val: &mut Vec<i32> = ext.get_or_default();
100        val.push(1);
101        assert_eq!(ext.get::<Vec<i32>>().unwrap(), &vec![1]);
102    }
103
104    #[test]
105    fn insert_replaces() {
106        let mut ext = Extensions::new();
107        ext.insert(1u32);
108        let prev = ext.insert(2u32);
109        assert_eq!(prev, Some(1));
110        assert_eq!(ext.get::<u32>(), Some(&2));
111    }
112
113    #[test]
114    fn clear_removes_all() {
115        let mut ext = Extensions::new();
116        ext.insert(1u32);
117        ext.insert(String::from("x"));
118        ext.clear();
119        assert!(!ext.contains::<u32>());
120        assert!(!ext.contains::<String>());
121    }
122
123    #[test]
124    fn different_types_coexist() {
125        let mut ext = Extensions::new();
126        ext.insert(42u32);
127        ext.insert(String::from("hello"));
128        ext.insert(true);
129        assert_eq!(ext.get::<u32>(), Some(&42));
130        assert_eq!(ext.get::<String>().unwrap(), "hello");
131        assert_eq!(ext.get::<bool>(), Some(&true));
132    }
133
134    #[test]
135    fn missing_type_returns_none() {
136        let ext = Extensions::new();
137        assert!(ext.get::<u32>().is_none());
138    }
139}