dynamo_runtime/storage/key_value_store/
mem.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::hash_map::Entry;
5use std::collections::{HashMap, HashSet};
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use rand::Rng as _;
12use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
13
14use crate::storage::key_value_store::{Key, KeyValue, WatchEvent};
15
16use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
17
18#[derive(Clone, Debug)]
19enum MemoryEvent {
20    Put { key: String, value: bytes::Bytes },
21    Delete { key: String },
22}
23
24#[derive(Clone)]
25pub struct MemoryStore {
26    inner: Arc<MemoryStoreInner>,
27    connection_id: u64,
28}
29
30impl Default for MemoryStore {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36struct MemoryStoreInner {
37    data: parking_lot::Mutex<HashMap<String, MemoryBucket>>,
38    change_sender: UnboundedSender<MemoryEvent>,
39    change_receiver: tokio::sync::Mutex<UnboundedReceiver<MemoryEvent>>,
40}
41
42pub struct MemoryBucketRef {
43    name: String,
44    inner: Arc<MemoryStoreInner>,
45}
46
47struct MemoryBucket {
48    data: HashMap<String, (u64, bytes::Bytes)>,
49}
50
51impl MemoryBucket {
52    fn new() -> Self {
53        MemoryBucket {
54            data: HashMap::new(),
55        }
56    }
57}
58
59impl MemoryStore {
60    pub(super) fn new() -> Self {
61        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
62        MemoryStore {
63            inner: Arc::new(MemoryStoreInner {
64                data: parking_lot::Mutex::new(HashMap::new()),
65                change_sender: tx,
66                change_receiver: tokio::sync::Mutex::new(rx),
67            }),
68            connection_id: rand::rng().random(),
69        }
70    }
71}
72
73#[async_trait]
74impl KeyValueStore for MemoryStore {
75    type Bucket = MemoryBucketRef;
76
77    async fn get_or_create_bucket(
78        &self,
79        bucket_name: &str,
80        // MemoryStore doesn't respect TTL yet
81        _ttl: Option<Duration>,
82    ) -> Result<Self::Bucket, StoreError> {
83        let mut locked_data = self.inner.data.lock();
84        // Ensure the bucket exists
85        locked_data
86            .entry(bucket_name.to_string())
87            .or_insert_with(MemoryBucket::new);
88        // Return an object able to access it
89        Ok(MemoryBucketRef {
90            name: bucket_name.to_string(),
91            inner: self.inner.clone(),
92        })
93    }
94
95    /// This operation cannot fail on MemoryStore. Always returns Ok.
96    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
97        let locked_data = self.inner.data.lock();
98        match locked_data.get(bucket_name) {
99            Some(_) => Ok(Some(MemoryBucketRef {
100                name: bucket_name.to_string(),
101                inner: self.inner.clone(),
102            })),
103            None => Ok(None),
104        }
105    }
106
107    fn connection_id(&self) -> u64 {
108        self.connection_id
109    }
110
111    fn shutdown(&self) {}
112}
113
114#[async_trait]
115impl KeyValueBucket for MemoryBucketRef {
116    async fn insert(
117        &self,
118        key: &Key,
119        value: bytes::Bytes,
120        revision: u64,
121    ) -> Result<StoreOutcome, StoreError> {
122        let mut locked_data = self.inner.data.lock();
123        let mut b = locked_data.get_mut(&self.name);
124        let Some(bucket) = b.as_mut() else {
125            return Err(StoreError::MissingBucket(self.name.to_string()));
126        };
127        let outcome = match bucket.data.entry(key.to_string()) {
128            Entry::Vacant(e) => {
129                e.insert((revision, value.clone()));
130                let _ = self.inner.change_sender.send(MemoryEvent::Put {
131                    key: key.to_string(),
132                    value,
133                });
134                StoreOutcome::Created(revision)
135            }
136            Entry::Occupied(mut entry) => {
137                let (rev, _v) = entry.get();
138                if *rev == revision {
139                    StoreOutcome::Exists(revision)
140                } else {
141                    entry.insert((revision, value));
142                    StoreOutcome::Created(revision)
143                }
144            }
145        };
146        Ok(outcome)
147    }
148
149    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
150        let locked_data = self.inner.data.lock();
151        let Some(bucket) = locked_data.get(&self.name) else {
152            return Ok(None);
153        };
154        Ok(bucket.data.get(&key.0).map(|(_, v)| v.clone()))
155    }
156
157    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
158        let mut locked_data = self.inner.data.lock();
159        let Some(bucket) = locked_data.get_mut(&self.name) else {
160            return Err(StoreError::MissingBucket(self.name.to_string()));
161        };
162        if bucket.data.remove(&key.0).is_some() {
163            let _ = self.inner.change_sender.send(MemoryEvent::Delete {
164                key: key.to_string(),
165            });
166        }
167        Ok(())
168    }
169
170    /// All current values in the bucket first, then block waiting for new
171    /// values to be published.
172    /// Caller takes the lock so only a single caller may use this at once.
173    async fn watch(
174        &self,
175    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
176        // All the existing ones first
177        let mut existing_items = vec![];
178        let mut seen_keys = HashSet::new();
179        let data_lock = self.inner.data.lock();
180        let Some(bucket) = data_lock.get(&self.name) else {
181            return Err(StoreError::MissingBucket(self.name.to_string()));
182        };
183        for (key, (_rev, v)) in &bucket.data {
184            seen_keys.insert(key.clone());
185            let item = KeyValue::new(key.clone(), v.clone());
186            existing_items.push(WatchEvent::Put(item));
187        }
188        drop(data_lock);
189
190        Ok(Box::pin(async_stream::stream! {
191            for event in existing_items {
192                yield event;
193            }
194            // Now any new ones
195            let mut rcv_lock = self.inner.change_receiver.lock().await;
196            loop {
197                match rcv_lock.recv().await {
198                    None => {
199                        // Channel is closed, no more values coming
200                        break;
201                    },
202                    Some(MemoryEvent::Put { key, value }) => {
203                        if seen_keys.contains(&key) {
204                            continue;
205                        }
206                        let item = KeyValue::new(key, value);
207                        yield WatchEvent::Put(item);
208                    },
209                    Some(MemoryEvent::Delete { key }) => {
210                        yield WatchEvent::Delete(Key::from_raw(key));
211                    }
212                }
213            }
214        }))
215    }
216
217    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
218        let locked_data = self.inner.data.lock();
219        match locked_data.get(&self.name) {
220            Some(bucket) => Ok(bucket
221                .data
222                .iter()
223                .map(|(k, (_rev, v))| (k.to_string(), v.clone()))
224                .collect()),
225            None => Err(StoreError::MissingBucket(self.name.clone())),
226        }
227    }
228}