litellm_rs/utils/sys/
state.rs1#![allow(dead_code)] use once_cell::sync::OnceCell;
9use parking_lot::RwLock;
10use std::sync::Arc;
11
12pub trait SharedResource: Send + Sync + 'static {}
14
15impl<T> SharedResource for T where T: Send + Sync + 'static {}
17
18#[derive(Debug)]
20pub struct Shared<T> {
21 inner: Arc<T>,
22}
23
24impl<T> Shared<T> {
25 pub fn new(value: T) -> Self {
27 Self {
28 inner: Arc::new(value),
29 }
30 }
31
32 pub fn get(&self) -> &T {
34 &self.inner
35 }
36
37 pub fn arc(&self) -> Arc<T> {
39 Arc::clone(&self.inner)
40 }
41
42 pub fn strong_count(&self) -> usize {
44 Arc::strong_count(&self.inner)
45 }
46}
47
48impl<T> Clone for Shared<T> {
49 fn clone(&self) -> Self {
50 Self {
51 inner: Arc::clone(&self.inner),
52 }
53 }
54}
55
56impl<T> std::ops::Deref for Shared<T> {
57 type Target = T;
58
59 fn deref(&self) -> &Self::Target {
60 &self.inner
61 }
62}
63
64#[derive(Debug)]
66pub struct SharedMut<T> {
67 inner: Arc<RwLock<T>>,
68}
69
70impl<T> SharedMut<T> {
71 pub fn new(value: T) -> Self {
73 Self {
74 inner: Arc::new(RwLock::new(value)),
75 }
76 }
77
78 pub fn read(&self) -> parking_lot::RwLockReadGuard<'_, T> {
80 self.inner.read()
81 }
82
83 pub fn write(&self) -> parking_lot::RwLockWriteGuard<'_, T> {
85 self.inner.write()
86 }
87
88 pub fn try_read(&self) -> Option<parking_lot::RwLockReadGuard<'_, T>> {
90 self.inner.try_read()
91 }
92
93 pub fn try_write(&self) -> Option<parking_lot::RwLockWriteGuard<'_, T>> {
95 self.inner.try_write()
96 }
97
98 pub fn arc(&self) -> Arc<RwLock<T>> {
100 Arc::clone(&self.inner)
101 }
102}
103
104impl<T> Clone for SharedMut<T> {
105 fn clone(&self) -> Self {
106 Self {
107 inner: Arc::clone(&self.inner),
108 }
109 }
110}
111
112pub struct GlobalShared<T> {
114 cell: OnceCell<Shared<T>>,
115}
116
117impl<T> GlobalShared<T> {
118 pub const fn new() -> Self {
120 Self {
121 cell: OnceCell::new(),
122 }
123 }
124
125 pub fn init(&self, value: T) -> Result<(), T> {
127 self.cell.set(Shared::new(value)).map_err(|shared| {
128 match Arc::try_unwrap(shared.inner) {
130 Ok(value) => value,
131 Err(_) => panic!("Failed to extract value from shared resource"),
132 }
133 })
134 }
135
136 pub fn get(&self) -> &Shared<T> {
138 self.cell.get().expect("Global resource not initialized")
139 }
140
141 pub fn try_get(&self) -> Option<&Shared<T>> {
143 self.cell.get()
144 }
145}
146
147impl<T> Default for GlobalShared<T> {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153#[macro_export]
155macro_rules! global_shared {
156 ($name:ident: $type:ty) => {
157 static $name: $crate::utils::shared_state::GlobalShared<$type> =
158 $crate::utils::shared_state::GlobalShared::new();
159 };
160}
161
162pub struct SharedBuilder<T> {
164 value: Option<T>,
165}
166
167impl<T> SharedBuilder<T> {
168 pub fn new() -> Self {
170 Self { value: None }
171 }
172
173 pub fn with_value(mut self, value: T) -> Self {
175 self.value = Some(value);
176 self
177 }
178
179 pub fn build(self) -> Option<Shared<T>> {
181 self.value.map(Shared::new)
182 }
183
184 pub fn build_or_panic(self, msg: &str) -> Shared<T> {
186 self.build().expect(msg)
187 }
188}
189
190impl<T> Default for SharedBuilder<T> {
191 fn default() -> Self {
192 Self::new()
193 }
194}
195
196pub mod utils {
198 use super::*;
199
200 pub fn share<T>(value: T) -> Shared<T> {
202 Shared::new(value)
203 }
204
205 pub fn share_mut<T>(value: T) -> SharedMut<T> {
207 SharedMut::new(value)
208 }
209
210 pub fn from_arc<T>(arc: Arc<T>) -> Shared<T> {
212 Shared { inner: arc }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn test_shared_resource() {
222 let shared = Shared::new(42);
223 assert_eq!(*shared.get(), 42);
224 assert_eq!(*shared, 42);
225
226 let cloned = shared.clone();
227 assert_eq!(*cloned.get(), 42);
228 assert_eq!(shared.strong_count(), 2);
229 }
230
231 #[test]
232 fn test_shared_mut_resource() {
233 let shared = SharedMut::new(42);
234
235 {
236 let read_guard = shared.read();
237 assert_eq!(*read_guard, 42);
238 }
239
240 {
241 let mut write_guard = shared.write();
242 *write_guard = 100;
243 }
244
245 {
246 let read_guard = shared.read();
247 assert_eq!(*read_guard, 100);
248 }
249 }
250
251 #[test]
252 fn test_global_shared() {
253 let global: GlobalShared<i32> = GlobalShared::new();
254
255 assert!(global.try_get().is_none());
256
257 global.init(42).unwrap();
258 assert_eq!(**global.get(), 42);
259
260 assert!(global.init(100).is_err());
262 }
263
264 #[test]
265 fn test_shared_builder() {
266 let shared = SharedBuilder::new().with_value(42).build().unwrap();
267
268 assert_eq!(*shared.get(), 42);
269 }
270}