Skip to main content

pure_stage/
resources.rs

1// Copyright 2024 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[expect(clippy::disallowed_types)]
16use std::{
17    any::{Any, TypeId, type_name},
18    collections::HashMap,
19    sync::Arc,
20};
21
22use parking_lot::{MappedRwLockReadGuard, MappedRwLockWriteGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
23
24/// A collection of resources that can be accessed by external effects.
25///
26/// This is used to pass resources to external effects while properly scoping the resource to the running stage graph.
27/// If you want to share a resource across multiple stage graphs, you can use `Arc<Mutex<T>>` or similar.
28///
29/// ## API Design Choices
30///
31/// StageGraph supports single-threaded simulation as well as multi-threaded production code.
32/// Since effect implementations must cover both cases with the same code, the resulting API must
33/// be constrained by both environments. The simulation can easily provided a `&mut Resources`,
34/// but if we require that, then the production code will have to hold a lock on the `Resources`
35/// for the whole duration of each effect, serializing all resources.
36///
37/// Therefore, even though not needed in simulation, we design the API so that the effect can
38/// use its resources for shorter durations.
39///
40/// ## `Sync` Bound
41///
42/// In order to allow resources to be used without blocking the whole resource collection, shared
43/// references can be obtained with read locking. Since this fundamentally allows shared access
44/// from multiple threads, the resources must be `Sync`. If your resource is not `Sync`, you can
45/// use [`SyncWrapper`](https://docs.rs/sync_wrapper/latest/sync_wrapper/struct.SyncWrapper.html)
46/// or a mutex.
47#[derive(Default, Clone)]
48#[expect(clippy::disallowed_types)]
49pub struct Resources(Arc<RwLock<HashMap<TypeId, Box<dyn Any + Send + Sync>>>>);
50
51impl Resources {
52    /// Put a resource into the resources collection.
53    ///
54    /// This variant uses locking to ensure that the resource is not accessed concurrently.
55    pub fn put<T: Any + Send + Sync>(&self, resource: T) {
56        self.0.write().insert(TypeId::of::<T>(), Box::new(resource));
57    }
58
59    /// Get a resource from the resources collection.
60    ///
61    /// This variant only takes a read lock on the resource collection, allowing other `get`
62    /// operations to proceed concurrently. [`get_mut`](Self::get_mut) will be blocked while
63    /// the returned guard is held, so [`drop`](std::mem::drop) it as soon as you don't need it
64    /// any more.
65    pub fn get<T: Any + Send + Sync>(&self) -> anyhow::Result<MappedRwLockReadGuard<'_, T>> {
66        RwLockReadGuard::try_map(self.0.read(), |res| res.get(&TypeId::of::<T>())?.downcast_ref::<T>())
67            .map_err(|_| anyhow::anyhow!("Resource of type `{}` not found", type_name::<T>()))
68    }
69
70    /// Get a mutable reference to a resource from the resources collection.
71    ///
72    /// This variant takes a write lock on the resource collection, blocking all other operations.
73    /// See [`get`](Self::get) for a variant that uses read locking. Concurrent operations will
74    /// be blocked while the returned guard is held, so [`drop`](std::mem::drop) it as soon as you
75    /// don't need it any more.
76    ///
77    /// If you need exclusive access to a single resource without blocking the rest of the
78    /// resource collection, consider putting an `Arc<Mutex<T>>` in the resources collection.
79    pub fn get_mut<T: Any + Send + Sync>(&self) -> anyhow::Result<MappedRwLockWriteGuard<'_, T>> {
80        RwLockWriteGuard::try_map(self.0.write(), |res| res.get_mut(&TypeId::of::<T>())?.downcast_mut::<T>())
81            .map_err(|_| anyhow::anyhow!("Resource of type `{}` not found", type_name::<T>()))
82    }
83
84    /// Take a resource from the resources collection.
85    ///
86    /// This variant uses locking to ensure that the resource is not accessed concurrently.
87    pub fn take<T: Any + Send + Sync>(&self) -> anyhow::Result<T> {
88        self.0
89            .write()
90            .remove(&TypeId::of::<T>())
91            .ok_or_else(|| anyhow::anyhow!("Resource of type `{}` not found", type_name::<T>()))?
92            .downcast::<T>()
93            .map(|x| *x)
94            .map_err(|_| anyhow::anyhow!("Resource of type `{}` not found", type_name::<T>()))
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[test]
103    fn test_resources() {
104        let resources = Resources::default();
105
106        assert_eq!(resources.get::<u32>().unwrap_err().to_string(), "Resource of type `u32` not found");
107
108        resources.put(42u32);
109        assert_eq!(*resources.get::<u32>().unwrap(), 42);
110
111        resources.put(43u32);
112        assert_eq!(*resources.get_mut::<u32>().unwrap(), 43);
113
114        assert_eq!(resources.take::<u32>().unwrap(), 43);
115        assert_eq!(resources.take::<u32>().unwrap_err().to_string(), "Resource of type `u32` not found");
116    }
117}