nebari/tree/
interior.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4};
5
6use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
7
8use super::{
9    btree_entry::{BTreeEntry, Reducer},
10    read_chunk, BinarySerialization, PagedWriter,
11};
12use crate::{
13    chunk_cache::CacheEntry,
14    error::Error,
15    io::File,
16    tree::{btree_entry::NodeInclusion, key_entry::ValueIndex},
17    vault::AnyVault,
18    AbortError, ArcBytes, ChunkCache, ErrorKind,
19};
20
21/// An interior B-Tree node. Does not contain values directly, and instead
22/// points to a node located on-disk elsewhere.
23#[derive(Clone, Debug)]
24pub struct Interior<Index, ReducedIndex> {
25    /// The key with the highest sort value within.
26    pub key: ArcBytes<'static>,
27    /// The location of the node.
28    pub position: Pointer<Index, ReducedIndex>,
29    /// The reduced statistics.
30    pub stats: ReducedIndex,
31}
32
33impl<Index, ReducedIndex> From<BTreeEntry<Index, ReducedIndex>> for Interior<Index, ReducedIndex>
34where
35    Index: Clone + Debug + ValueIndex + BinarySerialization + 'static,
36    ReducedIndex: Reducer<Index> + Clone + Debug + BinarySerialization + 'static,
37{
38    fn from(entry: BTreeEntry<Index, ReducedIndex>) -> Self {
39        let key = entry.max_key().clone();
40        let stats = entry.stats();
41
42        Self {
43            key,
44            stats,
45            position: Pointer::Loaded {
46                previous_location: None,
47                entry: Box::new(entry),
48            },
49        }
50    }
51}
52
53/// A pointer to a location on-disk. May also contain the node already loaded.
54#[derive(Clone, Debug)]
55pub enum Pointer<Index, ReducedIndex> {
56    /// The position on-disk of the node.
57    OnDisk(u64),
58    /// An in-memory node that may have previously been saved on-disk.
59    Loaded {
60        /// The position on-disk of the node, if it was previously saved.
61        previous_location: Option<u64>,
62        /// The loaded B-Tree entry.
63        entry: Box<BTreeEntry<Index, ReducedIndex>>,
64    },
65}
66
67impl<
68        Index: BinarySerialization + Debug + Clone + 'static,
69        ReducedIndex: Reducer<Index> + BinarySerialization + Debug + Clone + 'static,
70    > Pointer<Index, ReducedIndex>
71{
72    /// Attempts to load the node from disk. If the node is already loaded, this
73    /// function does nothing.
74    #[allow(clippy::missing_panics_doc)] // Currently the only panic is if the types don't match, which shouldn't happen due to these nodes always being accessed through a root.
75    pub fn load(
76        &mut self,
77        file: &mut dyn File,
78        validate_crc: bool,
79        vault: Option<&dyn AnyVault>,
80        cache: Option<&ChunkCache>,
81        current_order: Option<usize>,
82    ) -> Result<(), Error> {
83        match self {
84            Pointer::OnDisk(position) => {
85                let entry = match read_chunk(*position, validate_crc, file, vault, cache)? {
86                    CacheEntry::ArcBytes(mut buffer) => {
87                        // It's worthless to store this node in the cache
88                        // because if we mutate, we'll be rewritten.
89                        Box::new(BTreeEntry::deserialize_from(&mut buffer, current_order)?)
90                    }
91                    CacheEntry::Decoded(node) => node
92                        .as_ref()
93                        .as_any()
94                        .downcast_ref::<Box<BTreeEntry<Index, ReducedIndex>>>()
95                        .unwrap()
96                        .clone(),
97                };
98                *self = Self::Loaded {
99                    entry,
100                    previous_location: Some(*position),
101                };
102            }
103            Pointer::Loaded { .. } => {}
104        }
105        Ok(())
106    }
107
108    /// Returns the previously-[`load()`ed](Self::load) entry.
109    pub fn get(&mut self) -> Option<&BTreeEntry<Index, ReducedIndex>> {
110        match self {
111            Pointer::OnDisk(_) => None,
112            Pointer::Loaded { entry, .. } => Some(entry),
113        }
114    }
115
116    /// Returns the previously-[`load()`ed](Self::load) entry as a mutable reference.
117    pub fn get_mut(&mut self) -> Option<&mut BTreeEntry<Index, ReducedIndex>> {
118        match self {
119            Pointer::OnDisk(_) => None,
120            Pointer::Loaded { entry, .. } => Some(entry.as_mut()),
121        }
122    }
123
124    /// Returns the position on-disk of the node being pointed at, if the node
125    /// has been saved before.
126    #[must_use]
127    pub fn position(&self) -> Option<u64> {
128        match self {
129            Pointer::OnDisk(location) => Some(*location),
130            Pointer::Loaded {
131                previous_location, ..
132            } => *previous_location,
133        }
134    }
135
136    /// Loads the pointed at node, if necessary, and invokes `callback` with the
137    /// loaded node. This is useful in situations where the node isn't needed to
138    /// be accessed mutably.
139    #[allow(clippy::missing_panics_doc)]
140    pub fn map_loaded_entry<
141        Output,
142        CallerError: Display + Debug,
143        Cb: FnOnce(
144            &BTreeEntry<Index, ReducedIndex>,
145            &mut dyn File,
146        ) -> Result<Output, AbortError<CallerError>>,
147    >(
148        &self,
149        file: &mut dyn File,
150        vault: Option<&dyn AnyVault>,
151        cache: Option<&ChunkCache>,
152        current_order: Option<usize>,
153        callback: Cb,
154    ) -> Result<Output, AbortError<CallerError>> {
155        match self {
156            Pointer::OnDisk(position) => match read_chunk(*position, false, file, vault, cache)? {
157                CacheEntry::ArcBytes(mut buffer) => {
158                    let decoded = BTreeEntry::deserialize_from(&mut buffer, current_order)?;
159
160                    let result = callback(&decoded, file);
161                    if let (Some(cache), Some(file_id)) = (cache, file.id()) {
162                        cache.replace_with_decoded(file_id, *position, Box::new(decoded));
163                    }
164                    result
165                }
166                CacheEntry::Decoded(value) => {
167                    let entry = value
168                        .as_ref()
169                        .as_any()
170                        .downcast_ref::<Box<BTreeEntry<Index, ReducedIndex>>>()
171                        .unwrap();
172                    callback(entry, file)
173                }
174            },
175            Pointer::Loaded { entry, .. } => callback(entry, file),
176        }
177    }
178}
179
180impl<
181        Index: Clone + ValueIndex + BinarySerialization + Debug + 'static,
182        ReducedIndex: Reducer<Index> + Clone + BinarySerialization + Debug + 'static,
183    > Interior<Index, ReducedIndex>
184{
185    #[allow(clippy::too_many_arguments)]
186    pub(crate) fn copy_data_to<Callback>(
187        &mut self,
188        include_nodes: NodeInclusion,
189        file: &mut dyn File,
190        copied_chunks: &mut HashMap<u64, u64>,
191        writer: &mut PagedWriter<'_>,
192        vault: Option<&dyn AnyVault>,
193        scratch: &mut Vec<u8>,
194        index_callback: &mut Callback,
195    ) -> Result<bool, Error>
196    where
197        Callback: FnMut(
198            &ArcBytes<'static>,
199            &mut Index,
200            &mut dyn File,
201            &mut HashMap<u64, u64>,
202            &mut PagedWriter<'_>,
203            Option<&dyn AnyVault>,
204        ) -> Result<bool, Error>,
205    {
206        self.position.load(file, true, vault, None, None)?;
207        let node = self.position.get_mut().unwrap();
208        let mut any_data_copied = node.copy_data_to(
209            include_nodes,
210            file,
211            copied_chunks,
212            writer,
213            vault,
214            scratch,
215            index_callback,
216        )?;
217
218        // Serialize if we are supposed to
219        let position = if include_nodes.should_include() {
220            any_data_copied = true;
221            scratch.clear();
222            node.serialize_to(scratch, writer)?;
223            Some(writer.write_chunk(scratch)?)
224        } else {
225            self.position.position()
226        };
227
228        // Remove the node from memory to save RAM during the compaction process.
229        if let Some(position) = position {
230            self.position = Pointer::OnDisk(position);
231        }
232
233        Ok(any_data_copied)
234    }
235}
236
237impl<
238        Index: Clone + BinarySerialization + Debug + 'static,
239        ReducedIndex: Reducer<Index> + Clone + BinarySerialization + Debug + 'static,
240    > BinarySerialization for Interior<Index, ReducedIndex>
241{
242    fn serialize_to(
243        &mut self,
244        writer: &mut Vec<u8>,
245        paged_writer: &mut PagedWriter<'_>,
246    ) -> Result<usize, Error> {
247        let mut pointer = Pointer::OnDisk(0);
248        std::mem::swap(&mut pointer, &mut self.position);
249        let location_on_disk = match pointer {
250            Pointer::OnDisk(position) => position,
251            Pointer::Loaded {
252                mut entry,
253                previous_location,
254            } => match (entry.dirty, previous_location) {
255                // Serialize if dirty, or if this node hasn't been on-disk before.
256                (true, _) | (_, None) => {
257                    entry.dirty = false;
258                    let old_writer_length = writer.len();
259                    entry.serialize_to(writer, paged_writer)?;
260                    let position =
261                        paged_writer.write_chunk(&writer[old_writer_length..writer.len()])?;
262                    writer.truncate(old_writer_length);
263                    if let (Some(cache), Some(file_id)) = (paged_writer.cache, paged_writer.id()) {
264                        cache.replace_with_decoded(file_id, position, entry);
265                    }
266                    position
267                }
268                (false, Some(position)) => position,
269            },
270        };
271        self.position = Pointer::OnDisk(location_on_disk);
272        let mut bytes_written = 0;
273        // Write the key
274        let key_len = u16::try_from(self.key.len()).map_err(|_| ErrorKind::KeyTooLarge)?;
275        writer.write_u16::<BigEndian>(key_len)?;
276        writer.extend_from_slice(&self.key);
277        bytes_written += 2 + key_len as usize;
278
279        writer.write_u64::<BigEndian>(location_on_disk)?;
280        bytes_written += 8;
281
282        bytes_written += self.stats.serialize_to(writer, paged_writer)?;
283
284        Ok(bytes_written)
285    }
286
287    fn deserialize_from(
288        reader: &mut ArcBytes<'_>,
289        current_order: Option<usize>,
290    ) -> Result<Self, Error> {
291        let key_len = reader.read_u16::<BigEndian>()? as usize;
292        if key_len > reader.len() {
293            return Err(Error::data_integrity(format!(
294                "key length {} found but only {} bytes remaining",
295                key_len,
296                reader.len()
297            )));
298        }
299        let key = reader.read_bytes(key_len)?.into_owned();
300
301        let position = reader.read_u64::<BigEndian>()?;
302        let stats = ReducedIndex::deserialize_from(reader, current_order)?;
303
304        Ok(Self {
305            key,
306            position: Pointer::OnDisk(position),
307            stats,
308        })
309    }
310}