1use core::{cell::RefMut, fmt::Debug, hash::BuildHasherDefault, ops::BitXor};
4use std::collections::HashMap;
5
6use std::rc::Rc;
7
8use crate::{
9 flag::AllocFlag, shape::Shape, Alloc, Buffer, CacheAble, Device, GlobalCount, GraphReturn,
10 Ident, PtrConv, PtrType,
11};
12
13pub trait CacheReturn: GraphReturn<GlobalCount> {
15 fn cache(&self) -> core::cell::Ref<Cache<Self>>
17 where
18 Self: PtrConv;
19
20 fn cache_mut(&self) -> RefMut<Cache<Self>>
22 where
23 Self: PtrConv;
24}
25
26const K: usize = 0x517cc1b727220a95;
27
28#[derive(Default)]
30pub struct IdentHasher {
31 hash: usize,
32}
33
34impl std::hash::Hasher for IdentHasher {
35 #[inline]
36 fn finish(&self) -> u64 {
37 self.hash as u64
38 }
39
40 #[inline]
41 fn write(&mut self, _bytes: &[u8]) {
42 unimplemented!("IdentHasher only hashes usize.")
43 }
44
45 #[inline]
46 fn write_usize(&mut self, i: usize) {
47 self.hash = self.hash.rotate_left(5).bitxor(i).wrapping_mul(K);
48 }
49}
50
51impl<D> CacheAble<D> for Cache<D>
52where
53 D: PtrConv + CacheReturn,
54{
55 #[cfg(not(feature = "realloc"))]
56 #[inline]
57 fn retrieve<T, S: Shape>(
58 device: &D,
59 len: usize,
60 add_node: impl crate::AddGraph,
61 ) -> Buffer<T, D, S>
62 where
63 for<'b> D: Alloc<'b, T, S>,
64 {
65 device
66 .cache_mut()
67 .get(device, Ident::new(len), add_node, crate::bump_count)
68 }
69
70 #[cfg(feature = "realloc")]
71 #[inline]
72 fn retrieve<T, S: Shape>(
73 device: &D,
74 len: usize,
75 _add_node: impl crate::AddGraph,
76 ) -> Buffer<T, D, S>
77 where
78 for<'b> D: Alloc<'b, T, S>,
79 {
80 Buffer::new(device, len)
81 }
82
83 #[inline]
84 unsafe fn get_existing_buf<T, S: Shape>(device: &D, ident: Ident) -> Option<Buffer<T, D, S>> {
85 let ptr = D::convert(device.cache().nodes.get(&ident)?, AllocFlag::Wrapper);
86
87 Some(Buffer {
88 ptr,
89 device: Some(device),
90 ident: Some(ident),
91 })
92 }
93
94 #[inline]
95 fn remove(device: &D, ident: Ident) {
96 device.cache_mut().nodes.remove(&ident);
97 }
98
99 fn add_to_cache<T, S: Shape>(device: &D, ptr: &<D as Device>::Ptr<T, S>) -> Option<Ident> {
100 device.graph_mut().add_leaf(ptr.size());
101 let ident = Ident::new_bumped(ptr.size());
102 let raw_ptr = unsafe { std::rc::Rc::new(D::convert(ptr, AllocFlag::Wrapper)) };
103 device.cache_mut().nodes.insert(ident, raw_ptr);
104 Some(ident)
105 }
106}
107
108pub struct Cache<D: Device> {
110 pub nodes: HashMap<Ident, Rc<D::Ptr<u8, ()>>, BuildHasherDefault<IdentHasher>>,
112}
113
114impl<D: Device> Debug for Cache<D>
115where
116 D::Ptr<u8, ()>: Debug,
117{
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.debug_struct("Cache2")
120 .field("cache", &self.nodes)
121 .finish()
122 }
123}
124
125impl<D: Device> Default for Cache<D>
126where
127 D::Ptr<u8, ()>: Default,
128{
129 #[inline]
130 fn default() -> Self {
131 Self {
132 nodes: Default::default(),
133 }
134 }
135}
136
137impl<D: PtrConv + GraphReturn> Cache<D> {
138 #[cfg_attr(feature = "cpu", doc = "```")]
142 #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
143 pub fn add_node<'a, T, S: Shape>(
161 &mut self,
162 device: &'a D,
163 ident: Ident,
164 _add_node: impl crate::AddGraph,
165 callback: fn(),
166 ) -> Buffer<'a, T, D, S>
167 where
168 D: Alloc<'a, T, S>,
169 {
170 let ptr = device.alloc(ident.len, AllocFlag::Wrapper);
171
172 #[cfg(feature = "opt-cache")]
173 let graph_node = device.graph_mut().add(ident.len, _add_node);
174
175 #[cfg(not(feature = "opt-cache"))]
176 let graph_node = crate::Node {
177 idx: ident.idx,
178 deps: [0; 2],
179 len: ident.len,
180 };
181
182 let untyped_ptr = unsafe { D::convert(&ptr, AllocFlag::None) };
183 self.nodes.insert(ident, Rc::new(untyped_ptr));
184
185 callback();
186
187 Buffer {
188 ptr,
189 device: Some(device),
190 ident: Some(Ident {
191 idx: graph_node.idx,
192 len: ident.len,
193 }),
194 }
195 }
196
197 #[cfg_attr(feature = "cpu", doc = "```")]
202 #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
203 pub fn get<'a, T, S: Shape>(
219 &mut self,
220 device: &'a D,
221 ident: Ident,
222 add_node: impl crate::AddGraph,
223 callback: fn(),
224 ) -> Buffer<'a, T, D, S>
225 where
226 D: Alloc<'a, T, S>,
227 {
228 let may_allocated = self.nodes.get(&ident);
229
230 match may_allocated {
231 Some(ptr) => {
232 callback();
233 let typed_ptr = unsafe { D::convert(ptr, AllocFlag::Wrapper) };
234
235 Buffer {
236 ptr: typed_ptr,
237 device: Some(device),
238 ident: Some(ident),
239 }
240 }
241 None => self.add_node(device, ident, add_node, callback),
242 }
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use core::hash::Hasher;
249 use std::collections::HashSet;
250
251 #[test]
256 #[cfg_attr(miri, ignore)]
257 fn test_ident_hasher() {
258 use crate::IdentHasher;
259
260 let mut hashed_items = HashSet::new();
261 let mut hasher = IdentHasher::default();
262
263 for item in 0..2500000 {
264 hasher.write_usize(item);
265 hasher.write_usize(100000);
266 let hashed_item = hasher.finish();
267 assert!(!hashed_items.contains(&hashed_item));
268
269 hashed_items.insert(hashed_item);
270 }
271 }
272
273 #[cfg(feature = "cpu")]
274 #[test]
275 fn test_add_node() {
276 use crate::{bump_count, Buffer, CacheReturn, Ident};
277
278 let device = crate::CPU::new();
279 let cache: Buffer =
280 device
281 .cache_mut()
282 .add_node(&device, Ident { idx: 0, len: 7 }, (), bump_count);
283
284 let ptr = device
285 .cache()
286 .nodes
287 .get(&Ident { idx: 0, len: 7 })
288 .unwrap()
289 .clone();
290
291 assert_eq!(cache.host_ptr(), ptr.ptr as *mut f32);
292 }
293
294 #[cfg(feature = "cpu")]
295 #[cfg(not(feature = "realloc"))]
296 #[test]
297 fn test_get() {
298 use crate::{bump_count, set_count, Buffer, CacheReturn, Ident};
301 unsafe { set_count(0) };
302 let device = crate::CPU::new();
303
304 let cache_entry: Buffer = device
305 .cache_mut()
306 .get(&device, Ident::new(10), (), bump_count);
307 let new_cache_entry: Buffer =
308 device
309 .cache_mut()
310 .get(&device, Ident::new(10), (), bump_count);
311
312 assert_ne!(cache_entry.ptrs(), new_cache_entry.ptrs());
313
314 unsafe { set_count(0) };
315
316 let first_entry: Buffer = device
317 .cache_mut()
318 .get(&device, Ident::new(10), (), bump_count);
319 assert_eq!(cache_entry.ptrs(), first_entry.ptrs());
320 }
321}