bindb/storage/
binary_tree.rs

1use std::{fmt::Debug, fs::File};
2use binbuf::{bytes_ptr, fixed::Readable, impls::{arb_num, ArbNum}, BytesPtr, Entry, Fixed as _};
3use super::OpenMode;
4
5mod search;
6
7pub trait NodeId: binbuf::fixed::Decode {
8    fn to_u64(self) -> u64;
9    fn from_u64(value: u64) -> Self;
10}
11
12impl<const LEN: usize> NodeId for ArbNum<LEN, u64> {
13    fn from_u64(value: u64) -> Self {
14        ArbNum::new(value)
15    }
16    fn to_u64(self) -> u64 {
17        self.unwrap()
18    }
19}
20
21impl NodeId for u64 {
22    fn from_u64(value: u64) -> Self {
23        value
24    }
25    fn to_u64(self) -> u64 {
26        self
27    }
28}
29
30binbuf::fixed! {
31    pub struct Node<I: NodeId, K, V> {
32        #[lens(buf_key)]
33        key: K,
34        #[lens(buf_value)]
35        value: V,
36        #[lens(buf_left_id)]
37        left_id: I,
38        #[lens(buf_right_id)]
39        right_id: I,
40    }
41    buf! { pub struct NodeBuf<P, I: NodeId, K: binbuf::Fixed, V: binbuf::Fixed>(Node<I, K, V>, P); }
42
43    impl<I: NodeId, K: binbuf::Fixed, V: binbuf::Fixed> I for Node<I, K, V> {
44        type Buf<P> = NodeBuf<P, I, K, V>;
45    }
46
47    impl<I: NodeId, K: binbuf::Fixed, V: binbuf::Fixed> Encode for Node<I, K, V> {}
48    impl<I: NodeId, K: binbuf::fixed::Decode, V: binbuf::fixed::Decode> Decode for Node<I, K, V> {}
49}
50
51binbuf::fixed! {
52    pub struct Header {
53        #[lens(buf_root_id)]
54        root_id: Option<u64>,
55    }
56    buf! { pub struct HeaderBuf<P>(Header, P); }
57
58    impl I for Header {
59        type Buf<P> = HeaderBuf<P>;
60    }
61    impl Code for Header {}
62}
63
64#[derive(Clone, Copy, Debug)]
65pub enum NodeBranch {
66    Left,
67    Right
68}
69
70#[derive(Debug)]
71pub enum AddError {
72    AddNode(super::fixed::AddError),
73    RemoveLastFreeId(super::fixed::RemoveLastError)
74}
75
76#[derive(Debug)]
77pub enum RemoveError {
78    RemoveNode(RemoveNodeError),
79}
80
81#[derive(Debug)]
82pub enum RemoveNodeError {
83    RemoveIfLast(super::fixed::RemoveLastError),
84    AddFreeId(super::fixed::AddError),
85}
86
87#[derive(Debug)]
88pub enum CreateError {
89    CreateHeader(super::single::CreateError),
90    NodesNotEmpty,
91    FreeIdsNotEmpty,
92}
93
94#[derive(Debug)]
95pub enum OpenError {
96    FixedOpen(super::fixed::OpenError),
97    SingleOpen(super::single::OpenError),
98}
99
100pub struct Searched {
101    parent: Option<NodeParent>,
102    id: Option<u64>
103}
104
105impl Searched {
106    pub fn is_found(&self) -> bool {
107        self.id.is_some()
108    }
109
110    pub fn find(self) -> Result<SearchedFound, SearchedNotFound> {
111        match self.id {
112            Some(id) => Ok(SearchedFound { parent: self.parent, id }),
113            None => Err(SearchedNotFound { parent: self.parent })
114        }
115    }
116}
117
118pub struct SearchedFound {
119    parent: Option<NodeParent>,
120    id: u64,
121}
122
123pub struct SearchedNotFound {
124    parent: Option<NodeParent>,
125}
126
127#[derive(Clone, Copy)]
128pub struct NodeParent {
129    id: u64,
130    branch: NodeBranch,
131}
132
133pub struct OpenFiles {
134    pub nodes: File,
135    pub free_ids: File,
136    pub header: File,
137}
138
139pub struct OpenMaxMargins {
140    pub nodes: u64,
141    pub free_ids: u64,
142}
143
144pub struct OpenConfig {
145    pub mode: OpenMode,
146    pub files: OpenFiles,
147    pub max_margins: OpenMaxMargins
148}
149
150pub struct Value<I: NodeId, K, V> {
151    nodes: super::Fixed<Node<I, K, V>>,
152    free_ids: super::Fixed<u64>,
153    header: super::Single<Header>,
154    root_id: Option<u64>,
155}
156
157impl<I: NodeId, K: binbuf::fixed::Decode + Debug, V: binbuf::Fixed> Value<I, K, V> {
158    pub unsafe fn open(OpenConfig { mode, files, max_margins }: OpenConfig) -> Result<Self, OpenError> {
159        let nodes = super::Fixed::open(mode, files.nodes, max_margins.nodes)
160            .map_err(OpenError::FixedOpen)?;
161
162        let header = super::Single::open(
163            match mode {
164                OpenMode::New => super::single::OpenMode::New(&Header { root_id: None }),
165                OpenMode::Existing => super::single::OpenMode::Existing,
166            },
167            files.header,
168        )
169            .map_err(OpenError::SingleOpen)?;
170
171        let root_id = header.get().root_id;
172
173        Ok(Self {
174            nodes,
175            free_ids: super::Fixed::open(mode, files.free_ids, max_margins.free_ids).map_err(OpenError::FixedOpen)?,
176            header,
177            root_id,
178        })
179    }
180
181    unsafe fn node_buf_by_id(&self, id: u64) -> binbuf::BufConst<Node<I, K, V>> {
182        self.nodes.buf_unchecked(id)
183    }
184
185    unsafe fn node_buf_mut_by_id(&mut self, id: u64) -> binbuf::BufMut<Node<I, K, V>> {
186        self.nodes.buf_mut_unchecked(id)
187    }
188
189    pub fn search(&self, key: impl binbuf::fixed::BufOrd<K> + Clone) -> Searched {
190        // let mut arr = [0u8; T::LEN];
191        // let key_buf = unsafe { T::buf(bytes_ptr::Mut::from_slice(&mut arr)) };
192        // key.write_to(key_buf);
193        let mut parent = None;
194        let Some(mut node_id) = self.root_id else {
195            return Searched { id: None, parent: None };
196        };
197        loop {
198            let node = unsafe { self.node_buf_by_id(node_id) };
199            match key.clone().buf_cmp(Node::buf_key(node)) {
200                std::cmp::Ordering::Less => {
201                    let left_id = binbuf::fixed::decode::<I, _>(Node::buf_left_id(node)).to_u64();
202                    parent = Some(NodeParent { id: node_id, branch: NodeBranch::Left });
203                    if left_id == 0 {
204                        return Searched { id: None, parent };
205                        // return Err(Some((node_id, NodeBranch::Left)));
206                    } else {
207                        // parent = Some((node_id, NodeBranch::Left));
208                        node_id = left_id - 1;
209                    }
210                },
211                std::cmp::Ordering::Equal => {
212                    return Searched { parent, id: Some(node_id) };
213                },
214                std::cmp::Ordering::Greater => {
215                    let right_id = binbuf::fixed::decode::<I, _>(Node::buf_right_id(node)).to_u64();
216                    parent = Some(NodeParent { id: node_id, branch: NodeBranch::Right });
217                    if right_id == 0 {
218                        return Searched { id: None, parent };
219                        // return Err(Some((node_id, NodeBranch::Right)));
220                    } else {
221                        node_id = right_id - 1;
222                    }
223                },
224            }
225        }
226    }
227
228    fn set_root_id(&mut self, id: Option<u64>) {
229        id.encode(Header::buf_root_id(self.header.buf_mut()));
230        self.root_id = id;
231    }
232
233    pub unsafe fn buf_searched(&self, searched: &SearchedFound) -> binbuf::BufConst<V> {
234        Node::buf_value(unsafe { self.node_buf_by_id(searched.id) })
235    }
236
237    pub unsafe fn buf_mut_searched(&mut self, searched: &SearchedFound) -> binbuf::BufMut<V> {
238        Node::buf_value(unsafe { self.node_buf_mut_by_id(searched.id) })
239    }
240
241    pub fn buf(&self, key: impl binbuf::fixed::BufOrd<K> + Clone) -> Option<binbuf::BufConst<V>> {
242        self.search(key).find().ok().map(|s| unsafe { self.buf_searched(&s) })
243    }
244
245    pub fn buf_mut(&mut self, key: impl binbuf::fixed::BufOrd<K> + Clone) -> Option<binbuf::BufMut<V>> {
246        self.search(key).find().ok().map(|s| unsafe { self.buf_mut_searched(&s) })
247    }
248
249    // Returns true if item already exists.
250    pub unsafe fn add_searched(
251        &mut self,
252        search: &SearchedNotFound,
253        key: impl binbuf::fixed::BufOrd<K> + Clone,
254        value: impl binbuf::fixed::Readable<V>
255    ) -> Result<(), AddError>
256    where [(); Node::<I, K, V>::LEN]: {
257        let mut node_arr = [0u8; Node::<I, K, V>::LEN];
258        let node_buf = unsafe { Node::buf(bytes_ptr::Mut::from_slice(&mut node_arr)) };
259        key.clone().write_to(Node::<I, K, V>::buf_key(node_buf));
260        value.write_to(Node::<I, K, V>::buf_value(node_buf));
261        I::from_u64(0u64).encode(Node::<I, K, V>::buf_left_id(node_buf));
262        I::from_u64(0u64).encode(Node::<I, K, V>::buf_right_id(node_buf));
263
264        match search.parent {
265            None => {
266                let id = self.nodes.add(node_buf).map_err(AddError::AddNode)?;
267                self.set_root_id(Some(id));
268            }
269            Some(parent) => {
270                let id = match self.free_ids.last_buf() {
271                    Some(id_buf) => {
272                        let id = binbuf::fixed::decode::<u64, _>(id_buf);
273                        node_buf.write_to(unsafe { self.node_buf_mut_by_id(id) });
274                        self.free_ids.remove_last().map_err(AddError::RemoveLastFreeId)?;
275                        id
276                    },
277                    None => {
278                        self.nodes.add(node_buf).map_err(AddError::AddNode)?
279                    }
280                };
281
282                let parent_buf = unsafe { self.node_buf_mut_by_id(parent.id) };
283                match parent.branch {
284                    NodeBranch::Left => {
285                        I::from_u64(id + 1).write_to(Node::<I, K, V>::buf_left_id(parent_buf));
286                    },
287                    NodeBranch::Right => {
288                        I::from_u64(id + 1).write_to(Node::<I, K, V>::buf_right_id(parent_buf));
289                    }
290                }
291            }
292        }
293        Ok(())
294    }
295
296    pub fn add(&mut self, key: impl binbuf::fixed::BufOrd<K> + Clone, value: impl binbuf::fixed::Readable<V>) -> Result<bool, AddError>
297    where [(); Node::<I, K, V>::LEN]: {
298        match self.search(key.clone()).find() {
299            Ok(_) => Ok(true),
300            Err(s) => {
301                unsafe { self.add_searched(&s, key, value) }?;
302                Ok(false)
303            }
304        }
305    }
306
307    fn remove_node(&mut self, id: u64) -> Result<(), RemoveNodeError> {
308        if self.nodes.remove_if_last(id).map_err(RemoveNodeError::RemoveIfLast)? {
309            self.free_ids.add(&id).map_err(RemoveNodeError::AddFreeId)?;
310        }
311        Ok(())
312    }
313
314    // Returns true if item doesn't exist.
315    pub unsafe fn remove_searched(&mut self, searched: &SearchedFound) -> Result<(), RemoveError> {
316        let node_buf: NodeBuf<bytes_ptr::Mut, I, K, V> = unsafe { self.node_buf_mut_by_id(searched.id) };
317        let left_id = binbuf::fixed::decode::<I, _>(Node::buf_left_id(node_buf)).to_u64();
318        let right_id = binbuf::fixed::decode::<I, _>(Node::buf_right_id(node_buf)).to_u64();
319        
320        match (left_id, right_id, searched.parent) {
321            (_, _, Some(parent)) if left_id == 0 || right_id == 0 => {
322                self.remove_node(searched.id).map_err(RemoveError::RemoveNode)?;
323
324                let connect_id = if left_id == 0 { right_id } else { left_id };
325                let mut parent_buf = unsafe { self.node_buf_mut_by_id(parent.id) };
326                match parent.branch {
327                    NodeBranch::Left => {
328                        I::from_u64(connect_id).encode(Node::<I, K, V>::buf_left_id(parent_buf));
329                    }
330                    NodeBranch::Right => {
331                        I::from_u64(connect_id).encode(Node::<I, K, V>::buf_right_id(parent_buf));
332                    }
333                }
334            },
335            (0, 0, None) => {
336                self.remove_node(searched.id).map_err(RemoveError::RemoveNode)?;
337                self.set_root_id(None);
338            },
339            (_, _, None) if left_id == 0 || right_id == 0 => {
340                self.remove_node(searched.id).map_err(RemoveError::RemoveNode)?;
341
342                let root_id = if left_id == 0 { right_id } else { left_id } - 1;
343                self.set_root_id(Some(root_id));
344            },
345
346            // The most complex case to handle: both left and right branches exist.
347            (_, _, parent) => {
348                debug_assert_ne!(left_id, 0);
349                debug_assert_ne!(right_id, 0);
350
351                self.remove_node(searched.id).map_err(RemoveError::RemoveNode)?;
352
353                let mut node_parent_id = searched.id;
354                let mut node_id = right_id - 1;
355                let mut idx = 0;
356                loop {
357                    idx += 1;
358                    if idx > 100 {
359                        panic!("Loop stuck!");
360                    }
361
362                    let node_buf = unsafe { self.node_buf_mut_by_id(node_id) };
363                    let node_left_id_buf = Node::<I, K, V>::buf_left_id(node_buf);
364                    let node_left_id = binbuf::fixed::decode::<I, _>(node_left_id_buf).to_u64();
365                    let node_right_id_buf = Node::<I, K, V>::buf_right_id(node_buf);
366                    let node_right_id = binbuf::fixed::decode::<I, _>(node_right_id_buf).to_u64();
367                    if node_left_id == 0 {
368                        I::from_u64(node_right_id)
369                            .write_to(Node::buf_left_id(unsafe { self.node_buf_mut_by_id(node_parent_id) }));
370
371                        I::from_u64(left_id).write_to(node_left_id_buf);
372                        if searched.id != node_parent_id {
373                            I::from_u64(right_id).write_to(node_right_id_buf);
374                        }
375
376                        match parent {
377                            Some(parent) => {
378                                let parent_buf = unsafe { self.node_buf_mut_by_id(parent.id) };
379                                match parent.branch {
380                                    NodeBranch::Left => {
381                                        I::from_u64(node_id + 1).encode(Node::<I, K, V>::buf_left_id(parent_buf));
382                                    }
383                                    NodeBranch::Right => {
384                                        I::from_u64(node_id + 1).encode(Node::<I, K, V>::buf_right_id(parent_buf));
385                                    }
386                                }
387                            },
388                            None => {
389                                self.set_root_id(Some(node_id));
390                            }
391                        }
392
393                        break;
394                    } else {
395                        node_parent_id = node_id;
396                        node_id = node_left_id - 1;
397                    }
398                }
399            }
400        }
401        Ok(())
402    }
403
404    pub fn remove(&mut self, key: impl binbuf::fixed::BufOrd<K> + Clone) -> Result<bool, RemoveError> {
405        match self.search(key).find() {
406            Ok(s) => {
407                unsafe { self.remove_searched(&s) }?;
408                Ok(false)
409            },
410            Err(_) => Ok(true)
411        }
412    }
413}
414
415impl<I: NodeId, K: binbuf::fixed::Decode + Debug, V: binbuf::fixed::Decode> Value<I, K, V> {
416    pub fn get(&self, key: impl binbuf::fixed::BufOrd<K> + Clone) -> Option<V> {
417        self.search(key).find().ok().map(|s| unsafe { self.get_searched(&s) })
418    }
419
420    pub unsafe fn get_searched(&self, searched: &SearchedFound) -> V {
421        V::decode(self.buf_searched(searched))
422    }
423}