Skip to main content

orpc_procedure/
state.rs

1use std::any::{Any, TypeId};
2use std::collections::HashMap;
3use std::hash::{BuildHasherDefault, Hasher};
4
5/// A passthrough hasher optimized for `TypeId` keys.
6///
7/// `TypeId` values are already well-distributed hashes, so re-hashing
8/// them is wasted work. This hasher passes through the `u64` value.
9///
10/// Handles both `write_u64` (current TypeId impl) and `write` (future-proof fallback).
11#[derive(Default)]
12struct IdHasher(u64);
13
14impl Hasher for IdHasher {
15    fn write(&mut self, bytes: &[u8]) {
16        // Fallback: fold bytes into a u64. This handles the case where
17        // TypeId changes its Hasher::write behavior.
18        let mut value = 0u64;
19        for (i, &b) in bytes.iter().enumerate() {
20            value ^= (b as u64) << ((i % 8) * 8);
21        }
22        self.0 = value;
23    }
24
25    fn write_u64(&mut self, i: u64) {
26        self.0 = i;
27    }
28
29    fn finish(&self) -> u64 {
30        self.0
31    }
32}
33
34type TypeMap = HashMap<TypeId, Box<dyn Any + Send + Sync>, BuildHasherDefault<IdHasher>>;
35
36/// Type-safe heterogeneous state container.
37///
38/// Stores values keyed by their `TypeId`, allowing type-safe insertion and retrieval.
39/// Used for dependency injection and cross-procedure shared state.
40///
41/// Follows rspc's `State` pattern with `NoOpHasher` optimization.
42pub struct State(TypeMap);
43
44impl State {
45    pub fn new() -> Self {
46        State(HashMap::default())
47    }
48
49    /// Insert a value. Replaces any existing value of the same type.
50    pub fn insert<T: Send + Sync + 'static>(&mut self, value: T) {
51        self.0.insert(TypeId::of::<T>(), Box::new(value));
52    }
53
54    /// Get a reference to a stored value by type.
55    pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
56        self.0
57            .get(&TypeId::of::<T>())
58            .and_then(|v| v.downcast_ref())
59    }
60
61    /// Get a mutable reference to a stored value by type.
62    pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
63        self.0
64            .get_mut(&TypeId::of::<T>())
65            .and_then(|v| v.downcast_mut())
66    }
67
68    /// Check if a value of the given type exists.
69    pub fn contains<T: Send + Sync + 'static>(&self) -> bool {
70        self.0.contains_key(&TypeId::of::<T>())
71    }
72
73    /// Remove and return a stored value by type.
74    pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
75        self.0
76            .remove(&TypeId::of::<T>())
77            .and_then(|v| v.downcast().ok())
78            .map(|v| *v)
79    }
80}
81
82impl Default for State {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl std::fmt::Debug for State {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_tuple("State")
91            .field(&format!("{} entries", self.0.len()))
92            .finish()
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn insert_and_get() {
102        let mut state = State::new();
103        state.insert(42u32);
104        assert_eq!(state.get::<u32>(), Some(&42));
105    }
106
107    #[test]
108    fn insert_replaces_existing() {
109        let mut state = State::new();
110        state.insert(1u32);
111        state.insert(2u32);
112        assert_eq!(state.get::<u32>(), Some(&2));
113    }
114
115    #[test]
116    fn get_missing_returns_none() {
117        let state = State::new();
118        assert_eq!(state.get::<u32>(), None);
119    }
120
121    #[test]
122    fn get_mut_allows_modification() {
123        let mut state = State::new();
124        state.insert(String::from("hello"));
125        if let Some(s) = state.get_mut::<String>() {
126            s.push_str(" world");
127        }
128        assert_eq!(state.get::<String>().unwrap(), "hello world");
129    }
130
131    #[test]
132    fn contains() {
133        let mut state = State::new();
134        assert!(!state.contains::<u32>());
135        state.insert(42u32);
136        assert!(state.contains::<u32>());
137    }
138
139    #[test]
140    fn remove_returns_owned_value() {
141        let mut state = State::new();
142        state.insert(String::from("removed"));
143        let removed = state.remove::<String>();
144        assert_eq!(removed, Some(String::from("removed")));
145        assert!(!state.contains::<String>());
146    }
147
148    #[test]
149    fn multiple_types() {
150        let mut state = State::new();
151        state.insert(42u32);
152        state.insert("hello");
153        state.insert(vec![1, 2, 3]);
154
155        assert_eq!(state.get::<u32>(), Some(&42));
156        assert_eq!(state.get::<&str>(), Some(&"hello"));
157        assert_eq!(state.get::<Vec<i32>>(), Some(&vec![1, 2, 3]));
158    }
159
160    #[test]
161    fn debug_output() {
162        let mut state = State::new();
163        state.insert(1u32);
164        state.insert("hello");
165        let debug = format!("{state:?}");
166        assert!(debug.contains("2 entries"));
167    }
168
169    #[test]
170    fn state_is_send_sync() {
171        fn assert_send_sync<T: Send + Sync>() {}
172        assert_send_sync::<State>();
173    }
174}