sokoban/
node_allocator.rs

1use bytemuck::{Pod, Zeroable};
2use num_derive::FromPrimitive;
3use std::mem::{align_of, size_of};
4
5/// Enum representing the fields of a tree node:
6/// 0 - left pointer
7/// 1 - right pointer
8/// 2 - parent pointer
9/// 3 - value pointer (index of leaf)
10#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive)]
11pub enum TreeField {
12    Left = 0,
13    Right = 1,
14    Parent = 2,
15    Value = 3,
16}
17
18/// Enum representing the fields of a simple node (Linked List / Binary Tree):
19/// 0 - left pointer
20/// 1 - right pointer
21#[derive(Debug, Copy, Clone, PartialEq, Eq, FromPrimitive)]
22pub enum NodeField {
23    Left = 0,
24    Right = 1,
25}
26
27/// This is a convenience trait that exposes an interface to read a struct from an arbitrary byte array
28pub trait FromSlice {
29    fn new_from_slice(data: &mut [u8]) -> &mut Self;
30}
31
32/// This trait provides an API for map-like data structures that use the NodeAllocator
33/// struct as the underlying container
34pub trait NodeAllocatorMap<K, V> {
35    fn insert(&mut self, key: K, value: V) -> Option<u32>;
36    fn remove(&mut self, key: &K) -> Option<V>;
37    fn contains(&self, key: &K) -> bool;
38    fn get(&self, key: &K) -> Option<&V>;
39    fn get_mut(&mut self, key: &K) -> Option<&mut V>;
40    #[deprecated]
41    fn size(&self) -> usize;
42    fn len(&self) -> usize;
43    fn is_empty(&self) -> bool {
44        self.len() == 0
45    }
46    fn capacity(&self) -> usize;
47    fn iter(&self) -> Box<dyn DoubleEndedIterator<Item = (&K, &V)> + '_>;
48    fn iter_mut(&mut self) -> Box<dyn DoubleEndedIterator<Item = (&K, &mut V)> + '_>;
49}
50
51/// This trait adds additional functions for sorted map data structures that use the NodeAllocator
52pub trait OrderedNodeAllocatorMap<K, V>: NodeAllocatorMap<K, V> {
53    fn get_min_index(&mut self) -> u32;
54    fn get_max_index(&mut self) -> u32;
55    fn get_min(&mut self) -> Option<(K, V)>;
56    fn get_max(&mut self) -> Option<(K, V)>;
57}
58
59pub trait ZeroCopy: Pod {
60    fn load_mut_bytes(data: &'_ mut [u8]) -> Option<&'_ mut Self> {
61        let size = std::mem::size_of::<Self>();
62        bytemuck::try_from_bytes_mut(&mut data[..size]).ok()
63    }
64
65    fn load_bytes(data: &'_ [u8]) -> Option<&'_ Self> {
66        let size = std::mem::size_of::<Self>();
67        bytemuck::try_from_bytes(&data[..size]).ok()
68    }
69}
70
71pub const SENTINEL: u32 = 0;
72
73#[repr(C)]
74#[derive(Copy, Clone)]
75pub struct Node<T: Copy + Clone + Pod + Zeroable + Default, const NUM_REGISTERS: usize> {
76    /// Arbitrary registers (generally used for pointers)
77    /// Note: Register 0 is ALWAYS used for the free list
78    registers: [u32; NUM_REGISTERS],
79    value: T,
80}
81
82impl<T: Copy + Clone + Pod + Zeroable + Default, const NUM_REGISTERS: usize> Default
83    for Node<T, NUM_REGISTERS>
84{
85    fn default() -> Self {
86        assert!(NUM_REGISTERS >= 1);
87        Self {
88            registers: [SENTINEL; NUM_REGISTERS],
89            value: T::default(),
90        }
91    }
92}
93
94impl<T: Copy + Clone + Pod + Zeroable + Default, const NUM_REGISTERS: usize>
95    Node<T, NUM_REGISTERS>
96{
97    #[inline(always)]
98    pub(crate) fn get_free_list_register(&self) -> u32 {
99        self.registers[0]
100    }
101
102    #[inline(always)]
103    pub fn get_register(&self, r: usize) -> u32 {
104        self.registers[r]
105    }
106
107    #[inline(always)]
108    pub(crate) fn set_free_list_register(&mut self, v: u32) {
109        self.registers[0] = v;
110    }
111
112    #[inline(always)]
113    pub fn set_register(&mut self, r: usize, v: u32) {
114        self.registers[r] = v;
115    }
116
117    #[inline(always)]
118    pub fn set_value(&mut self, v: T) {
119        self.value = v;
120    }
121
122    #[inline(always)]
123    pub fn get_value_mut(&mut self) -> &mut T {
124        &mut self.value
125    }
126
127    #[inline(always)]
128    pub fn get_value(&self) -> &T {
129        &self.value
130    }
131}
132
133#[repr(C)]
134#[derive(Copy, Clone)]
135pub struct NodeAllocator<
136    T: Default + Copy + Clone + Pod + Zeroable,
137    const MAX_SIZE: usize,
138    const NUM_REGISTERS: usize,
139> {
140    /// Size of the allocator. The max value this can take is `MAX_SIZE`
141    pub size: u64,
142    /// Index that represents the "boundary" of the allocator. When this value reaches `MAX_SIZE`
143    /// this indicates that all of the nodes has been used at least once and all new allocated
144    /// indicies must be pulled from the free list.
145    bump_index: u32,
146    /// Buffer index of the first element in the free list. The free list is a singly-linked list
147    /// of unallocated nodes. The free list operates like a stack. When a node is removed from the
148    /// allocator, the removed node becomes the new free list head. When new nodes are added,
149    /// the new index to allocated is pulled from the `free_list_head`
150    free_list_head: u32,
151    /// Nodes containing data, with `NUM_REGISTERS` registers that store arbitrary data  
152    pub nodes: [Node<T, NUM_REGISTERS>; MAX_SIZE],
153}
154
155unsafe impl<
156        T: Default + Copy + Clone + Pod + Zeroable,
157        const MAX_SIZE: usize,
158        const NUM_REGISTERS: usize,
159    > Zeroable for NodeAllocator<T, MAX_SIZE, NUM_REGISTERS>
160{
161}
162unsafe impl<
163        T: Default + Copy + Clone + Pod + Zeroable,
164        const MAX_SIZE: usize,
165        const NUM_REGISTERS: usize,
166    > Pod for NodeAllocator<T, MAX_SIZE, NUM_REGISTERS>
167{
168}
169
170impl<
171        T: Default + Copy + Clone + Pod + Zeroable,
172        const MAX_SIZE: usize,
173        const NUM_REGISTERS: usize,
174    > ZeroCopy for NodeAllocator<T, MAX_SIZE, NUM_REGISTERS>
175{
176}
177
178impl<
179        T: Default + Copy + Clone + Pod + Zeroable,
180        const MAX_SIZE: usize,
181        const NUM_REGISTERS: usize,
182    > Default for NodeAllocator<T, MAX_SIZE, NUM_REGISTERS>
183{
184    fn default() -> Self {
185        assert!(NUM_REGISTERS >= 1);
186        let na = NodeAllocator {
187            size: 0,
188            bump_index: 1,
189            free_list_head: 1,
190            nodes: [Node::<T, NUM_REGISTERS>::default(); MAX_SIZE],
191        };
192        na.assert_proper_alignemnt();
193        na
194    }
195}
196
197impl<
198        T: Default + Copy + Clone + Pod + Zeroable,
199        const MAX_SIZE: usize,
200        const NUM_REGISTERS: usize,
201    > NodeAllocator<T, MAX_SIZE, NUM_REGISTERS>
202{
203    pub fn new() -> Self {
204        Self::default()
205    }
206
207    #[inline(always)]
208    fn assert_proper_alignemnt(&self) {
209        let reg_size = size_of::<u32>() * NUM_REGISTERS;
210        let self_ptr = std::slice::from_ref(self).as_ptr() as usize;
211        let node_ptr = std::slice::from_ref(&self.nodes).as_ptr() as usize;
212        let self_align = align_of::<Self>();
213        let t_index = node_ptr + reg_size;
214        let t_align = align_of::<T>();
215        let t_size = size_of::<T>();
216        assert!(
217            self_ptr % self_align as usize == 0,
218            "NodeAllocator alignment mismatch, address is {} which is not a multiple of the struct alignment ({})",
219            self_ptr,
220            self_align,
221        );
222        assert!(
223            t_size % t_align == 0,
224            "Size of T ({}) is not a multiple of the alignment of T ({})",
225            t_size,
226            t_align,
227        );
228        assert!(
229            t_size == 0 || t_size >= self_align,
230            "Size of T ({}) must be >= than the alignment of NodeAllocator ({})",
231            t_size,
232            self_align,
233        );
234        assert!(node_ptr == self_ptr + 16, "Nodes are misaligned");
235        assert!(t_index % t_align == 0, "First index of T is misaligned");
236        assert!(
237            (t_index + t_size + reg_size) % t_align == 0,
238            "Subsequent indices of T are misaligned"
239        );
240    }
241
242    pub fn initialize(&mut self) {
243        assert!(NUM_REGISTERS >= 1);
244        self.assert_proper_alignemnt();
245        if self.size == 0 && self.bump_index == 0 && self.free_list_head == 0 {
246            self.bump_index = 1;
247            self.free_list_head = 1;
248        } else {
249            panic!("Cannot reinitialize NodeAllocator");
250        }
251    }
252
253    #[inline(always)]
254    pub fn get(&self, i: u32) -> &Node<T, NUM_REGISTERS> {
255        &self.nodes[(i - 1) as usize]
256    }
257
258    #[inline(always)]
259    pub fn get_mut(&mut self, i: u32) -> &mut Node<T, NUM_REGISTERS> {
260        &mut self.nodes[(i - 1) as usize]
261    }
262
263    /// Adds a new node to the allocator. The function returns the current pointer
264    /// to the free list, where the new node is inserted
265    pub fn add_node(&mut self, node: T) -> u32 {
266        let i = self.free_list_head;
267        if self.free_list_head == self.bump_index {
268            if self.bump_index == (MAX_SIZE + 1) as u32 {
269                panic!("Buffer is full, size {}", self.size);
270            }
271            self.bump_index += 1;
272            self.free_list_head = self.bump_index;
273        } else {
274            self.free_list_head = self.get(i).get_free_list_register();
275            self.get_mut(i).set_free_list_register(SENTINEL);
276        }
277        self.get_mut(i).set_value(node);
278        self.size += 1;
279        i
280    }
281
282    /// Removes the node at index `i` from the allocator and adds the index to the free list
283    /// When deleting nodes, you MUST clear all registers prior to calling `remove_node`
284    pub fn remove_node(&mut self, i: u32) -> Option<&T> {
285        if i == SENTINEL {
286            return None;
287        }
288        let free_list_head = self.free_list_head;
289        self.get_mut(i).set_free_list_register(free_list_head);
290        self.free_list_head = i;
291        self.size -= 1;
292        Some(self.get(i).get_value())
293    }
294
295    #[inline(always)]
296    pub fn disconnect(&mut self, i: u32, j: u32, r_i: u32, r_j: u32) {
297        if i != SENTINEL {
298            // assert!(j == self.get_register(i, r_i), "Nodes are not connected");
299            self.clear_register(i, r_i);
300        }
301        if j != SENTINEL {
302            // assert!(i == self.get_register(j, r_j), "Nodes are not connected");
303            self.clear_register(j, r_j);
304        }
305    }
306
307    #[inline(always)]
308    pub fn clear_register(&mut self, i: u32, r_i: u32) {
309        if i != SENTINEL {
310            self.get_mut(i).set_register(r_i as usize, SENTINEL);
311        }
312    }
313
314    #[inline(always)]
315    pub fn connect(&mut self, i: u32, j: u32, r_i: u32, r_j: u32) {
316        if i != SENTINEL {
317            self.get_mut(i).set_register(r_i as usize, j);
318        }
319        if j != SENTINEL {
320            self.get_mut(j).set_register(r_j as usize, i);
321        }
322    }
323
324    #[inline(always)]
325    pub fn set_register(&mut self, i: u32, value: u32, r_i: u32) {
326        if i != SENTINEL {
327            self.get_mut(i).set_register(r_i as usize, value);
328        }
329    }
330
331    #[inline(always)]
332    pub fn get_register(&self, i: u32, r_i: u32) -> u32 {
333        if i != SENTINEL {
334            self.get(i).get_register(r_i as usize)
335        } else {
336            SENTINEL
337        }
338    }
339}