pure_stage/resources.rs
1// Copyright 2024 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[expect(clippy::disallowed_types)]
16use std::{
17 any::{Any, TypeId, type_name},
18 collections::HashMap,
19 sync::Arc,
20};
21
22use parking_lot::{MappedRwLockReadGuard, MappedRwLockWriteGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
23
24/// A collection of resources that can be accessed by external effects.
25///
26/// This is used to pass resources to external effects while properly scoping the resource to the running stage graph.
27/// If you want to share a resource across multiple stage graphs, you can use `Arc<Mutex<T>>` or similar.
28///
29/// ## API Design Choices
30///
31/// StageGraph supports single-threaded simulation as well as multi-threaded production code.
32/// Since effect implementations must cover both cases with the same code, the resulting API must
33/// be constrained by both environments. The simulation can easily provided a `&mut Resources`,
34/// but if we require that, then the production code will have to hold a lock on the `Resources`
35/// for the whole duration of each effect, serializing all resources.
36///
37/// Therefore, even though not needed in simulation, we design the API so that the effect can
38/// use its resources for shorter durations.
39///
40/// ## `Sync` Bound
41///
42/// In order to allow resources to be used without blocking the whole resource collection, shared
43/// references can be obtained with read locking. Since this fundamentally allows shared access
44/// from multiple threads, the resources must be `Sync`. If your resource is not `Sync`, you can
45/// use [`SyncWrapper`](https://docs.rs/sync_wrapper/latest/sync_wrapper/struct.SyncWrapper.html)
46/// or a mutex.
47#[derive(Default, Clone)]
48#[expect(clippy::disallowed_types)]
49pub struct Resources(Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>);
50
51impl Resources {
52 /// Put a resource into the resources collection.
53 ///
54 /// This variant uses locking to ensure that the resource is not accessed concurrently.
55 pub fn put<T: Any + Send + Sync>(&self, resource: T) {
56 self.0.write().insert(TypeId::of::<T>(), Box::new(resource));
57 }
58
59 /// Get a resource from the resources collection.
60 ///
61 /// This variant only takes a read lock on the resource collection, allowing other `get`
62 /// operations to proceed concurrently. [`get_mut`](Self::get_mut) will be blocked while
63 /// the returned guard is held, so [`drop`](std::mem::drop) it as soon as you don't need it
64 /// any more.
65 pub fn get<T: Any + Send + Sync>(&self) -> anyhow::Result<MappedRwLockReadGuard<'_, T>> {
66 RwLockReadGuard::try_map(self.0.read(), |res| res.get(&TypeId::of::<T>())?.downcast_ref::<T>())
67 .map_err(|_| anyhow::anyhow!("Resource of type `{}` not found", type_name::<T>()))
68 }
69
70 /// Get a mutable reference to a resource from the resources collection.
71 ///
72 /// This variant takes a write lock on the resource collection, blocking all other operations.
73 /// See [`get`](Self::get) for a variant that uses read locking. Concurrent operations will
74 /// be blocked while the returned guard is held, so [`drop`](std::mem::drop) it as soon as you
75 /// don't need it any more.
76 ///
77 /// If you need exclusive access to a single resource without blocking the rest of the
78 /// resource collection, consider putting an `Arc<Mutex<T>>` in the resources collection.
79 pub fn get_mut<T: Any + Send + Sync>(&self) -> anyhow::Result<MappedRwLockWriteGuard<'_, T>> {
80 RwLockWriteGuard::try_map(self.0.write(), |res| res.get_mut(&TypeId::of::<T>())?.downcast_mut::<T>())
81 .map_err(|_| anyhow::anyhow!("Resource of type `{}` not found", type_name::<T>()))
82 }
83
84 /// Take a resource from the resources collection.
85 ///
86 /// This variant uses locking to ensure that the resource is not accessed concurrently.
87 pub fn take<T: Any + Send + Sync>(&self) -> anyhow::Result<T> {
88 self.0
89 .write()
90 .remove(&TypeId::of::<T>())
91 .ok_or_else(|| anyhow::anyhow!("Resource of type `{}` not found", type_name::<T>()))?
92 .downcast::<T>()
93 .map(|x| *x)
94 .map_err(|_| anyhow::anyhow!("Resource of type `{}` not found", type_name::<T>()))
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 #[test]
103 fn test_resources() {
104 let resources = Resources::default();
105
106 assert_eq!(resources.get::<u32>().unwrap_err().to_string(), "Resource of type `u32` not found");
107
108 resources.put(42u32);
109 assert_eq!(*resources.get::<u32>().unwrap(), 42);
110
111 resources.put(43u32);
112 assert_eq!(*resources.get_mut::<u32>().unwrap(), 43);
113
114 assert_eq!(resources.take::<u32>().unwrap(), 43);
115 assert_eq!(resources.take::<u32>().unwrap_err().to_string(), "Resource of type `u32` not found");
116 }
117}