custos/devices/
cache.rs

1//! Contains the [`Cache`]ing logic.
2
3use core::{cell::RefMut, fmt::Debug, hash::BuildHasherDefault, ops::BitXor};
4use std::collections::HashMap;
5
6use std::rc::Rc;
7
8use crate::{
9    flag::AllocFlag, shape::Shape, Alloc, Buffer, CacheAble, Device, GlobalCount, GraphReturn,
10    Ident, PtrConv, PtrType,
11};
12
13/// This trait makes a device's [`Cache`] accessible and is implemented for all compute devices.
14pub trait CacheReturn: GraphReturn<GlobalCount> {
15    /// Returns a reference to a device's [`Cache`].
16    fn cache(&self) -> core::cell::Ref<Cache<Self>>
17    where
18        Self: PtrConv;
19
20    /// Returns a mutable reference to a device's [`Cache`].
21    fn cache_mut(&self) -> RefMut<Cache<Self>>
22    where
23        Self: PtrConv;
24}
25
26const K: usize = 0x517cc1b727220a95;
27
28/// An low-overhead [`Ident`] hasher using "FxHasher".
29#[derive(Default)]
30pub struct IdentHasher {
31    hash: usize,
32}
33
34impl std::hash::Hasher for IdentHasher {
35    #[inline]
36    fn finish(&self) -> u64 {
37        self.hash as u64
38    }
39
40    #[inline]
41    fn write(&mut self, _bytes: &[u8]) {
42        unimplemented!("IdentHasher only hashes usize.")
43    }
44
45    #[inline]
46    fn write_usize(&mut self, i: usize) {
47        self.hash = self.hash.rotate_left(5).bitxor(i).wrapping_mul(K);
48    }
49}
50
51impl<D> CacheAble<D> for Cache<D>
52where
53    D: PtrConv + CacheReturn,
54{
55    #[cfg(not(feature = "realloc"))]
56    #[inline]
57    fn retrieve<T, S: Shape>(
58        device: &D,
59        len: usize,
60        add_node: impl crate::AddGraph,
61    ) -> Buffer<T, D, S>
62    where
63        for<'b> D: Alloc<'b, T, S>,
64    {
65        device
66            .cache_mut()
67            .get(device, Ident::new(len), add_node, crate::bump_count)
68    }
69
70    #[cfg(feature = "realloc")]
71    #[inline]
72    fn retrieve<T, S: Shape>(
73        device: &D,
74        len: usize,
75        _add_node: impl crate::AddGraph,
76    ) -> Buffer<T, D, S>
77    where
78        for<'b> D: Alloc<'b, T, S>,
79    {
80        Buffer::new(device, len)
81    }
82
83    #[inline]
84    unsafe fn get_existing_buf<T, S: Shape>(device: &D, ident: Ident) -> Option<Buffer<T, D, S>> {
85        let ptr = D::convert(device.cache().nodes.get(&ident)?, AllocFlag::Wrapper);
86
87        Some(Buffer {
88            ptr,
89            device: Some(device),
90            ident: Some(ident),
91        })
92    }
93
94    #[inline]
95    fn remove(device: &D, ident: Ident) {
96        device.cache_mut().nodes.remove(&ident);
97    }
98
99    fn add_to_cache<T, S: Shape>(device: &D, ptr: &<D as Device>::Ptr<T, S>) -> Option<Ident> {
100        device.graph_mut().add_leaf(ptr.size());
101        let ident = Ident::new_bumped(ptr.size());
102        let raw_ptr = unsafe { std::rc::Rc::new(D::convert(ptr, AllocFlag::Wrapper)) };
103        device.cache_mut().nodes.insert(ident, raw_ptr);
104        Some(ident)
105    }
106}
107
108/// A cache for 'no-generic' raw pointers.
109pub struct Cache<D: Device> {
110    /// A map of all cached buffers using a custom hash function.
111    pub nodes: HashMap<Ident, Rc<D::Ptr<u8, ()>>, BuildHasherDefault<IdentHasher>>,
112}
113
114impl<D: Device> Debug for Cache<D>
115where
116    D::Ptr<u8, ()>: Debug,
117{
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        f.debug_struct("Cache2")
120            .field("cache", &self.nodes)
121            .finish()
122    }
123}
124
125impl<D: Device> Default for Cache<D>
126where
127    D::Ptr<u8, ()>: Default,
128{
129    #[inline]
130    fn default() -> Self {
131        Self {
132            nodes: Default::default(),
133        }
134    }
135}
136
137impl<D: PtrConv + GraphReturn> Cache<D> {
138    /// Adds a new cache entry to the cache.
139    /// The next get call will return this entry if the [Ident] is correct.
140    /// # Example
141    #[cfg_attr(feature = "cpu", doc = "```")]
142    #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
143    /// use custos::prelude::*;
144    /// use custos::{Ident, bump_count};
145    ///
146    /// let device = CPU::new();
147    /// let cache: Buffer = device
148    ///     .cache_mut()
149    ///     .add_node(&device, Ident { idx: 0, len: 7 }, (), bump_count);
150    ///
151    /// let ptr = device
152    ///     .cache()
153    ///     .nodes
154    ///     .get(&Ident { idx: 0, len: 7 })
155    ///     .unwrap()
156    ///     .clone();
157    ///
158    /// assert_eq!(cache.host_ptr(), ptr.ptr as *mut f32);
159    /// ```
160    pub fn add_node<'a, T, S: Shape>(
161        &mut self,
162        device: &'a D,
163        ident: Ident,
164        _add_node: impl crate::AddGraph,
165        callback: fn(),
166    ) -> Buffer<'a, T, D, S>
167    where
168        D: Alloc<'a, T, S>,
169    {
170        let ptr = device.alloc(ident.len, AllocFlag::Wrapper);
171
172        #[cfg(feature = "opt-cache")]
173        let graph_node = device.graph_mut().add(ident.len, _add_node);
174
175        #[cfg(not(feature = "opt-cache"))]
176        let graph_node = crate::Node {
177            idx: ident.idx,
178            deps: [0; 2],
179            len: ident.len,
180        };
181
182        let untyped_ptr = unsafe { D::convert(&ptr, AllocFlag::None) };
183        self.nodes.insert(ident, Rc::new(untyped_ptr));
184
185        callback();
186
187        Buffer {
188            ptr,
189            device: Some(device),
190            ident: Some(Ident {
191                idx: graph_node.idx,
192                len: ident.len,
193            }),
194        }
195    }
196
197    /// Retrieves cached pointers and constructs a [`Buffer`] with the pointers and the given `len`gth.
198    /// If a cached pointer doesn't exist, a new `Buffer` will be added to the cache and returned.
199    ///
200    /// # Example
201    #[cfg_attr(feature = "cpu", doc = "```")]
202    #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
203    /// use custos::prelude::*;
204    /// use custos::bump_count;
205    ///
206    /// let device = CPU::new();
207    ///     
208    /// let cache_entry: Buffer = device.cache_mut().get(&device, Ident::new(10), (), bump_count);
209    /// let new_cache_entry: Buffer = device.cache_mut().get(&device, Ident::new(10), (), bump_count);
210    ///
211    /// assert_ne!(cache_entry.ptrs(), new_cache_entry.ptrs());
212    ///
213    /// unsafe { set_count(0) };
214    ///
215    /// let first_entry: Buffer = device.cache_mut().get(&device, Ident::new(10), (), bump_count);
216    /// assert_eq!(cache_entry.ptrs(), first_entry.ptrs());
217    /// ```
218    pub fn get<'a, T, S: Shape>(
219        &mut self,
220        device: &'a D,
221        ident: Ident,
222        add_node: impl crate::AddGraph,
223        callback: fn(),
224    ) -> Buffer<'a, T, D, S>
225    where
226        D: Alloc<'a, T, S>,
227    {
228        let may_allocated = self.nodes.get(&ident);
229
230        match may_allocated {
231            Some(ptr) => {
232                callback();
233                let typed_ptr = unsafe { D::convert(ptr, AllocFlag::Wrapper) };
234
235                Buffer {
236                    ptr: typed_ptr,
237                    device: Some(device),
238                    ident: Some(ident),
239                }
240            }
241            None => self.add_node(device, ident, add_node, callback),
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use core::hash::Hasher;
249    use std::collections::HashSet;
250
251    //#[cfg(not(feature = "realloc"))]
252    //use crate::set_count;
253    //use crate::{bump_count, Buffer, CacheReturn, Ident, IdentHasher};
254
255    #[test]
256    #[cfg_attr(miri, ignore)]
257    fn test_ident_hasher() {
258        use crate::IdentHasher;
259
260        let mut hashed_items = HashSet::new();
261        let mut hasher = IdentHasher::default();
262
263        for item in 0..2500000 {
264            hasher.write_usize(item);
265            hasher.write_usize(100000);
266            let hashed_item = hasher.finish();
267            assert!(!hashed_items.contains(&hashed_item));
268
269            hashed_items.insert(hashed_item);
270        }
271    }
272
273    #[cfg(feature = "cpu")]
274    #[test]
275    fn test_add_node() {
276        use crate::{bump_count, Buffer, CacheReturn, Ident};
277
278        let device = crate::CPU::new();
279        let cache: Buffer =
280            device
281                .cache_mut()
282                .add_node(&device, Ident { idx: 0, len: 7 }, (), bump_count);
283
284        let ptr = device
285            .cache()
286            .nodes
287            .get(&Ident { idx: 0, len: 7 })
288            .unwrap()
289            .clone();
290
291        assert_eq!(cache.host_ptr(), ptr.ptr as *mut f32);
292    }
293
294    #[cfg(feature = "cpu")]
295    #[cfg(not(feature = "realloc"))]
296    #[test]
297    fn test_get() {
298        // for: cargo test -- --test-threads=1
299
300        use crate::{bump_count, set_count, Buffer, CacheReturn, Ident};
301        unsafe { set_count(0) };
302        let device = crate::CPU::new();
303
304        let cache_entry: Buffer = device
305            .cache_mut()
306            .get(&device, Ident::new(10), (), bump_count);
307        let new_cache_entry: Buffer =
308            device
309                .cache_mut()
310                .get(&device, Ident::new(10), (), bump_count);
311
312        assert_ne!(cache_entry.ptrs(), new_cache_entry.ptrs());
313
314        unsafe { set_count(0) };
315
316        let first_entry: Buffer = device
317            .cache_mut()
318            .get(&device, Ident::new(10), (), bump_count);
319        assert_eq!(cache_entry.ptrs(), first_entry.ptrs());
320    }
321}