bonsaidb_local/database/
keyvalue.rs

1use std::borrow::Cow;
2use std::collections::{btree_map, BTreeMap, VecDeque};
3use std::sync::{Arc, Weak};
4use std::time::Duration;
5
6use bonsaidb_core::connection::{Connection, HasSession};
7use bonsaidb_core::keyvalue::{
8    Command, KeyCheck, KeyOperation, KeyStatus, KeyValue, Numeric, Output, SetCommand, Timestamp,
9    Value,
10};
11use bonsaidb_core::permissions::bonsai::{
12    keyvalue_key_resource_name, BonsaiAction, DatabaseAction, KeyValueAction,
13};
14use bonsaidb_core::transaction::{ChangedKey, Changes};
15use nebari::io::any::AnyFile;
16use nebari::tree::{CompareSwap, Operation, Root, ScanEvaluation, Unversioned};
17use nebari::{AbortError, ArcBytes, Roots};
18use parking_lot::Mutex;
19use serde::{Deserialize, Serialize};
20use watchable::{Watchable, Watcher};
21
22use crate::config::KeyValuePersistence;
23use crate::database::compat;
24use crate::storage::StorageLock;
25use crate::tasks::{Job, Keyed, Task};
26use crate::{Database, DatabaseNonBlocking, Error};
27
28#[derive(Serialize, Deserialize, Debug, Clone)]
29pub struct Entry {
30    pub value: Value,
31    pub expiration: Option<Timestamp>,
32    #[serde(default)]
33    pub last_updated: Timestamp,
34}
35
36impl Entry {
37    pub(crate) fn restore(
38        self,
39        namespace: Option<String>,
40        key: String,
41        database: &Database,
42    ) -> Result<(), bonsaidb_core::Error> {
43        database.execute_key_operation(KeyOperation {
44            namespace,
45            key,
46            command: Command::Set(SetCommand {
47                value: self.value,
48                expiration: self.expiration,
49                keep_existing_expiration: false,
50                check: None,
51                return_previous_value: false,
52            }),
53        })?;
54        Ok(())
55    }
56}
57
58impl KeyValue for Database {
59    fn execute_key_operation(&self, op: KeyOperation) -> Result<Output, bonsaidb_core::Error> {
60        self.check_permission(
61            keyvalue_key_resource_name(self.name(), op.namespace.as_deref(), &op.key),
62            &BonsaiAction::Database(DatabaseAction::KeyValue(KeyValueAction::ExecuteOperation)),
63        )?;
64        self.data.context.perform_kv_operation(op)
65    }
66}
67
68impl Database {
69    pub(crate) fn all_key_value_entries(
70        &self,
71    ) -> Result<BTreeMap<(Option<String>, String), Entry>, Error> {
72        // Lock the state so that new new modifications can be made while we gather this snapshot.
73        let state = self.data.context.key_value_state.lock();
74        let database = self.clone();
75        // Initialize our entries with any dirty keys and any keys that are about to be persisted.
76        let mut all_entries = BTreeMap::new();
77        database
78            .roots()
79            .tree(Unversioned::tree(KEY_TREE))?
80            .scan::<Error, _, _, _, _>(
81                &(..),
82                true,
83                |_, _, _| ScanEvaluation::ReadData,
84                |_, _| ScanEvaluation::ReadData,
85                |key, _, entry: ArcBytes<'static>| {
86                    let entry = bincode::deserialize::<Entry>(&entry)
87                        .map_err(|err| AbortError::Other(Error::from(err)))?;
88                    let full_key = std::str::from_utf8(&key)
89                        .map_err(|err| AbortError::Other(Error::from(err)))?;
90
91                    if let Some(split_key) = split_key(full_key) {
92                        // Do not overwrite the existing key
93                        all_entries.entry(split_key).or_insert(entry);
94                    }
95
96                    Ok(())
97                },
98            )?;
99
100        // Apply the pending writes first
101        if let Some(pending_keys) = &state.keys_being_persisted {
102            for (key, possible_entry) in pending_keys.iter() {
103                let (namespace, key) = split_key(key).unwrap();
104                if let Some(updated_entry) = possible_entry {
105                    all_entries.insert((namespace, key), updated_entry.clone());
106                } else {
107                    all_entries.remove(&(namespace, key));
108                }
109            }
110        }
111
112        for (key, possible_entry) in &state.dirty_keys {
113            let (namespace, key) = split_key(key).unwrap();
114            if let Some(updated_entry) = possible_entry {
115                all_entries.insert((namespace, key), updated_entry.clone());
116            } else {
117                all_entries.remove(&(namespace, key));
118            }
119        }
120
121        Ok(all_entries)
122    }
123}
124
125pub(crate) const KEY_TREE: &str = "kv";
126
127fn full_key(namespace: Option<&str>, key: &str) -> String {
128    let full_length = namespace.map_or_else(|| 0, str::len) + key.len() + 1;
129    let mut full_key = String::with_capacity(full_length);
130    if let Some(ns) = namespace {
131        full_key.push_str(ns);
132    }
133    full_key.push('\0');
134    full_key.push_str(key);
135    full_key
136}
137
138fn split_key(full_key: &str) -> Option<(Option<String>, String)> {
139    if let Some((namespace, key)) = full_key.split_once('\0') {
140        let namespace = if namespace.is_empty() {
141            None
142        } else {
143            Some(namespace.to_string())
144        };
145        Some((namespace, key.to_string()))
146    } else {
147        None
148    }
149}
150
151fn increment(existing: &Numeric, amount: &Numeric, saturating: bool) -> Numeric {
152    match amount {
153        Numeric::Integer(amount) => {
154            let existing_value = existing.as_i64_lossy(saturating);
155            let new_value = if saturating {
156                existing_value.saturating_add(*amount)
157            } else {
158                existing_value.wrapping_add(*amount)
159            };
160            Numeric::Integer(new_value)
161        }
162        Numeric::UnsignedInteger(amount) => {
163            let existing_value = existing.as_u64_lossy(saturating);
164            let new_value = if saturating {
165                existing_value.saturating_add(*amount)
166            } else {
167                existing_value.wrapping_add(*amount)
168            };
169            Numeric::UnsignedInteger(new_value)
170        }
171        Numeric::Float(amount) => {
172            let existing_value = existing.as_f64_lossy();
173            let new_value = existing_value + *amount;
174            Numeric::Float(new_value)
175        }
176    }
177}
178
179fn decrement(existing: &Numeric, amount: &Numeric, saturating: bool) -> Numeric {
180    match amount {
181        Numeric::Integer(amount) => {
182            let existing_value = existing.as_i64_lossy(saturating);
183            let new_value = if saturating {
184                existing_value.saturating_sub(*amount)
185            } else {
186                existing_value.wrapping_sub(*amount)
187            };
188            Numeric::Integer(new_value)
189        }
190        Numeric::UnsignedInteger(amount) => {
191            let existing_value = existing.as_u64_lossy(saturating);
192            let new_value = if saturating {
193                existing_value.saturating_sub(*amount)
194            } else {
195                existing_value.wrapping_sub(*amount)
196            };
197            Numeric::UnsignedInteger(new_value)
198        }
199        Numeric::Float(amount) => {
200            let existing_value = existing.as_f64_lossy();
201            let new_value = existing_value - *amount;
202            Numeric::Float(new_value)
203        }
204    }
205}
206
207#[derive(Debug)]
208pub struct KeyValueState {
209    roots: Roots<AnyFile>,
210    persistence: KeyValuePersistence,
211    last_commit: Timestamp,
212    background_worker_target: Watchable<BackgroundWorkerProcessTarget>,
213    expiring_keys: BTreeMap<String, Timestamp>,
214    expiration_order: VecDeque<String>,
215    dirty_keys: BTreeMap<String, Option<Entry>>,
216    keys_being_persisted: Option<Arc<BTreeMap<String, Option<Entry>>>>,
217    last_persistence: Watchable<Timestamp>,
218    shutdown: Option<flume::Sender<()>>,
219}
220
221impl KeyValueState {
222    pub fn new(
223        persistence: KeyValuePersistence,
224        roots: Roots<AnyFile>,
225        background_worker_target: Watchable<BackgroundWorkerProcessTarget>,
226    ) -> Self {
227        Self {
228            roots,
229            persistence,
230            last_commit: Timestamp::now(),
231            expiring_keys: BTreeMap::new(),
232            background_worker_target,
233            expiration_order: VecDeque::new(),
234            dirty_keys: BTreeMap::new(),
235            keys_being_persisted: None,
236            last_persistence: Watchable::new(Timestamp::MIN),
237            shutdown: None,
238        }
239    }
240
241    pub fn shutdown(&mut self, state: &Arc<Mutex<KeyValueState>>) -> Option<flume::Receiver<()>> {
242        if self.keys_being_persisted.is_none() && self.commit_dirty_keys(state) {
243            let (shutdown_sender, shutdown_receiver) = flume::bounded(1);
244            self.shutdown = Some(shutdown_sender);
245            Some(shutdown_receiver)
246        } else {
247            None
248        }
249    }
250
251    pub fn perform_kv_operation(
252        &mut self,
253        op: KeyOperation,
254        state: &Arc<Mutex<KeyValueState>>,
255    ) -> Result<Output, bonsaidb_core::Error> {
256        let now = Timestamp::now();
257        // If there are any keys that have expired, clear them before executing any operations.
258        self.remove_expired_keys(now);
259        let result = match op.command {
260            Command::Set(command) => {
261                self.execute_set_operation(op.namespace.as_deref(), &op.key, command, now)
262            }
263            Command::Get { delete } => {
264                self.execute_get_operation(op.namespace.as_deref(), &op.key, delete)
265            }
266            Command::Delete => self.execute_delete_operation(op.namespace.as_deref(), &op.key),
267            Command::Increment { amount, saturating } => self.execute_increment_operation(
268                op.namespace.as_deref(),
269                &op.key,
270                &amount,
271                saturating,
272                now,
273            ),
274            Command::Decrement { amount, saturating } => self.execute_decrement_operation(
275                op.namespace.as_deref(),
276                &op.key,
277                &amount,
278                saturating,
279                now,
280            ),
281        };
282        if result.is_ok() {
283            if self.needs_commit(now) {
284                self.commit_dirty_keys(state);
285            }
286            self.update_background_worker_target();
287        }
288        result
289    }
290
291    #[cfg_attr(
292        feature = "tracing",
293        tracing::instrument(level = "trace", skip(self, set, now),)
294    )]
295    fn execute_set_operation(
296        &mut self,
297        namespace: Option<&str>,
298        key: &str,
299        set: SetCommand,
300        now: Timestamp,
301    ) -> Result<Output, bonsaidb_core::Error> {
302        let mut entry = Entry {
303            value: set.value.validate()?,
304            expiration: set.expiration,
305            last_updated: now,
306        };
307        let full_key = full_key(namespace, key);
308        let possible_existing_value =
309            if set.check.is_some() || set.return_previous_value || set.keep_existing_expiration {
310                Some(self.get(&full_key).map_err(Error::from)?)
311            } else {
312                None
313            };
314        let existing_value_ref = possible_existing_value.as_ref().and_then(Option::as_ref);
315
316        let updating = match set.check {
317            Some(KeyCheck::OnlyIfPresent) => existing_value_ref.is_some(),
318            Some(KeyCheck::OnlyIfVacant) => existing_value_ref.is_none(),
319            None => true,
320        };
321        if updating {
322            if set.keep_existing_expiration {
323                if let Some(existing_value) = existing_value_ref {
324                    entry.expiration = existing_value.expiration;
325                }
326            }
327            self.update_key_expiration(&full_key, entry.expiration);
328
329            let previous_value = if let Some(existing_value) = possible_existing_value {
330                // we already fetched, no need to ask for the existing value back
331                self.set(full_key, entry);
332                existing_value
333            } else {
334                self.replace(full_key, entry).map_err(Error::from)?
335            };
336            if set.return_previous_value {
337                Ok(Output::Value(previous_value.map(|entry| entry.value)))
338            } else if previous_value.is_none() {
339                Ok(Output::Status(KeyStatus::Inserted))
340            } else {
341                Ok(Output::Status(KeyStatus::Updated))
342            }
343        } else {
344            Ok(Output::Status(KeyStatus::NotChanged))
345        }
346    }
347
348    #[cfg_attr(
349        feature = "tracing",
350        tracing::instrument(level = "trace", skip(self, tree_key, expiration))
351    )]
352    pub fn update_key_expiration<'key>(
353        &mut self,
354        tree_key: impl Into<Cow<'key, str>>,
355        expiration: Option<Timestamp>,
356    ) {
357        let tree_key = tree_key.into();
358        let mut changed_first_expiration = false;
359        if let Some(expiration) = expiration {
360            let key = if self.expiring_keys.contains_key(tree_key.as_ref()) {
361                // Update the existing entry.
362                let existing_entry_index = self
363                    .expiration_order
364                    .iter()
365                    .enumerate()
366                    .find_map(
367                        |(index, key)| {
368                            if &tree_key == key {
369                                Some(index)
370                            } else {
371                                None
372                            }
373                        },
374                    )
375                    .unwrap();
376                changed_first_expiration = existing_entry_index == 0;
377                self.expiration_order.remove(existing_entry_index).unwrap()
378            } else {
379                tree_key.into_owned()
380            };
381
382            // Insert the key into the expiration_order queue
383            let mut insert_at = None;
384            for (index, expiring_key) in self.expiration_order.iter().enumerate() {
385                if self.expiring_keys.get(expiring_key).unwrap() > &expiration {
386                    insert_at = Some(index);
387                    break;
388                }
389            }
390            if let Some(insert_at) = insert_at {
391                changed_first_expiration |= insert_at == 0;
392
393                self.expiration_order.insert(insert_at, key.clone());
394            } else {
395                changed_first_expiration |= self.expiration_order.is_empty();
396                self.expiration_order.push_back(key.clone());
397            }
398            self.expiring_keys.insert(key, expiration);
399        } else if self.expiring_keys.remove(tree_key.as_ref()).is_some() {
400            let index = self
401                .expiration_order
402                .iter()
403                .enumerate()
404                .find_map(|(index, key)| {
405                    if tree_key.as_ref() == key {
406                        Some(index)
407                    } else {
408                        None
409                    }
410                })
411                .unwrap();
412
413            changed_first_expiration |= index == 0;
414            self.expiration_order.remove(index);
415        }
416
417        if changed_first_expiration {
418            self.update_background_worker_target();
419        }
420    }
421
422    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
423    fn execute_get_operation(
424        &mut self,
425        namespace: Option<&str>,
426        key: &str,
427        delete: bool,
428    ) -> Result<Output, bonsaidb_core::Error> {
429        let full_key = full_key(namespace, key);
430        let entry = if delete {
431            self.remove(full_key).map_err(Error::from)?
432        } else {
433            self.get(&full_key).map_err(Error::from)?
434        };
435
436        Ok(Output::Value(entry.map(|e| e.value)))
437    }
438
439    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
440    fn execute_delete_operation(
441        &mut self,
442        namespace: Option<&str>,
443        key: &str,
444    ) -> Result<Output, bonsaidb_core::Error> {
445        let full_key = full_key(namespace, key);
446        let value = self.remove(full_key).map_err(Error::from)?;
447        if value.is_some() {
448            Ok(Output::Status(KeyStatus::Deleted))
449        } else {
450            Ok(Output::Status(KeyStatus::NotChanged))
451        }
452    }
453
454    #[cfg_attr(
455        feature = "tracing",
456        tracing::instrument(level = "trace", skip(self, amount, saturating, now))
457    )]
458    fn execute_increment_operation(
459        &mut self,
460        namespace: Option<&str>,
461        key: &str,
462        amount: &Numeric,
463        saturating: bool,
464        now: Timestamp,
465    ) -> Result<Output, bonsaidb_core::Error> {
466        self.execute_numeric_operation(namespace, key, amount, saturating, now, increment)
467    }
468
469    #[cfg_attr(
470        feature = "tracing",
471        tracing::instrument(level = "trace", skip(self, amount, saturating, now))
472    )]
473    fn execute_decrement_operation(
474        &mut self,
475        namespace: Option<&str>,
476        key: &str,
477        amount: &Numeric,
478        saturating: bool,
479        now: Timestamp,
480    ) -> Result<Output, bonsaidb_core::Error> {
481        self.execute_numeric_operation(namespace, key, amount, saturating, now, decrement)
482    }
483
484    fn execute_numeric_operation<F: Fn(&Numeric, &Numeric, bool) -> Numeric>(
485        &mut self,
486        namespace: Option<&str>,
487        key: &str,
488        amount: &Numeric,
489        saturating: bool,
490        now: Timestamp,
491        op: F,
492    ) -> Result<Output, bonsaidb_core::Error> {
493        let full_key = full_key(namespace, key);
494        let current = self.get(&full_key).map_err(Error::from)?;
495        let mut entry = current.unwrap_or(Entry {
496            value: Value::Numeric(Numeric::UnsignedInteger(0)),
497            expiration: None,
498            last_updated: now,
499        });
500
501        match entry.value {
502            Value::Numeric(existing) => {
503                let value = Value::Numeric(op(&existing, amount, saturating).validate()?);
504                entry.value = value.clone();
505
506                self.set(full_key, entry);
507                Ok(Output::Value(Some(value)))
508            }
509            Value::Bytes(_) => Err(bonsaidb_core::Error::other(
510                "bonsaidb-local",
511                "type of stored `Value` is not `Numeric`",
512            )),
513        }
514    }
515
516    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
517    fn remove(&mut self, key: String) -> Result<Option<Entry>, nebari::Error> {
518        self.update_key_expiration(&key, None);
519
520        if let Some(dirty_entry) = self.dirty_keys.get_mut(&key) {
521            Ok(dirty_entry.take())
522        } else if let Some(persisting_entry) = self
523            .keys_being_persisted
524            .as_ref()
525            .and_then(|keys| keys.get(&key))
526        {
527            self.dirty_keys.insert(key, None);
528            Ok(persisting_entry.clone())
529        } else {
530            // There might be a value on-disk we need to remove.
531            let previous_value = Self::retrieve_key_from_disk(&self.roots, &key)?;
532            self.dirty_keys.insert(key, None);
533            Ok(previous_value)
534        }
535    }
536
537    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
538    fn get(&self, key: &str) -> Result<Option<Entry>, nebari::Error> {
539        if let Some(entry) = self.dirty_keys.get(key) {
540            Ok(entry.clone())
541        } else if let Some(persisting_entry) = self
542            .keys_being_persisted
543            .as_ref()
544            .and_then(|keys| keys.get(key))
545        {
546            Ok(persisting_entry.clone())
547        } else {
548            Self::retrieve_key_from_disk(&self.roots, key)
549        }
550    }
551
552    fn set(&mut self, key: String, value: Entry) {
553        self.dirty_keys.insert(key, Some(value));
554    }
555
556    fn replace(&mut self, key: String, value: Entry) -> Result<Option<Entry>, nebari::Error> {
557        let mut value = Some(value);
558        let map_entry = self.dirty_keys.entry(key);
559        if matches!(map_entry, btree_map::Entry::Vacant(_)) {
560            // This key is clean, and the caller is expecting the previous
561            // value.
562            let stored_value = if let Some(persisting_entry) = self
563                .keys_being_persisted
564                .as_ref()
565                .and_then(|keys| keys.get(map_entry.key()))
566            {
567                persisting_entry.clone()
568            } else {
569                Self::retrieve_key_from_disk(&self.roots, map_entry.key())?
570            };
571            map_entry.or_insert(value);
572            Ok(stored_value)
573        } else {
574            // This key is already dirty, we can just replace the value and
575            // return the old value.
576            map_entry.and_modify(|map_entry| {
577                std::mem::swap(&mut value, map_entry);
578            });
579            Ok(value)
580        }
581    }
582
583    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(roots)))]
584    fn retrieve_key_from_disk(
585        roots: &Roots<AnyFile>,
586        key: &str,
587    ) -> Result<Option<Entry>, nebari::Error> {
588        roots
589            .tree(Unversioned::tree(KEY_TREE))?
590            .get(key.as_bytes())
591            .map(|current| current.and_then(|current| bincode::deserialize::<Entry>(&current).ok()))
592    }
593
594    fn update_background_worker_target(&mut self) {
595        let key_expiration_target = self.expiration_order.get(0).map(|key| {
596            let expiration_timeout = self.expiring_keys.get(key).unwrap();
597            *expiration_timeout
598        });
599        let now = Timestamp::now();
600        let persisting = self.keys_being_persisted.is_some();
601        let commit_target = (!persisting)
602            .then(|| {
603                self.persistence.duration_until_next_commit(
604                    self.dirty_keys.len(),
605                    (now - self.last_commit).unwrap_or_default(),
606                )
607            })
608            .flatten()
609            .map(|duration| now + duration);
610        match (commit_target, key_expiration_target) {
611            (Some(target), _) | (_, Some(target)) if target <= now => {
612                self.background_worker_target
613                    .replace(BackgroundWorkerProcessTarget::Now);
614            }
615            (Some(commit_target), Some(key_target)) => {
616                let closest_target = key_target.min(commit_target);
617                let new_target = BackgroundWorkerProcessTarget::Timestamp(closest_target);
618                let _: Result<_, _> = self.background_worker_target.update(new_target);
619            }
620            (Some(target), None) | (None, Some(target)) => {
621                let _: Result<_, _> = self
622                    .background_worker_target
623                    .update(BackgroundWorkerProcessTarget::Timestamp(target));
624            }
625            (None, None) => {
626                let _: Result<_, _> = self
627                    .background_worker_target
628                    .update(BackgroundWorkerProcessTarget::Never);
629            }
630        }
631    }
632
633    fn remove_expired_keys(&mut self, now: Timestamp) {
634        while !self.expiration_order.is_empty()
635            && self.expiring_keys.get(&self.expiration_order[0]).unwrap() <= &now
636        {
637            let key = self.expiration_order.pop_front().unwrap();
638            self.expiring_keys.remove(&key);
639            self.dirty_keys.insert(key, None);
640        }
641    }
642
643    fn needs_commit(&mut self, now: Timestamp) -> bool {
644        if self.keys_being_persisted.is_some() {
645            false
646        } else {
647            let since_last_commit = (now - self.last_commit).unwrap_or_default();
648            self.persistence
649                .should_commit(self.dirty_keys.len(), since_last_commit)
650        }
651    }
652
653    fn stage_dirty_keys(&mut self) -> Option<Arc<BTreeMap<String, Option<Entry>>>> {
654        if !self.dirty_keys.is_empty() && self.keys_being_persisted.is_none() {
655            let keys = Arc::new(std::mem::take(&mut self.dirty_keys));
656            self.keys_being_persisted = Some(keys.clone());
657            Some(keys)
658        } else {
659            None
660        }
661    }
662
663    pub fn commit_dirty_keys(&mut self, state: &Arc<Mutex<KeyValueState>>) -> bool {
664        if let Some(keys) = self.stage_dirty_keys() {
665            let roots = self.roots.clone();
666            let state = state.clone();
667            std::thread::Builder::new()
668                .name(String::from("keyvalue-persist"))
669                .spawn(move || Self::persist_keys(&state, &roots, &keys))
670                .unwrap();
671            self.last_commit = Timestamp::now();
672            true
673        } else {
674            false
675        }
676    }
677
678    #[cfg(test)]
679    pub fn persistence_watcher(&self) -> Watcher<Timestamp> {
680        self.last_persistence.watch()
681    }
682
683    #[cfg_attr(feature = "instrument", tracing::instrument(level = "trace", skip_all))]
684    fn persist_keys(
685        key_value_state: &Arc<Mutex<KeyValueState>>,
686        roots: &Roots<AnyFile>,
687        keys: &BTreeMap<String, Option<Entry>>,
688    ) -> Result<(), bonsaidb_core::Error> {
689        let mut transaction = roots
690            .transaction(&[Unversioned::tree(KEY_TREE)])
691            .map_err(Error::from)?;
692        let all_keys = keys
693            .keys()
694            .map(|key| ArcBytes::from(key.as_bytes().to_vec()))
695            .collect();
696        let mut changed_keys = Vec::new();
697        transaction
698            .tree::<Unversioned>(0)
699            .unwrap()
700            .modify(
701                all_keys,
702                Operation::CompareSwap(CompareSwap::new(&mut |key, existing_value| {
703                    let full_key = std::str::from_utf8(key).unwrap();
704                    let (namespace, key) = split_key(full_key).unwrap();
705
706                    if let Some(new_value) = keys.get(full_key).unwrap() {
707                        changed_keys.push(ChangedKey {
708                            namespace,
709                            key,
710                            deleted: false,
711                        });
712                        let bytes = bincode::serialize(new_value).unwrap();
713                        nebari::tree::KeyOperation::Set(ArcBytes::from(bytes))
714                    } else if existing_value.is_some() {
715                        changed_keys.push(ChangedKey {
716                            namespace,
717                            key,
718                            deleted: existing_value.is_some(),
719                        });
720                        nebari::tree::KeyOperation::Remove
721                    } else {
722                        nebari::tree::KeyOperation::Skip
723                    }
724                })),
725            )
726            .map_err(Error::from)?;
727
728        if !changed_keys.is_empty() {
729            transaction
730                .entry_mut()
731                .set_data(compat::serialize_executed_transaction_changes(
732                    &Changes::Keys(changed_keys),
733                )?)
734                .map_err(Error::from)?;
735            transaction.commit().map_err(Error::from)?;
736        }
737
738        // If we are shutting down, check if we still have dirty keys.
739        let final_keys = {
740            let mut state = key_value_state.lock();
741            state.last_persistence.replace(Timestamp::now());
742            state.keys_being_persisted = None;
743            state.update_background_worker_target();
744            // This block is a little ugly to avoid having to acquire the lock
745            // twice. If we're shutting down and have no dirty keys, we notify
746            // the waiting shutdown task. If we have any dirty keys, we wait do
747            // to that step because we're going to recurse and reach this spot
748            // again.
749            if state.shutdown.is_some() {
750                let staged_keys = state.stage_dirty_keys();
751                if staged_keys.is_none() {
752                    let shutdown = state.shutdown.take().unwrap();
753                    let _: Result<_, _> = shutdown.send(());
754                }
755                staged_keys
756            } else {
757                None
758            }
759        };
760        if let Some(final_keys) = final_keys {
761            Self::persist_keys(key_value_state, roots, &final_keys)?;
762        }
763        Ok(())
764    }
765}
766
767pub fn background_worker(
768    key_value_state: &Weak<Mutex<KeyValueState>>,
769    timestamp_receiver: &mut Watcher<BackgroundWorkerProcessTarget>,
770    storage_lock: Option<StorageLock>,
771) {
772    loop {
773        let mut perform_operations = false;
774        let current_target = *timestamp_receiver.read();
775        match current_target {
776            // With no target, sleep until we receive a target.
777            BackgroundWorkerProcessTarget::Never => {
778                if timestamp_receiver.watch().is_err() {
779                    break;
780                }
781            }
782            BackgroundWorkerProcessTarget::Timestamp(target) => {
783                // With a target, we need to wait to receive a target only as
784                // long as there is time remaining.
785                let remaining = target - Timestamp::now();
786                if let Some(remaining) = remaining {
787                    // recv_timeout panics if Instant::checked_add(remaining)
788                    // fails. So, we will cap the sleep time at 1 day.
789                    let remaining = remaining.min(Duration::from_secs(60 * 60 * 24));
790                    match timestamp_receiver.watch_timeout(remaining) {
791                        Ok(_) | Err(watchable::TimeoutError::Timeout) => {
792                            perform_operations = true;
793                        }
794                        Err(watchable::TimeoutError::Disconnected) => break,
795                    }
796                } else {
797                    perform_operations = true;
798                }
799            }
800            BackgroundWorkerProcessTarget::Now => {
801                perform_operations = true;
802            }
803        };
804
805        let Some(key_value_state) = key_value_state.upgrade() else {
806            break;
807        };
808
809        if perform_operations {
810            let mut state = key_value_state.lock();
811            let now = Timestamp::now();
812            state.remove_expired_keys(now);
813            if state.needs_commit(now) {
814                state.commit_dirty_keys(&key_value_state);
815            }
816            state.update_background_worker_target();
817        }
818    }
819
820    // The key-value store's delayed persistence can cause the key-value storage
821    // to be written past when the last reference to the storage is still held.
822    // The storage lock being held ensures that another reader/writer doesn't
823    // begin accessing this same storage again.
824    drop(storage_lock);
825}
826
827#[derive(Debug, PartialEq, Eq, Clone, Copy)]
828pub enum BackgroundWorkerProcessTarget {
829    Now,
830    Timestamp(Timestamp),
831    Never,
832}
833
834#[derive(Debug)]
835pub struct ExpirationLoader {
836    pub database: Database,
837    pub launched_at: Timestamp,
838}
839
840impl Keyed<Task> for ExpirationLoader {
841    fn key(&self) -> Task {
842        Task::ExpirationLoader(self.database.data.name.clone())
843    }
844}
845
846impl Job for ExpirationLoader {
847    type Error = Error;
848    type Output = ();
849
850    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip_all))]
851    fn execute(&mut self) -> Result<Self::Output, Self::Error> {
852        let database = self.database.clone();
853        let launched_at = self.launched_at;
854
855        for ((namespace, key), entry) in database.all_key_value_entries()? {
856            if entry.last_updated < launched_at && entry.expiration.is_some() {
857                self.database
858                    .update_key_expiration(full_key(namespace.as_deref(), &key), entry.expiration);
859            }
860        }
861
862        self.database
863            .storage()
864            .instance
865            .tasks()
866            .mark_key_value_expiration_loaded(self.database.data.name.clone());
867
868        Ok(())
869    }
870}
871
872#[cfg(test)]
873mod tests {
874    use std::time::{Duration, Instant};
875
876    use bonsaidb_core::arc_bytes::serde::Bytes;
877    use bonsaidb_core::test_util::{TestDirectory, TimingTest};
878    use nebari::io::any::{AnyFile, AnyFileManager};
879
880    use super::*;
881    use crate::config::PersistenceThreshold;
882    use crate::database::Context;
883
884    fn run_test_with_persistence<
885        F: Fn(Context, nebari::Roots<AnyFile>) -> anyhow::Result<()> + Send,
886    >(
887        name: &str,
888        persistence: KeyValuePersistence,
889        test_contents: &F,
890    ) -> anyhow::Result<()> {
891        let dir = TestDirectory::new(name);
892        let sled = nebari::Config::new(&dir)
893            .file_manager(AnyFileManager::std())
894            .open()?;
895
896        let context = Context::new(sled.clone(), persistence, None);
897
898        test_contents(context, sled)?;
899
900        Ok(())
901    }
902
903    fn run_test<F: Fn(Context, nebari::Roots<AnyFile>) -> anyhow::Result<()> + Send>(
904        name: &str,
905        test_contents: F,
906    ) -> anyhow::Result<()> {
907        run_test_with_persistence(name, KeyValuePersistence::default(), &test_contents)
908    }
909
910    #[test]
911    fn basic_expiration() -> anyhow::Result<()> {
912        run_test("kv-basic-expiration", |context, roots| {
913            // Initialize the test state
914            let mut persistence_watcher = context.kv_persistence_watcher();
915            roots.delete_tree(KEY_TREE)?;
916            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
917            tree.set(b"atree\0akey", b"somevalue")?;
918
919            // Expire the existing key
920            context.update_key_expiration(
921                full_key(Some("atree"), "akey"),
922                Some(Timestamp::now() + Duration::from_millis(100)),
923            );
924            // Wait for persistence.
925            persistence_watcher.next_value()?;
926
927            // Verify it is gone.
928            assert!(tree.get(b"akey")?.is_none());
929
930            Ok(())
931        })
932    }
933
934    #[test]
935    fn updating_expiration() -> anyhow::Result<()> {
936        run_test("kv-updating-expiration", |context, roots| {
937            // Initialize the test state
938            let mut persistence_watcher = context.kv_persistence_watcher();
939            roots.delete_tree(KEY_TREE)?;
940            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
941            tree.set(b"atree\0akey", b"somevalue")?;
942            let start = Timestamp::now();
943
944            // Set the expiration once.
945            context.update_key_expiration(
946                full_key(Some("atree"), "akey"),
947                Some(start + Duration::from_millis(100)),
948            );
949            // Set the expiration to a longer value.
950            let correct_expiration = start + Duration::from_secs(1);
951            context
952                .update_key_expiration(full_key(Some("atree"), "akey"), Some(correct_expiration));
953
954            // Wait for persistence, and ensure that the next persistence is
955            // after our expiration timestamp.
956            assert!(persistence_watcher.next_value()? > correct_expiration);
957
958            // Verify the key is gone now.
959            assert_eq!(tree.get(b"atree\0akey")?, None);
960
961            Ok(())
962        })
963    }
964
965    #[test]
966    fn multiple_keys_expiration() -> anyhow::Result<()> {
967        run_test("kv-multiple-keys-expiration", |context, roots| {
968            // Initialize the test state
969            let mut persistence_watcher = context.kv_persistence_watcher();
970            roots.delete_tree(KEY_TREE)?;
971            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
972            tree.set(b"atree\0akey", b"somevalue")?;
973            tree.set(b"atree\0bkey", b"somevalue")?;
974
975            // Expire both keys, one for a shorter time than the other.
976            context.update_key_expiration(
977                full_key(Some("atree"), "akey"),
978                Some(Timestamp::now() + Duration::from_millis(100)),
979            );
980            context.update_key_expiration(
981                full_key(Some("atree"), "bkey"),
982                Some(Timestamp::now() + Duration::from_secs(1)),
983            );
984
985            // Wait for the first persistence.
986            persistence_watcher.next_value()?;
987            assert!(tree.get(b"atree\0akey")?.is_none());
988            assert!(tree.get(b"atree\0bkey")?.is_some());
989
990            // Wait for the second persistence.
991            persistence_watcher.next_value()?;
992            assert!(tree.get(b"atree\0bkey")?.is_none());
993
994            Ok(())
995        })
996    }
997
998    #[test]
999    fn clearing_expiration() -> anyhow::Result<()> {
1000        run_test("kv-clearing-expiration", |sender, sled| {
1001            loop {
1002                sled.delete_tree(KEY_TREE)?;
1003                let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
1004                tree.set(b"atree\0akey", b"somevalue")?;
1005                let timing = TimingTest::new(Duration::from_millis(100));
1006                sender.update_key_expiration(
1007                    full_key(Some("atree"), "akey"),
1008                    Some(Timestamp::now() + Duration::from_millis(100)),
1009                );
1010                sender.update_key_expiration(full_key(Some("atree"), "akey"), None);
1011                if timing.elapsed() > Duration::from_millis(100) {
1012                    // Restart, took too long.
1013                    continue;
1014                }
1015                timing.wait_until(Duration::from_millis(150));
1016                assert!(tree.get(b"atree\0akey")?.is_some());
1017                break;
1018            }
1019
1020            Ok(())
1021        })
1022    }
1023
1024    #[test]
1025    fn out_of_order_expiration() -> anyhow::Result<()> {
1026        run_test("kv-out-of-order-expiration", |context, roots| loop {
1027            context.update_key_expiration(full_key(Some("atree"), "akey"), None);
1028            context.update_key_expiration(full_key(Some("atree"), "bkey"), None);
1029            context.update_key_expiration(full_key(Some("atree"), "ckey"), None);
1030            let mut persistence_watcher = context.kv_persistence_watcher();
1031            drop(roots.delete_tree(KEY_TREE));
1032            let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
1033            tree.set(b"atree\0akey", b"somevalue")?;
1034            tree.set(b"atree\0bkey", b"somevalue")?;
1035            tree.set(b"atree\0ckey", b"somevalue")?;
1036            let timing = TimingTest::new(Duration::from_millis(100));
1037            context.update_key_expiration(
1038                full_key(Some("atree"), "akey"),
1039                Some(Timestamp::now() + Duration::from_secs(3)),
1040            );
1041            context.update_key_expiration(
1042                full_key(Some("atree"), "ckey"),
1043                Some(Timestamp::now() + Duration::from_secs(1)),
1044            );
1045            context.update_key_expiration(
1046                full_key(Some("atree"), "bkey"),
1047                Some(Timestamp::now() + Duration::from_secs(2)),
1048            );
1049            persistence_watcher.mark_read();
1050            if timing.elapsed() > Duration::from_millis(500) {
1051                println!("Restarting");
1052                continue;
1053            }
1054
1055            // Wait for the first key to expire.
1056            persistence_watcher
1057                .watch_timeout(Duration::from_secs(5))
1058                .unwrap();
1059            persistence_watcher.mark_read();
1060            if timing.elapsed() > Duration::from_millis(1500) {
1061                println!("Restarting");
1062                continue;
1063            }
1064            assert!(tree.get(b"atree\0akey")?.is_some());
1065            assert!(tree.get(b"atree\0bkey")?.is_some());
1066            assert!(tree.get(b"atree\0ckey")?.is_none());
1067
1068            // Wait for the next key to expire.
1069            persistence_watcher
1070                .watch_timeout(Duration::from_secs(5))
1071                .unwrap();
1072            persistence_watcher.mark_read();
1073            if timing.elapsed() > Duration::from_millis(2500) {
1074                println!("Restarting");
1075                continue;
1076            }
1077            assert!(tree.get(b"atree\0akey")?.is_some());
1078            assert!(tree.get(b"atree\0bkey")?.is_none());
1079
1080            // Wait for the final key to expire.
1081            persistence_watcher
1082                .watch_timeout(Duration::from_secs(5))
1083                .unwrap();
1084            if timing.elapsed() > Duration::from_millis(3500) {
1085                println!("Restarting");
1086                continue;
1087            }
1088            assert!(tree.get(b"atree\0akey")?.is_none());
1089
1090            return Ok(());
1091        })
1092    }
1093
1094    #[test]
1095    fn basic_persistence() -> anyhow::Result<()> {
1096        run_test_with_persistence(
1097            "kv-basic-persistence",
1098            KeyValuePersistence::lazy([
1099                PersistenceThreshold::after_changes(2),
1100                PersistenceThreshold::after_changes(1).and_duration(Duration::from_secs(2)),
1101            ]),
1102            &|context, roots| {
1103                // Initialize the test state
1104                let mut persistence_watcher = context.kv_persistence_watcher();
1105                let tree = roots.tree(Unversioned::tree(KEY_TREE))?;
1106                let start = Instant::now();
1107                // Set three keys in quick succession. The first two should
1108                // persist immediately after the second is set, and the
1109                // third should show up after 2 seconds.
1110                context
1111                    .perform_kv_operation(KeyOperation {
1112                        namespace: None,
1113                        key: String::from("key1"),
1114                        command: Command::Set(SetCommand {
1115                            value: Value::Bytes(Bytes::default()),
1116                            expiration: None,
1117                            keep_existing_expiration: false,
1118                            check: None,
1119                            return_previous_value: false,
1120                        }),
1121                    })
1122                    .unwrap();
1123                context
1124                    .perform_kv_operation(KeyOperation {
1125                        namespace: None,
1126                        key: String::from("key2"),
1127                        command: Command::Set(SetCommand {
1128                            value: Value::Bytes(Bytes::default()),
1129                            expiration: None,
1130                            keep_existing_expiration: false,
1131                            check: None,
1132                            return_previous_value: false,
1133                        }),
1134                    })
1135                    .unwrap();
1136                context
1137                    .perform_kv_operation(KeyOperation {
1138                        namespace: None,
1139                        key: String::from("key3"),
1140                        command: Command::Set(SetCommand {
1141                            value: Value::Bytes(Bytes::default()),
1142                            expiration: None,
1143                            keep_existing_expiration: false,
1144                            check: None,
1145                            return_previous_value: false,
1146                        }),
1147                    })
1148                    .unwrap();
1149                // Wait for the first persistence to occur.
1150                persistence_watcher.next_value()?;
1151
1152                assert!(tree.get(b"\0key1").unwrap().is_some());
1153                assert!(tree.get(b"\0key2").unwrap().is_some());
1154                assert!(tree.get(b"\0key3").unwrap().is_none());
1155
1156                // Wait for the second persistence
1157                persistence_watcher.next_value()?;
1158                assert!(tree.get(b"\0key3").unwrap().is_some());
1159                // The total operation should have taken *at least* two seconds,
1160                // since the second persistence should have delayed for two
1161                // seconds itself.
1162                assert!(start.elapsed() > Duration::from_secs(2));
1163
1164                Ok(())
1165            },
1166        )
1167    }
1168
1169    #[test]
1170    fn saves_on_drop() -> anyhow::Result<()> {
1171        let dir = TestDirectory::new("saves-on-drop.bonsaidb");
1172        let sled = nebari::Config::new(&dir)
1173            .file_manager(AnyFileManager::std())
1174            .open()?;
1175        let tree = sled.tree(Unversioned::tree(KEY_TREE))?;
1176
1177        let context = Context::new(
1178            sled,
1179            KeyValuePersistence::lazy([PersistenceThreshold::after_changes(2)]),
1180            None,
1181        );
1182        context
1183            .perform_kv_operation(KeyOperation {
1184                namespace: None,
1185                key: String::from("key1"),
1186                command: Command::Set(SetCommand {
1187                    value: Value::Bytes(Bytes::default()),
1188                    expiration: None,
1189                    keep_existing_expiration: false,
1190                    check: None,
1191                    return_previous_value: false,
1192                }),
1193            })
1194            .unwrap();
1195        assert!(tree.get(b"\0key1").unwrap().is_none());
1196        drop(context);
1197        // Dropping spawns a task that should persist the keys. Give a moment
1198        // for the runtime to execute the task.
1199        std::thread::sleep(Duration::from_millis(100));
1200        assert!(tree.get(b"\0key1").unwrap().is_some());
1201
1202        Ok(())
1203    }
1204}