1use std::ptr;
11
12struct Node<V> {
14 key: u64,
15 value: V,
16 left: *mut Node<V>,
17 right: *mut Node<V>,
18}
19
20pub struct SplayTree<V> {
22 root: *mut Node<V>,
23 len: usize,
24}
25
26impl<V> SplayTree<V> {
27 pub fn new() -> Self {
28 SplayTree {
29 root: ptr::null_mut(),
30 len: 0,
31 }
32 }
33
34 pub fn len(&self) -> usize {
35 self.len
36 }
37
38 pub fn is_empty(&self) -> bool {
39 self.len == 0
40 }
41
42 pub fn find(&mut self, key: u64) -> Option<&mut V> {
45 if self.root.is_null() {
46 return None;
47 }
48 self.splay(key);
49 unsafe {
50 if (*self.root).key == key {
51 Some(&mut (*self.root).value)
52 } else {
53 None
54 }
55 }
56 }
57
58 pub fn insert_or_get(&mut self, key: u64, value: V) -> &mut V {
61 if self.root.is_null() {
62 let node = Box::into_raw(Box::new(Node {
63 key,
64 value,
65 left: ptr::null_mut(),
66 right: ptr::null_mut(),
67 }));
68 self.root = node;
69 self.len += 1;
70 return unsafe { &mut (*self.root).value };
71 }
72
73 self.splay(key);
74
75 unsafe {
76 if (*self.root).key == key {
77 return &mut (*self.root).value;
78 }
79
80 let node = Box::into_raw(Box::new(Node {
81 key,
82 value,
83 left: ptr::null_mut(),
84 right: ptr::null_mut(),
85 }));
86 self.len += 1;
87
88 if key < (*self.root).key {
89 (*node).left = (*self.root).left;
90 (*node).right = self.root;
91 (*self.root).left = ptr::null_mut();
92 } else {
93 (*node).right = (*self.root).right;
94 (*node).left = self.root;
95 (*self.root).right = ptr::null_mut();
96 }
97 self.root = node;
98 &mut (*self.root).value
99 }
100 }
101
102 pub fn insert(&mut self, key: u64, value: V) {
104 if self.root.is_null() {
105 let node = Box::into_raw(Box::new(Node {
106 key,
107 value,
108 left: ptr::null_mut(),
109 right: ptr::null_mut(),
110 }));
111 self.root = node;
112 self.len += 1;
113 return;
114 }
115
116 self.splay(key);
117
118 unsafe {
119 if (*self.root).key == key {
120 (*self.root).value = value;
121 return;
122 }
123
124 let node = Box::into_raw(Box::new(Node {
125 key,
126 value,
127 left: ptr::null_mut(),
128 right: ptr::null_mut(),
129 }));
130 self.len += 1;
131
132 if key < (*self.root).key {
133 (*node).left = (*self.root).left;
134 (*node).right = self.root;
135 (*self.root).left = ptr::null_mut();
136 } else {
137 (*node).right = (*self.root).right;
138 (*node).left = self.root;
139 (*self.root).right = ptr::null_mut();
140 }
141 self.root = node;
142 }
143 }
144
145 fn splay(&mut self, key: u64) {
150 use std::mem::MaybeUninit;
151
152 if self.root.is_null() {
153 return;
154 }
155
156 let mut header = MaybeUninit::<Node<V>>::uninit();
158 let header_ptr = header.as_mut_ptr();
159 unsafe {
160 (*header_ptr).left = ptr::null_mut();
161 (*header_ptr).right = ptr::null_mut();
162 }
163 let mut l: *mut Node<V> = header_ptr;
164 let mut r: *mut Node<V> = header_ptr;
165 let mut t = self.root;
166
167 unsafe {
168 loop {
169 if key < (*t).key {
170 if (*t).left.is_null() {
171 break;
172 }
173 if key < (*(*t).left).key {
174 let y = (*t).left;
176 (*t).left = (*y).right;
177 (*y).right = t;
178 t = y;
179 if (*t).left.is_null() {
180 break;
181 }
182 }
183 (*r).left = t;
185 r = t;
186 t = (*t).left;
187 } else if key > (*t).key {
188 if (*t).right.is_null() {
189 break;
190 }
191 if key > (*(*t).right).key {
192 let y = (*t).right;
194 (*t).right = (*y).left;
195 (*y).left = t;
196 t = y;
197 if (*t).right.is_null() {
198 break;
199 }
200 }
201 (*l).right = t;
203 l = t;
204 t = (*t).right;
205 } else {
206 break; }
208 }
209
210 (*l).right = (*t).left;
212 (*r).left = (*t).right;
213 (*t).left = (*header_ptr).right;
214 (*t).right = (*header_ptr).left;
215 self.root = t;
216 }
217 }
218}
219
220impl<V> Drop for SplayTree<V> {
221 fn drop(&mut self) {
222 let mut stack = Vec::new();
224 if !self.root.is_null() {
225 stack.push(self.root);
226 }
227 while let Some(node) = stack.pop() {
228 unsafe {
229 if !(*node).left.is_null() {
230 stack.push((*node).left);
231 }
232 if !(*node).right.is_null() {
233 stack.push((*node).right);
234 }
235 drop(Box::from_raw(node));
236 }
237 }
238 self.root = ptr::null_mut();
239 self.len = 0;
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn insert_and_find() {
249 let mut tree = SplayTree::new();
250 tree.insert_or_get(42, vec![0usize]);
251 tree.find(42).unwrap().push(1);
252 assert_eq!(tree.find(42).unwrap(), &vec![0, 1]);
253 assert!(tree.find(99).is_none());
254 assert_eq!(tree.len(), 1);
255 }
256
257 #[test]
258 fn insert_or_get_retains_existing() {
259 let mut tree = SplayTree::new();
260 tree.insert_or_get(10, 100usize);
261 tree.insert_or_get(10, 200);
262 assert_eq!(*tree.find(10).unwrap(), 100);
263 }
264
265 #[test]
266 fn insert_overwrites() {
267 let mut tree = SplayTree::new();
268 tree.insert(10, 100usize);
269 tree.insert(10, 200);
270 assert_eq!(*tree.find(10).unwrap(), 200);
271 }
272
273 #[test]
274 fn many_keys() {
275 let mut tree = SplayTree::new();
276 for i in 0..1000u64 {
277 tree.insert_or_get(i, i as usize);
278 }
279 assert_eq!(tree.len(), 1000);
280 for i in 0..1000u64 {
281 assert_eq!(*tree.find(i).unwrap(), i as usize);
282 }
283 }
284}