dynamo_runtime/storage/key_value_store/
mem.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::hash_map::Entry;
5use std::collections::{HashMap, HashSet};
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use rand::Rng as _;
12use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
13
14use crate::storage::key_value_store::{Key, KeyValue, WatchEvent};
15
16use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
17
18#[derive(Clone, Debug)]
19enum MemoryEvent {
20    Put { key: String, value: String },
21    Delete { key: String },
22}
23
24#[derive(Clone)]
25pub struct MemoryStore {
26    inner: Arc<MemoryStoreInner>,
27    connection_id: u64,
28}
29
30impl Default for MemoryStore {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36struct MemoryStoreInner {
37    data: parking_lot::Mutex<HashMap<String, MemoryBucket>>,
38    change_sender: UnboundedSender<MemoryEvent>,
39    change_receiver: tokio::sync::Mutex<UnboundedReceiver<MemoryEvent>>,
40}
41
42pub struct MemoryBucketRef {
43    name: String,
44    inner: Arc<MemoryStoreInner>,
45}
46
47struct MemoryBucket {
48    data: HashMap<String, (u64, String)>,
49}
50
51impl MemoryBucket {
52    fn new() -> Self {
53        MemoryBucket {
54            data: HashMap::new(),
55        }
56    }
57}
58
59impl MemoryStore {
60    pub fn new() -> Self {
61        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
62        MemoryStore {
63            inner: Arc::new(MemoryStoreInner {
64                data: parking_lot::Mutex::new(HashMap::new()),
65                change_sender: tx,
66                change_receiver: tokio::sync::Mutex::new(rx),
67            }),
68            connection_id: rand::rng().random(),
69        }
70    }
71}
72
73#[async_trait]
74impl KeyValueStore for MemoryStore {
75    type Bucket = MemoryBucketRef;
76
77    async fn get_or_create_bucket(
78        &self,
79        bucket_name: &str,
80        // MemoryStore doesn't respect TTL yet
81        _ttl: Option<Duration>,
82    ) -> Result<Self::Bucket, StoreError> {
83        let mut locked_data = self.inner.data.lock();
84        // Ensure the bucket exists
85        locked_data
86            .entry(bucket_name.to_string())
87            .or_insert_with(MemoryBucket::new);
88        // Return an object able to access it
89        Ok(MemoryBucketRef {
90            name: bucket_name.to_string(),
91            inner: self.inner.clone(),
92        })
93    }
94
95    /// This operation cannot fail on MemoryStore. Always returns Ok.
96    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
97        let locked_data = self.inner.data.lock();
98        match locked_data.get(bucket_name) {
99            Some(_) => Ok(Some(MemoryBucketRef {
100                name: bucket_name.to_string(),
101                inner: self.inner.clone(),
102            })),
103            None => Ok(None),
104        }
105    }
106
107    fn connection_id(&self) -> u64 {
108        self.connection_id
109    }
110}
111
112#[async_trait]
113impl KeyValueBucket for MemoryBucketRef {
114    async fn insert(
115        &self,
116        key: &Key,
117        value: &str,
118        revision: u64,
119    ) -> Result<StoreOutcome, StoreError> {
120        let mut locked_data = self.inner.data.lock();
121        let mut b = locked_data.get_mut(&self.name);
122        let Some(bucket) = b.as_mut() else {
123            return Err(StoreError::MissingBucket(self.name.to_string()));
124        };
125        let outcome = match bucket.data.entry(key.to_string()) {
126            Entry::Vacant(e) => {
127                e.insert((revision, value.to_string()));
128                let _ = self.inner.change_sender.send(MemoryEvent::Put {
129                    key: key.to_string(),
130                    value: value.to_string(),
131                });
132                StoreOutcome::Created(revision)
133            }
134            Entry::Occupied(mut entry) => {
135                let (rev, _v) = entry.get();
136                if *rev == revision {
137                    StoreOutcome::Exists(revision)
138                } else {
139                    entry.insert((revision, value.to_string()));
140                    StoreOutcome::Created(revision)
141                }
142            }
143        };
144        Ok(outcome)
145    }
146
147    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
148        let locked_data = self.inner.data.lock();
149        let Some(bucket) = locked_data.get(&self.name) else {
150            return Ok(None);
151        };
152        Ok(bucket
153            .data
154            .get(&key.0)
155            .map(|(_, v)| bytes::Bytes::from(v.clone())))
156    }
157
158    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
159        let mut locked_data = self.inner.data.lock();
160        let Some(bucket) = locked_data.get_mut(&self.name) else {
161            return Err(StoreError::MissingBucket(self.name.to_string()));
162        };
163        if bucket.data.remove(&key.0).is_some() {
164            let _ = self.inner.change_sender.send(MemoryEvent::Delete {
165                key: key.to_string(),
166            });
167        }
168        Ok(())
169    }
170
171    /// All current values in the bucket first, then block waiting for new
172    /// values to be published.
173    /// Caller takes the lock so only a single caller may use this at once.
174    async fn watch(
175        &self,
176    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
177        // All the existing ones first
178        let mut existing_items = vec![];
179        let mut seen_keys = HashSet::new();
180        let data_lock = self.inner.data.lock();
181        let Some(bucket) = data_lock.get(&self.name) else {
182            return Err(StoreError::MissingBucket(self.name.to_string()));
183        };
184        for (key, (_rev, v)) in &bucket.data {
185            seen_keys.insert(key.clone());
186            let item = KeyValue::new(key.clone(), bytes::Bytes::from(v.clone().into_bytes()));
187            existing_items.push(WatchEvent::Put(item));
188        }
189        drop(data_lock);
190
191        Ok(Box::pin(async_stream::stream! {
192            for event in existing_items {
193                yield event;
194            }
195            // Now any new ones
196            let mut rcv_lock = self.inner.change_receiver.lock().await;
197            loop {
198                match rcv_lock.recv().await {
199                    None => {
200                        // Channel is closed, no more values coming
201                        break;
202                    },
203                    Some(MemoryEvent::Put { key, value }) => {
204                        if seen_keys.contains(&key) {
205                            continue;
206                        }
207                        let item = KeyValue::new(key, bytes::Bytes::from(value));
208                        yield WatchEvent::Put(item);
209                    },
210                    Some(MemoryEvent::Delete { key }) => {
211                        let item = KeyValue::new(key, bytes::Bytes::new());
212                        yield WatchEvent::Delete(item);
213                    }
214                }
215            }
216        }))
217    }
218
219    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
220        let locked_data = self.inner.data.lock();
221        match locked_data.get(&self.name) {
222            Some(bucket) => Ok(bucket
223                .data
224                .iter()
225                .map(|(k, (_rev, v))| (k.to_string(), bytes::Bytes::from(v.clone())))
226                .collect()),
227            None => Err(StoreError::MissingBucket(self.name.clone())),
228        }
229    }
230}