Skip to main content

dynamo_runtime/storage/kv/
mem.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::hash_map::Entry;
5use std::collections::{HashMap, HashSet};
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use rand::Rng as _;
12use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
13
14use super::{Bucket, Key, KeyValue, Store, StoreError, StoreOutcome, WatchEvent};
15
16#[derive(Clone, Debug)]
17enum MemoryEvent {
18    Put { key: String, value: bytes::Bytes },
19    Delete { key: String },
20}
21
22#[derive(Clone)]
23pub struct MemoryStore {
24    inner: Arc<MemoryStoreInner>,
25    connection_id: u64,
26}
27
28impl Default for MemoryStore {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34struct MemoryStoreInner {
35    data: parking_lot::Mutex<HashMap<String, MemoryBucket>>,
36    change_sender: UnboundedSender<MemoryEvent>,
37    change_receiver: tokio::sync::Mutex<UnboundedReceiver<MemoryEvent>>,
38}
39
40pub struct MemoryBucketRef {
41    name: String,
42    inner: Arc<MemoryStoreInner>,
43}
44
45struct MemoryBucket {
46    data: HashMap<String, (u64, bytes::Bytes)>,
47}
48
49impl MemoryBucket {
50    fn new() -> Self {
51        MemoryBucket {
52            data: HashMap::new(),
53        }
54    }
55}
56
57impl MemoryStore {
58    pub(super) fn new() -> Self {
59        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
60        MemoryStore {
61            inner: Arc::new(MemoryStoreInner {
62                data: parking_lot::Mutex::new(HashMap::new()),
63                change_sender: tx,
64                change_receiver: tokio::sync::Mutex::new(rx),
65            }),
66            connection_id: rand::rng().random(),
67        }
68    }
69}
70
71#[async_trait]
72impl Store for MemoryStore {
73    type Bucket = MemoryBucketRef;
74
75    async fn get_or_create_bucket(
76        &self,
77        bucket_name: &str,
78        // MemoryStore doesn't respect TTL yet
79        _ttl: Option<Duration>,
80    ) -> Result<Self::Bucket, StoreError> {
81        let mut locked_data = self.inner.data.lock();
82        // Ensure the bucket exists
83        locked_data
84            .entry(bucket_name.to_string())
85            .or_insert_with(MemoryBucket::new);
86        // Return an object able to access it
87        Ok(MemoryBucketRef {
88            name: bucket_name.to_string(),
89            inner: self.inner.clone(),
90        })
91    }
92
93    /// This operation cannot fail on MemoryStore. Always returns Ok.
94    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
95        let locked_data = self.inner.data.lock();
96        match locked_data.get(bucket_name) {
97            Some(_) => Ok(Some(MemoryBucketRef {
98                name: bucket_name.to_string(),
99                inner: self.inner.clone(),
100            })),
101            None => Ok(None),
102        }
103    }
104
105    fn connection_id(&self) -> u64 {
106        self.connection_id
107    }
108
109    fn shutdown(&self) {}
110}
111
112#[async_trait]
113impl Bucket for MemoryBucketRef {
114    async fn insert(
115        &self,
116        key: &Key,
117        value: bytes::Bytes,
118        revision: u64,
119    ) -> Result<StoreOutcome, StoreError> {
120        let mut locked_data = self.inner.data.lock();
121        let mut b = locked_data.get_mut(&self.name);
122        let Some(bucket) = b.as_mut() else {
123            return Err(StoreError::MissingBucket(self.name.to_string()));
124        };
125        let outcome = match bucket.data.entry(key.to_string()) {
126            Entry::Vacant(e) => {
127                e.insert((revision, value.clone()));
128                let _ = self.inner.change_sender.send(MemoryEvent::Put {
129                    key: key.to_string(),
130                    value,
131                });
132                StoreOutcome::Created(revision)
133            }
134            Entry::Occupied(mut entry) => {
135                let (rev, _v) = entry.get();
136                if *rev == revision {
137                    StoreOutcome::Exists(revision)
138                } else {
139                    entry.insert((revision, value));
140                    StoreOutcome::Created(revision)
141                }
142            }
143        };
144        Ok(outcome)
145    }
146
147    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
148        let locked_data = self.inner.data.lock();
149        let Some(bucket) = locked_data.get(&self.name) else {
150            return Ok(None);
151        };
152        Ok(bucket.data.get(&key.0).map(|(_, v)| v.clone()))
153    }
154
155    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
156        let mut locked_data = self.inner.data.lock();
157        let Some(bucket) = locked_data.get_mut(&self.name) else {
158            return Err(StoreError::MissingBucket(self.name.to_string()));
159        };
160        if bucket.data.remove(&key.0).is_some() {
161            let _ = self.inner.change_sender.send(MemoryEvent::Delete {
162                key: key.to_string(),
163            });
164        }
165        Ok(())
166    }
167
168    /// All current values in the bucket first, then block waiting for new
169    /// values to be published.
170    /// Caller takes the lock so only a single caller may use this at once.
171    async fn watch(
172        &self,
173    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
174        // All the existing ones first
175        let mut existing_items = vec![];
176        let mut seen_keys = HashSet::new();
177        let data_lock = self.inner.data.lock();
178        let Some(bucket) = data_lock.get(&self.name) else {
179            return Err(StoreError::MissingBucket(self.name.to_string()));
180        };
181        for (key, (_rev, v)) in &bucket.data {
182            seen_keys.insert(key.clone());
183            let item = KeyValue::new(Key::new(key.clone()), v.clone());
184            existing_items.push(WatchEvent::Put(item));
185        }
186        drop(data_lock);
187
188        Ok(Box::pin(async_stream::stream! {
189            for event in existing_items {
190                yield event;
191            }
192            // Now any new ones
193            let mut rcv_lock = self.inner.change_receiver.lock().await;
194            loop {
195                match rcv_lock.recv().await {
196                    None => {
197                        // Channel is closed, no more values coming
198                        break;
199                    },
200                    Some(MemoryEvent::Put { key, value }) => {
201                        if seen_keys.contains(&key) {
202                            continue;
203                        }
204                        let item = KeyValue::new(Key::new(key), value);
205                        yield WatchEvent::Put(item);
206                    },
207                    Some(MemoryEvent::Delete { key }) => {
208                        yield WatchEvent::Delete(Key::new(key));
209                    }
210                }
211            }
212        }))
213    }
214
215    async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError> {
216        let locked_data = self.inner.data.lock();
217        match locked_data.get(&self.name) {
218            Some(bucket) => {
219                let mut out = HashMap::new();
220                for (k, (_rev, v)) in bucket.data.iter() {
221                    let key = Key::new([self.name.clone(), k.to_string()].join("/"));
222                    let value = v.clone();
223                    out.insert(key, value);
224                }
225                Ok(out)
226            }
227            None => Err(StoreError::MissingBucket(self.name.clone())),
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use crate::storage::kv::{Bucket as _, Key, MemoryStore, Store as _};
235    use std::collections::HashSet;
236
237    #[tokio::test]
238    async fn test_entries_full_path() {
239        let m = MemoryStore::new();
240        let bucket = m.get_or_create_bucket("bucket1", None).await.unwrap();
241        let _ = bucket
242            .insert(&Key::new("key1".to_string()), "value1".into(), 0)
243            .await
244            .unwrap();
245        let _ = bucket
246            .insert(&Key::new("key2".to_string()), "value2".into(), 0)
247            .await
248            .unwrap();
249        let entries = bucket.entries().await.unwrap();
250        let keys: HashSet<Key> = entries.into_keys().collect();
251        assert!(keys.contains(&Key::new("bucket1/key1".to_string())));
252        assert!(keys.contains(&Key::new("bucket1/key2".to_string())));
253    }
254}