dynamo_runtime/storage/key_value_store/
mem.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::hash_map::Entry;
5use std::collections::{HashMap, HashSet};
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use rand::Rng as _;
12use tokio::sync::Mutex;
13use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
14
15use crate::storage::key_value_store::Key;
16
17use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
18
19#[derive(Clone)]
20pub struct MemoryStore {
21    inner: Arc<MemoryStoreInner>,
22    connection_id: u64,
23}
24
25impl Default for MemoryStore {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31struct MemoryStoreInner {
32    data: Mutex<HashMap<String, MemoryBucket>>,
33    change_sender: UnboundedSender<(String, String)>,
34    change_receiver: Mutex<UnboundedReceiver<(String, String)>>,
35}
36
37pub struct MemoryBucketRef {
38    name: String,
39    inner: Arc<MemoryStoreInner>,
40}
41
42struct MemoryBucket {
43    data: HashMap<String, (u64, String)>,
44}
45
46impl MemoryBucket {
47    fn new() -> Self {
48        MemoryBucket {
49            data: HashMap::new(),
50        }
51    }
52}
53
54impl MemoryStore {
55    pub fn new() -> Self {
56        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
57        MemoryStore {
58            inner: Arc::new(MemoryStoreInner {
59                data: Mutex::new(HashMap::new()),
60                change_sender: tx,
61                change_receiver: Mutex::new(rx),
62            }),
63            connection_id: rand::rng().random(),
64        }
65    }
66}
67
68#[async_trait]
69impl KeyValueStore for MemoryStore {
70    async fn get_or_create_bucket(
71        &self,
72        bucket_name: &str,
73        // MemoryStore doesn't respect TTL yet
74        _ttl: Option<Duration>,
75    ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
76        let mut locked_data = self.inner.data.lock().await;
77        // Ensure the bucket exists
78        locked_data
79            .entry(bucket_name.to_string())
80            .or_insert_with(MemoryBucket::new);
81        // Return an object able to access it
82        Ok(Box::new(MemoryBucketRef {
83            name: bucket_name.to_string(),
84            inner: self.inner.clone(),
85        }))
86    }
87
88    /// This operation cannot fail on MemoryStore. Always returns Ok.
89    async fn get_bucket(
90        &self,
91        bucket_name: &str,
92    ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
93        let locked_data = self.inner.data.lock().await;
94        match locked_data.get(bucket_name) {
95            Some(_) => Ok(Some(Box::new(MemoryBucketRef {
96                name: bucket_name.to_string(),
97                inner: self.inner.clone(),
98            }))),
99            None => Ok(None),
100        }
101    }
102
103    fn connection_id(&self) -> u64 {
104        self.connection_id
105    }
106}
107
108#[async_trait]
109impl KeyValueBucket for MemoryBucketRef {
110    async fn insert(
111        &self,
112        key: &Key,
113        value: &str,
114        revision: u64,
115    ) -> Result<StoreOutcome, StoreError> {
116        let mut locked_data = self.inner.data.lock().await;
117        let mut b = locked_data.get_mut(&self.name);
118        let Some(bucket) = b.as_mut() else {
119            return Err(StoreError::MissingBucket(self.name.to_string()));
120        };
121        let outcome = match bucket.data.entry(key.to_string()) {
122            Entry::Vacant(e) => {
123                e.insert((revision, value.to_string()));
124                let _ = self
125                    .inner
126                    .change_sender
127                    .send((key.to_string(), value.to_string()));
128                StoreOutcome::Created(revision)
129            }
130            Entry::Occupied(mut entry) => {
131                let (rev, _v) = entry.get();
132                if *rev == revision {
133                    StoreOutcome::Exists(revision)
134                } else {
135                    entry.insert((revision, value.to_string()));
136                    StoreOutcome::Created(revision)
137                }
138            }
139        };
140        Ok(outcome)
141    }
142
143    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
144        let locked_data = self.inner.data.lock().await;
145        let Some(bucket) = locked_data.get(&self.name) else {
146            return Ok(None);
147        };
148        Ok(bucket
149            .data
150            .get(&key.0)
151            .map(|(_, v)| bytes::Bytes::from(v.clone())))
152    }
153
154    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
155        let mut locked_data = self.inner.data.lock().await;
156        let Some(bucket) = locked_data.get_mut(&self.name) else {
157            return Err(StoreError::MissingBucket(self.name.to_string()));
158        };
159        bucket.data.remove(&key.0);
160        Ok(())
161    }
162
163    /// All current values in the bucket first, then block waiting for new
164    /// values to be published.
165    /// Caller takes the lock so only a single caller may use this at once.
166    async fn watch(
167        &self,
168    ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
169    {
170        Ok(Box::pin(async_stream::stream! {
171            // All the existing ones first
172            let mut seen = HashSet::new();
173            let data_lock = self.inner.data.lock().await;
174            let Some(bucket) = data_lock.get(&self.name) else {
175                tracing::error!(bucket_name = self.name, "watch: Missing bucket");
176                return;
177            };
178            for (_rev, v) in bucket.data.values() {
179                seen.insert(v.clone());
180                yield bytes::Bytes::from(v.clone());
181            }
182            drop(data_lock);
183            // Now any new ones
184            let mut rcv_lock = self.inner.change_receiver.lock().await;
185            loop {
186                match rcv_lock.recv().await {
187                    None => {
188                        // Channel is closed, no more values coming
189                        break;
190                    },
191                    Some((_k, v)) => {
192                        if seen.contains(&v) {
193                            continue;
194                        }
195                        yield bytes::Bytes::from(v.clone());
196                    }
197                }
198            }
199        }))
200    }
201
202    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
203        let locked_data = self.inner.data.lock().await;
204        match locked_data.get(&self.name) {
205            Some(bucket) => Ok(bucket
206                .data
207                .iter()
208                .map(|(k, (_rev, v))| (k.to_string(), bytes::Bytes::from(v.clone())))
209                .collect()),
210            None => Err(StoreError::MissingBucket(self.name.clone())),
211        }
212    }
213}