litellm_rs/utils/sys/
state.rs1#![allow(dead_code)] use parking_lot::RwLock;
9use std::sync::Arc;
10use std::sync::OnceLock;
11
12pub trait SharedResource: Send + Sync + 'static {}
14
15impl<T> SharedResource for T where T: Send + Sync + 'static {}
17
18#[derive(Debug)]
20pub struct Shared<T> {
21 inner: Arc<T>,
22}
23
24impl<T> Shared<T> {
25 pub fn new(value: T) -> Self {
27 Self {
28 inner: Arc::new(value),
29 }
30 }
31
32 pub fn get(&self) -> &T {
34 &self.inner
35 }
36
37 pub fn arc(&self) -> Arc<T> {
39 Arc::clone(&self.inner)
40 }
41
42 pub fn strong_count(&self) -> usize {
44 Arc::strong_count(&self.inner)
45 }
46}
47
48impl<T> Clone for Shared<T> {
49 fn clone(&self) -> Self {
50 Self {
51 inner: Arc::clone(&self.inner),
52 }
53 }
54}
55
56impl<T> std::ops::Deref for Shared<T> {
57 type Target = T;
58
59 fn deref(&self) -> &Self::Target {
60 &self.inner
61 }
62}
63
64#[derive(Debug)]
66pub struct SharedMut<T> {
67 inner: Arc<RwLock<T>>,
68}
69
70impl<T> SharedMut<T> {
71 pub fn new(value: T) -> Self {
73 Self {
74 inner: Arc::new(RwLock::new(value)),
75 }
76 }
77
78 pub fn read(&self) -> parking_lot::RwLockReadGuard<'_, T> {
80 self.inner.read()
81 }
82
83 pub fn write(&self) -> parking_lot::RwLockWriteGuard<'_, T> {
85 self.inner.write()
86 }
87
88 pub fn try_read(&self) -> Option<parking_lot::RwLockReadGuard<'_, T>> {
90 self.inner.try_read()
91 }
92
93 pub fn try_write(&self) -> Option<parking_lot::RwLockWriteGuard<'_, T>> {
95 self.inner.try_write()
96 }
97
98 pub fn arc(&self) -> Arc<RwLock<T>> {
100 Arc::clone(&self.inner)
101 }
102}
103
104impl<T> Clone for SharedMut<T> {
105 fn clone(&self) -> Self {
106 Self {
107 inner: Arc::clone(&self.inner),
108 }
109 }
110}
111
112pub struct GlobalShared<T> {
114 cell: OnceLock<Shared<T>>,
115}
116
117impl<T> GlobalShared<T> {
118 pub const fn new() -> Self {
120 Self {
121 cell: OnceLock::new(),
122 }
123 }
124
125 pub fn init(&self, value: T) -> Result<(), T> {
127 self.cell.set(Shared::new(value)).map_err(|shared| {
128 Arc::try_unwrap(shared.inner)
131 .unwrap_or_else(|_| unreachable!("freshly created Arc should have refcount 1"))
132 })
133 }
134
135 pub fn get(&self) -> &Shared<T> {
137 self.cell.get().expect("Global resource not initialized")
138 }
139
140 pub fn try_get(&self) -> Option<&Shared<T>> {
142 self.cell.get()
143 }
144}
145
146impl<T> Default for GlobalShared<T> {
147 fn default() -> Self {
148 Self::new()
149 }
150}
151
152#[macro_export]
154macro_rules! global_shared {
155 ($name:ident: $type:ty) => {
156 static $name: $crate::utils::shared_state::GlobalShared<$type> =
157 $crate::utils::shared_state::GlobalShared::new();
158 };
159}
160
161pub struct SharedBuilder<T> {
163 value: Option<T>,
164}
165
166impl<T> SharedBuilder<T> {
167 pub fn new() -> Self {
169 Self { value: None }
170 }
171
172 pub fn with_value(mut self, value: T) -> Self {
174 self.value = Some(value);
175 self
176 }
177
178 pub fn build(self) -> Option<Shared<T>> {
180 self.value.map(Shared::new)
181 }
182
183 pub fn build_or_panic(self, msg: &str) -> Shared<T> {
185 self.build().expect(msg)
186 }
187}
188
189impl<T> Default for SharedBuilder<T> {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195pub mod utils {
197 use super::*;
198
199 pub fn share<T>(value: T) -> Shared<T> {
201 Shared::new(value)
202 }
203
204 pub fn share_mut<T>(value: T) -> SharedMut<T> {
206 SharedMut::new(value)
207 }
208
209 pub fn from_arc<T>(arc: Arc<T>) -> Shared<T> {
211 Shared { inner: arc }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_shared_resource() {
221 let shared = Shared::new(42);
222 assert_eq!(*shared.get(), 42);
223 assert_eq!(*shared, 42);
224
225 let cloned = shared.clone();
226 assert_eq!(*cloned.get(), 42);
227 assert_eq!(shared.strong_count(), 2);
228 }
229
230 #[test]
231 fn test_shared_mut_resource() {
232 let shared = SharedMut::new(42);
233
234 {
235 let read_guard = shared.read();
236 assert_eq!(*read_guard, 42);
237 }
238
239 {
240 let mut write_guard = shared.write();
241 *write_guard = 100;
242 }
243
244 {
245 let read_guard = shared.read();
246 assert_eq!(*read_guard, 100);
247 }
248 }
249
250 #[test]
251 fn test_global_shared() {
252 let global: GlobalShared<i32> = GlobalShared::new();
253
254 assert!(global.try_get().is_none());
255
256 global.init(42).unwrap();
257 assert_eq!(**global.get(), 42);
258
259 assert!(global.init(100).is_err());
261 }
262
263 #[test]
264 fn test_shared_builder() {
265 let shared = SharedBuilder::new().with_value(42).build().unwrap();
266
267 assert_eq!(*shared.get(), 42);
268 }
269}