cellular_raza_core/storage/
sled_database.rs

1use super::concepts::StorageError;
2use super::concepts::{StorageInterfaceLoad, StorageInterfaceOpen, StorageInterfaceStore};
3
4use serde::{Deserialize, Serialize};
5
6use std::collections::{BTreeMap, HashMap};
7use std::marker::PhantomData;
8
9/// Use the [sled] database to save results to an embedded database.
10// TODO use custom field for config [](https://docs.rs/sled/latest/sled/struct.Config.html) to let the user control these parameters
11#[derive(Clone, Debug)]
12pub struct SledStorageInterface<Id, Element, const TEMP: bool = false> {
13    db: sled::Db,
14    // TODO use this buffer
15    // buffer: StorageBuffer<Id, Element>,
16    id_phantom: PhantomData<Id>,
17    element_phantom: PhantomData<Element>,
18}
19
20impl<Id, Element, const TEMP: bool> SledStorageInterface<Id, Element, TEMP> {
21    /// Transform a u64 value to an iteration key which can be given to a sled tree.
22    fn iteration_to_key(iteration: u64) -> [u8; 8] {
23        iteration.to_le_bytes()
24    }
25
26    /// Transform the key given by the tree to the corresponding iteartion u64 value
27    fn key_to_iteration(key: &sled::IVec) -> Result<u64, StorageError> {
28        let iteration = bincode::deserialize::<u64>(key)?;
29        Ok(iteration)
30    }
31
32    /// Get the correct tree of the iteration or create if not currently present.
33    fn open_or_create_tree(&self, iteration: u64) -> Result<sled::Tree, StorageError> {
34        let tree_key = Self::iteration_to_key(iteration);
35        let tree = self.db.open_tree(&tree_key)?;
36        Ok(tree)
37    }
38
39    fn open_tree(&self, iteration: u64) -> Result<Option<sled::Tree>, StorageError> {
40        let tree_key = Self::iteration_to_key(iteration);
41        if !self.db.tree_names().contains(&sled::IVec::from(&tree_key)) {
42            Ok(None)
43        } else {
44            let tree = self.db.open_tree(tree_key)?;
45            Ok(Some(tree))
46        }
47    }
48}
49
50impl<Id, Element, const TEMP: bool> StorageInterfaceOpen
51    for SledStorageInterface<Id, Element, TEMP>
52{
53    fn open_or_create(
54        location: &std::path::Path,
55        _storage_instance: u64,
56    ) -> Result<Self, StorageError> {
57        let config = sled::Config::default()
58            .mode(sled::Mode::HighThroughput)
59            .cache_capacity(1024 * 1024 * 1024 * 5) // 5gb
60            .path(&location)
61            .temporary(TEMP)
62            .use_compression(false);
63
64        let db = config.open()?;
65
66        Ok(SledStorageInterface {
67            db,
68            id_phantom: PhantomData,
69            element_phantom: PhantomData,
70        })
71    }
72}
73
74impl<Id, Element, const TEMP: bool> StorageInterfaceStore<Id, Element>
75    for SledStorageInterface<Id, Element, TEMP>
76{
77    fn store_single_element(
78        &mut self,
79        iteration: u64,
80        identifier: &Id,
81        element: &Element,
82    ) -> Result<(), StorageError>
83    where
84        Id: Serialize,
85        Element: Serialize,
86    {
87        let tree = self.open_or_create_tree(iteration)?;
88
89        // Serialize the identifier and the element
90        let identifier_serialized = bincode::serialize(&identifier)?;
91        let element_serialized = bincode::serialize(&element)?;
92        match tree.insert(identifier_serialized, element_serialized)? {
93            None => Ok(()),
94            Some(_) => Err(StorageError::InitError(format!(
95                "Element already present at iteration {}",
96                iteration
97            ))),
98        }?;
99        Ok(())
100    }
101
102    fn store_batch_elements<'a, I>(
103        &'a mut self,
104        iteration: u64,
105        identifiers_elements: I,
106    ) -> Result<(), StorageError>
107    where
108        Id: 'a + Serialize,
109        Element: 'a + Serialize,
110        I: Clone + IntoIterator<Item = (&'a Id, &'a Element)>,
111    {
112        let tree = self.open_or_create_tree(iteration)?;
113        let mut batch = sled::Batch::default();
114        for (identifier, element) in identifiers_elements.into_iter() {
115            let identifier_serialized = bincode::serialize(&identifier)?;
116            let element_serialized = bincode::serialize(&element)?;
117            batch.insert(identifier_serialized, element_serialized)
118        }
119        tree.apply_batch(batch)?;
120        Ok(())
121    }
122}
123
124impl<Id, Element, const TEMP: bool> StorageInterfaceLoad<Id, Element>
125    for SledStorageInterface<Id, Element, TEMP>
126{
127    fn load_single_element(
128        &self,
129        iteration: u64,
130        identifier: &Id,
131    ) -> Result<Option<Element>, StorageError>
132    where
133        Id: Serialize + for<'a> Deserialize<'a>,
134        Element: for<'a> Deserialize<'a>,
135    {
136        let tree = match self.open_tree(iteration)? {
137            Some(tree) => tree,
138            None => return Ok(None),
139        };
140        let identifier_serialized = bincode::serialize(identifier)?;
141        match tree.get(&identifier_serialized)? {
142            Some(element_serialized) => {
143                let element: Element = bincode::deserialize(&element_serialized)?;
144                Ok(Some(element))
145            }
146            None => Ok(None),
147        }
148    }
149
150    fn load_element_history(&self, identifier: &Id) -> Result<HashMap<u64, Element>, StorageError>
151    where
152        Id: Serialize,
153        Element: for<'a> Deserialize<'a>,
154    {
155        // Keep track if the element has not been found in a database.
156        // If so we can either get the current minimal iteration or maximal depending on where it was found else.
157        let mut minimal_iteration = None;
158        let mut maximal_iteration = None;
159        let mut success_iteration = None;
160
161        // Save results in this hashmap
162        let mut accumulator = HashMap::new();
163        // Serialize the identifier
164        let identifier_serialized = bincode::serialize(identifier)?;
165        for iteration_serialized in self.db.tree_names() {
166            // If we are above the maximal or below the minimal iteration, we skip checking
167            let iteration: u64 = bincode::deserialize(&iteration_serialized)?;
168            match minimal_iteration {
169                None => (),
170                Some(min_iter) => {
171                    if iteration < min_iter {
172                        continue;
173                    }
174                }
175            }
176            match maximal_iteration {
177                None => (),
178                Some(max_iter) => {
179                    if max_iter < iteration {
180                        continue;
181                    }
182                }
183            }
184            // Get the tree for a random iteration
185            let tree = self.db.open_tree(iteration_serialized)?;
186            match tree.get(&identifier_serialized)? {
187                // We found and element insert it
188                Some(element_serialized) => {
189                    let element: Element = bincode::deserialize(&element_serialized)?;
190                    accumulator.insert(iteration, element);
191                    success_iteration = Some(iteration);
192                }
193                // We did not find an element. Thus update the helper variables atop.
194                None => match (minimal_iteration, maximal_iteration, success_iteration) {
195                    (None, None, Some(suc_iter)) => {
196                        if iteration > suc_iter {
197                            maximal_iteration = Some(iteration);
198                        }
199                        if iteration < suc_iter {
200                            minimal_iteration = Some(iteration);
201                        }
202                    }
203                    (Some(min_iter), None, Some(suc_iter)) => {
204                        if iteration > suc_iter {
205                            maximal_iteration = Some(iteration);
206                        }
207                        if iteration < suc_iter && iteration > min_iter {
208                            minimal_iteration = Some(iteration);
209                        }
210                    }
211                    (None, Some(max_iter), Some(suc_iter)) => {
212                        if iteration > suc_iter && iteration < max_iter {
213                            maximal_iteration = Some(iteration);
214                        }
215                        if iteration < suc_iter {
216                            minimal_iteration = Some(iteration);
217                        }
218                    }
219                    (Some(min_iter), Some(max_iter), Some(suc_iter)) => {
220                        if iteration > suc_iter && iteration < max_iter {
221                            maximal_iteration = Some(iteration);
222                        }
223                        if iteration < suc_iter && iteration > min_iter {
224                            minimal_iteration = Some(iteration);
225                        }
226                    }
227                    (_, _, None) => (),
228                },
229            };
230        }
231        Ok(accumulator)
232    }
233
234    fn load_all_elements_at_iteration(
235        &self,
236        iteration: u64,
237    ) -> Result<HashMap<Id, Element>, StorageError>
238    where
239        Id: std::hash::Hash + std::cmp::Eq + for<'a> Deserialize<'a>,
240        Element: for<'a> Deserialize<'a>,
241    {
242        let tree = match self.open_tree(iteration)? {
243            Some(tree) => tree,
244            None => return Ok(HashMap::new()),
245        };
246        tree.iter()
247            .map(|entry_result| {
248                let (identifier_serialized, element_serialized) = entry_result?;
249                let identifier: Id = bincode::deserialize(&identifier_serialized)?;
250                let element: Element = bincode::deserialize(&element_serialized)?;
251                Ok((identifier, element))
252            })
253            .collect::<Result<HashMap<Id, Element>, StorageError>>()
254    }
255
256    fn load_all_elements(&self) -> Result<BTreeMap<u64, HashMap<Id, Element>>, StorageError>
257    where
258        Id: std::hash::Hash + std::cmp::Eq + for<'a> Deserialize<'a>,
259        Element: for<'a> Deserialize<'a>,
260    {
261        self.db
262            .tree_names()
263            .iter()
264            .map(|tree_name_serialized| {
265                let tree = self.db.open_tree(tree_name_serialized)?;
266                let iteration = Self::key_to_iteration(tree_name_serialized)?;
267                let identifier_to_element = tree
268                    .iter()
269                    .map(|entry_result| {
270                        let (identifier_serialized, element_serialized) = entry_result?;
271                        let identifier: Id = bincode::deserialize(&identifier_serialized)?;
272                        let element: Element = bincode::deserialize(&element_serialized)?;
273                        Ok((identifier, element))
274                    })
275                    .collect::<Result<HashMap<Id, Element>, StorageError>>()?;
276                Ok((iteration, identifier_to_element))
277            })
278            .collect::<Result<BTreeMap<u64, HashMap<Id, Element>>, StorageError>>()
279    }
280
281    fn get_all_iterations(&self) -> Result<Vec<u64>, StorageError> {
282        let iterations = self
283            .db
284            .tree_names()
285            .iter()
286            // TODO this should not be here! Fix it properly (I asked on sled discord)
287            .filter(|key| {
288                **key
289                    != sled::IVec::from(&[
290                        95, 95, 115, 108, 101, 100, 95, 95, 100, 101, 102, 97, 117, 108, 116,
291                    ])
292            })
293            .map(|tree_name_serialized| Self::key_to_iteration(tree_name_serialized))
294            .collect::<Result<Vec<_>, StorageError>>()?;
295
296        Ok(iterations)
297    }
298}