Skip to main content

delta/
splay.rs

1//! Tarjan-Sleator splay tree keyed on u64 fingerprints.
2//!
3//! A self-adjusting binary search tree: every access (find/insert)
4//! splays the accessed node to the root via zig/zig-zig/zig-zag
5//! rotations.  Amortized O(log n) per operation.
6//!
7//! Reference: Sleator & Tarjan, "Self-Adjusting Binary Search Trees",
8//! JACM 32(3), 1985.
9
10use std::ptr;
11
12/// A node in the splay tree.
13struct Node<V> {
14    key: u64,
15    value: V,
16    left: *mut Node<V>,
17    right: *mut Node<V>,
18}
19
20/// A splay tree mapping u64 keys to values of type V.
21pub struct SplayTree<V> {
22    root: *mut Node<V>,
23    len: usize,
24}
25
26impl<V> SplayTree<V> {
27    pub fn new() -> Self {
28        SplayTree {
29            root: ptr::null_mut(),
30            len: 0,
31        }
32    }
33
34    pub fn len(&self) -> usize {
35        self.len
36    }
37
38    pub fn is_empty(&self) -> bool {
39        self.len == 0
40    }
41
42    /// Find key; returns reference to value or None.
43    /// Splays the found node (or last visited) to root.
44    pub fn find(&mut self, key: u64) -> Option<&mut V> {
45        if self.root.is_null() {
46            return None;
47        }
48        self.splay(key);
49        unsafe {
50            if (*self.root).key == key {
51                Some(&mut (*self.root).value)
52            } else {
53                None
54            }
55        }
56    }
57
58    /// Insert key with value if absent; returns mutable reference to
59    /// the (possibly pre-existing) value. Splays to root.
60    pub fn insert_or_get(&mut self, key: u64, value: V) -> &mut V {
61        if self.root.is_null() {
62            let node = Box::into_raw(Box::new(Node {
63                key,
64                value,
65                left: ptr::null_mut(),
66                right: ptr::null_mut(),
67            }));
68            self.root = node;
69            self.len += 1;
70            return unsafe { &mut (*self.root).value };
71        }
72
73        self.splay(key);
74
75        unsafe {
76            if (*self.root).key == key {
77                return &mut (*self.root).value;
78            }
79
80            let node = Box::into_raw(Box::new(Node {
81                key,
82                value,
83                left: ptr::null_mut(),
84                right: ptr::null_mut(),
85            }));
86            self.len += 1;
87
88            if key < (*self.root).key {
89                (*node).left = (*self.root).left;
90                (*node).right = self.root;
91                (*self.root).left = ptr::null_mut();
92            } else {
93                (*node).right = (*self.root).right;
94                (*node).left = self.root;
95                (*self.root).right = ptr::null_mut();
96            }
97            self.root = node;
98            &mut (*self.root).value
99        }
100    }
101
102    /// Insert key with value, overwriting any existing entry.
103    pub fn insert(&mut self, key: u64, value: V) {
104        if self.root.is_null() {
105            let node = Box::into_raw(Box::new(Node {
106                key,
107                value,
108                left: ptr::null_mut(),
109                right: ptr::null_mut(),
110            }));
111            self.root = node;
112            self.len += 1;
113            return;
114        }
115
116        self.splay(key);
117
118        unsafe {
119            if (*self.root).key == key {
120                (*self.root).value = value;
121                return;
122            }
123
124            let node = Box::into_raw(Box::new(Node {
125                key,
126                value,
127                left: ptr::null_mut(),
128                right: ptr::null_mut(),
129            }));
130            self.len += 1;
131
132            if key < (*self.root).key {
133                (*node).left = (*self.root).left;
134                (*node).right = self.root;
135                (*self.root).left = ptr::null_mut();
136            } else {
137                (*node).right = (*self.root).right;
138                (*node).left = self.root;
139                (*self.root).right = ptr::null_mut();
140            }
141            self.root = node;
142        }
143    }
144
145    /// Top-down splay (Sleator & Tarjan 1985).
146    ///
147    /// Uses MaybeUninit for the header sentinel since we only access
148    /// the left/right pointer fields, never the key or value.
149    fn splay(&mut self, key: u64) {
150        use std::mem::MaybeUninit;
151
152        if self.root.is_null() {
153            return;
154        }
155
156        // Sentinel header — only left/right are used.
157        let mut header = MaybeUninit::<Node<V>>::uninit();
158        let header_ptr = header.as_mut_ptr();
159        unsafe {
160            (*header_ptr).left = ptr::null_mut();
161            (*header_ptr).right = ptr::null_mut();
162        }
163        let mut l: *mut Node<V> = header_ptr;
164        let mut r: *mut Node<V> = header_ptr;
165        let mut t = self.root;
166
167        unsafe {
168            loop {
169                if key < (*t).key {
170                    if (*t).left.is_null() {
171                        break;
172                    }
173                    if key < (*(*t).left).key {
174                        // Zig-zig: rotate right
175                        let y = (*t).left;
176                        (*t).left = (*y).right;
177                        (*y).right = t;
178                        t = y;
179                        if (*t).left.is_null() {
180                            break;
181                        }
182                    }
183                    // Link right
184                    (*r).left = t;
185                    r = t;
186                    t = (*t).left;
187                } else if key > (*t).key {
188                    if (*t).right.is_null() {
189                        break;
190                    }
191                    if key > (*(*t).right).key {
192                        // Zig-zig: rotate left
193                        let y = (*t).right;
194                        (*t).right = (*y).left;
195                        (*y).left = t;
196                        t = y;
197                        if (*t).right.is_null() {
198                            break;
199                        }
200                    }
201                    // Link left
202                    (*l).right = t;
203                    l = t;
204                    t = (*t).right;
205                } else {
206                    break; // found
207                }
208            }
209
210            // Assemble
211            (*l).right = (*t).left;
212            (*r).left = (*t).right;
213            (*t).left = (*header_ptr).right;
214            (*t).right = (*header_ptr).left;
215            self.root = t;
216        }
217    }
218}
219
220impl<V> Drop for SplayTree<V> {
221    fn drop(&mut self) {
222        // Iterative destruction using a stack to avoid deep recursion.
223        let mut stack = Vec::new();
224        if !self.root.is_null() {
225            stack.push(self.root);
226        }
227        while let Some(node) = stack.pop() {
228            unsafe {
229                if !(*node).left.is_null() {
230                    stack.push((*node).left);
231                }
232                if !(*node).right.is_null() {
233                    stack.push((*node).right);
234                }
235                drop(Box::from_raw(node));
236            }
237        }
238        self.root = ptr::null_mut();
239        self.len = 0;
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn insert_and_find() {
249        let mut tree = SplayTree::new();
250        tree.insert_or_get(42, vec![0usize]);
251        tree.find(42).unwrap().push(1);
252        assert_eq!(tree.find(42).unwrap(), &vec![0, 1]);
253        assert!(tree.find(99).is_none());
254        assert_eq!(tree.len(), 1);
255    }
256
257    #[test]
258    fn insert_or_get_retains_existing() {
259        let mut tree = SplayTree::new();
260        tree.insert_or_get(10, 100usize);
261        tree.insert_or_get(10, 200);
262        assert_eq!(*tree.find(10).unwrap(), 100);
263    }
264
265    #[test]
266    fn insert_overwrites() {
267        let mut tree = SplayTree::new();
268        tree.insert(10, 100usize);
269        tree.insert(10, 200);
270        assert_eq!(*tree.find(10).unwrap(), 200);
271    }
272
273    #[test]
274    fn many_keys() {
275        let mut tree = SplayTree::new();
276        for i in 0..1000u64 {
277            tree.insert_or_get(i, i as usize);
278        }
279        assert_eq!(tree.len(), 1000);
280        for i in 0..1000u64 {
281            assert_eq!(*tree.find(i).unwrap(), i as usize);
282        }
283    }
284}