1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
/*
 * Created on Thu Jun 29 2023
 *
 * Copyright (c) storycraft. Licensed under the MIT Licence.
 */

#![doc = include_str!("../README.md")]

use core::{marker::PhantomData, mem, ptr};
use std::{cell::RefCell, collections::HashMap};

use bumpalo::Bump;
use parking_lot::RwLock;
use type_key::TypeKey;

#[derive(Debug)]
/// A raw persistent value store using closure as key and storing its return value.
pub struct RawFnStore<'a> {
    map: HashMap<TypeKey, ManuallyDealloc>,

    // Ensure allocator always drops later than its value to prevent UB
    bump: Bump,
    _phantom: PhantomData<&'a ()>,
}

impl<'a> RawFnStore<'a> {
    pub fn new() -> Self {
        Self {
            map: HashMap::new(),
            bump: Bump::new(),
            _phantom: PhantomData,
        }
    }

    pub fn get_ptr<T: 'a>(&self, key: &impl FnOnce() -> T) -> Option<*const T> {
        Some(self.map.get(&TypeKey::of_val(key))?.ptr().cast::<T>())
    }

    pub fn insert_ptr<F: FnOnce() -> T, T: 'a>(&mut self, value: T) -> *const T {
        // SAFETY: Exclusively borrowed reference, original value is forgotten by Bump allocator and does not outlive.
        let value = unsafe { ManuallyDealloc::new(self.bump.alloc(value)) };
        let ptr = value.ptr();

        self.map.insert(TypeKey::of::<F>(), value);

        ptr.cast::<T>()
    }

    pub fn reset(&mut self) {
        self.map.clear();
        self.bump.reset();
    }
}

impl Default for RawFnStore<'_> {
    fn default() -> Self {
        Self::new()
    }
}

#[derive(Debug)]
/// Single thread only FnStore implementation.
///
/// Uses RefCell to borrow inner Map mutably.
pub struct LocalFnStore<'a>(RefCell<RawFnStore<'a>>);

impl<'a> LocalFnStore<'a> {
    pub fn new() -> Self {
        Self(RefCell::new(RawFnStore::new()))
    }

    /// Get or compute value using key
    pub fn get<T: 'a + Send, F: FnOnce() -> T>(&self, key: F) -> &T {
        if let Some(ptr) = self.0.borrow().get_ptr(&key) {
            return unsafe { &*ptr };
        }

        // SAFETY: pointer is valid and its reference cannot outlive more than Self
        let value = (key)();
        let ptr = self.0.borrow_mut().insert_ptr::<F, T>(value);
        unsafe { &*ptr }
    }

    /// Reset stored values
    pub fn reset(&mut self) {
        self.0.get_mut().reset();
    }
}

impl Default for LocalFnStore<'_> {
    fn default() -> Self {
        Self::new()
    }
}

unsafe impl Send for LocalFnStore<'_> {}

#[derive(Debug)]
/// Thread safe FnStore implementation.
///
/// Uses parking_lot's [`RwLock`] to accuire mutable access to Map.
pub struct AtomicFnStore<'a>(RwLock<RawFnStore<'a>>);

impl<'a> AtomicFnStore<'a> {
    pub fn new() -> Self {
        Self(RwLock::new(RawFnStore::new()))
    }

    /// Get or compute value and insert using key
    pub fn get<T: 'a + Send + Sync, F: FnOnce() -> T>(&self, key: F) -> &T {
        if let Some(ptr) = self.0.read().get_ptr(&key) {
            // SAFETY: pointer is valid and its reference cannot outlive more than Self
            return unsafe { &*ptr };
        }

        let value = (key)();
        let ptr = self.0.write().insert_ptr::<F, T>(value);
        // SAFETY: pointer is valid and its reference cannot outlive more than Self
        unsafe { &*ptr }
    }

    /// Reset stored values
    pub fn reset(&mut self) {
        self.0.get_mut().reset();
    }
}

impl Default for AtomicFnStore<'_> {
    fn default() -> Self {
        Self::new()
    }
}

unsafe impl Send for AtomicFnStore<'_> {}
unsafe impl Sync for AtomicFnStore<'_> {}

trait Erased {}
impl<T> Erased for T {}

#[derive(Debug)]
#[repr(transparent)]
/// Manually deallocated pointer.
/// It's intended to be used with bump allocator.
///
/// # Safety
/// Dereferencing the pointer is only safe when the pointer did not outlive its value
struct ManuallyDealloc(*mut dyn Erased);

impl ManuallyDealloc {
    /// # Safety
    /// Calling this function is only safe if the value referenced by `reference is forgotten.
    pub unsafe fn new<T>(reference: &mut T) -> Self {
        Self(mem::transmute::<&mut dyn Erased, &mut dyn Erased>(reference) as *mut _)
    }

    pub const fn ptr(&self) -> *const dyn Erased {
        self.0.cast_const()
    }
}

impl Drop for ManuallyDealloc {
    fn drop(&mut self) {
        // SAFETY: Safe to drop since its original was forgotten and only the pointer is pointing the value. See [`ManuallyDealloc::new`]
        unsafe { ptr::drop_in_place(self.0) }
    }
}

#[cfg(test)]
mod tests {
    use crate::{AtomicFnStore, LocalFnStore};

    #[test]
    fn test_trait() {
        const fn is_send<T: Send>() {}
        const fn is_sync<T: Sync>() {}

        is_send::<LocalFnStore>();

        is_send::<AtomicFnStore>();
        is_sync::<AtomicFnStore>();
    }

    #[test]
    fn test_local() {
        let store = LocalFnStore::new();

        fn one() -> i32 {
            1
        }

        let b = store.get(|| store.get(one) + 1);
        let a = store.get(one);

        assert_eq!(*b, 2);
        assert_eq!(*a, 1);
    }

    #[test]
    fn test_atomic() {
        let store = AtomicFnStore::new();

        fn one() -> i32 {
            1
        }

        let b = store.get(|| store.get(one) + 1);
        let a = store.get(one);

        assert_eq!(*b, 2);
        assert_eq!(*a, 1);
    }
}