dynamo_llm/key_value_store/
mem.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::collections::hash_map::Entry;
17use std::collections::{HashMap, HashSet};
18use std::pin::Pin;
19use std::sync::Arc;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
24use tokio::sync::Mutex;
25
26use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
27
28#[derive(Clone)]
29pub struct MemoryStorage {
30    inner: Arc<MemoryStorageInner>,
31}
32
33impl Default for MemoryStorage {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39struct MemoryStorageInner {
40    data: Mutex<HashMap<String, MemoryBucket>>,
41    change_sender: UnboundedSender<(String, String)>,
42    change_receiver: Mutex<UnboundedReceiver<(String, String)>>,
43}
44
45pub struct MemoryBucketRef {
46    name: String,
47    inner: Arc<MemoryStorageInner>,
48}
49
50struct MemoryBucket {
51    data: HashMap<String, (u64, String)>,
52}
53
54impl MemoryBucket {
55    fn new() -> Self {
56        MemoryBucket {
57            data: HashMap::new(),
58        }
59    }
60}
61
62impl MemoryStorage {
63    pub fn new() -> Self {
64        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
65        MemoryStorage {
66            inner: Arc::new(MemoryStorageInner {
67                data: Mutex::new(HashMap::new()),
68                change_sender: tx,
69                change_receiver: Mutex::new(rx),
70            }),
71        }
72    }
73}
74
75#[async_trait]
76impl KeyValueStore for MemoryStorage {
77    async fn get_or_create_bucket(
78        &self,
79        bucket_name: &str,
80        // MemoryStorage doesn't respect TTL yet
81        _ttl: Option<Duration>,
82    ) -> Result<Box<dyn KeyValueBucket>, StorageError> {
83        let mut locked_data = self.inner.data.lock().await;
84        // Ensure the bucket exists
85        locked_data
86            .entry(bucket_name.to_string())
87            .or_insert_with(MemoryBucket::new);
88        // Return an object able to access it
89        Ok(Box::new(MemoryBucketRef {
90            name: bucket_name.to_string(),
91            inner: self.inner.clone(),
92        }))
93    }
94
95    /// This operation cannot fail on MemoryStorage. Always returns Ok.
96    async fn get_bucket(
97        &self,
98        bucket_name: &str,
99    ) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
100        let locked_data = self.inner.data.lock().await;
101        match locked_data.get(bucket_name) {
102            Some(_) => Ok(Some(Box::new(MemoryBucketRef {
103                name: bucket_name.to_string(),
104                inner: self.inner.clone(),
105            }))),
106            None => Ok(None),
107        }
108    }
109}
110
111#[async_trait]
112impl KeyValueBucket for MemoryBucketRef {
113    async fn insert(
114        &self,
115        key: String,
116        value: String,
117        revision: u64,
118    ) -> Result<StorageOutcome, StorageError> {
119        let mut locked_data = self.inner.data.lock().await;
120        let mut b = locked_data.get_mut(&self.name);
121        let Some(bucket) = b.as_mut() else {
122            return Err(StorageError::MissingBucket(self.name.to_string()));
123        };
124        let outcome = match bucket.data.entry(key.to_string()) {
125            Entry::Vacant(e) => {
126                e.insert((revision, value.clone()));
127                let _ = self.inner.change_sender.send((key, value));
128                StorageOutcome::Created(revision)
129            }
130            Entry::Occupied(mut entry) => {
131                let (rev, _v) = entry.get();
132                if *rev == revision {
133                    StorageOutcome::Exists(revision)
134                } else {
135                    entry.insert((revision, value));
136                    StorageOutcome::Created(revision)
137                }
138            }
139        };
140        Ok(outcome)
141    }
142
143    async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
144        let locked_data = self.inner.data.lock().await;
145        let Some(bucket) = locked_data.get(&self.name) else {
146            return Ok(None);
147        };
148        Ok(bucket
149            .data
150            .get(key)
151            .map(|(_, v)| bytes::Bytes::from(v.clone())))
152    }
153
154    async fn delete(&self, key: &str) -> Result<(), StorageError> {
155        let mut locked_data = self.inner.data.lock().await;
156        let Some(bucket) = locked_data.get_mut(&self.name) else {
157            return Err(StorageError::MissingBucket(self.name.to_string()));
158        };
159        bucket.data.remove(key);
160        Ok(())
161    }
162
163    /// All current values in the bucket first, then block waiting for new
164    /// values to be published.
165    /// Caller takes the lock so only a single caller may use this at once.
166    async fn watch(
167        &self,
168    ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
169    {
170        Ok(Box::pin(async_stream::stream! {
171            // All the existing ones first
172            let mut seen = HashSet::new();
173            let data_lock = self.inner.data.lock().await;
174            let Some(bucket) = data_lock.get(&self.name) else {
175                tracing::error!(bucket_name = self.name, "watch: Missing bucket");
176                return;
177            };
178            for (_rev, v) in bucket.data.values() {
179                seen.insert(v.clone());
180                yield bytes::Bytes::from(v.clone());
181            }
182            drop(data_lock);
183            // Now any new ones
184            let mut rcv_lock = self.inner.change_receiver.lock().await;
185            loop {
186                match rcv_lock.recv().await {
187                    None => {
188                        // Channel is closed, no more values coming
189                        break;
190                    },
191                    Some((_k, v)) => {
192                        if seen.contains(&v) {
193                            continue;
194                        }
195                        yield bytes::Bytes::from(v.clone());
196                    }
197                }
198            }
199        }))
200    }
201
202    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
203        let locked_data = self.inner.data.lock().await;
204        match locked_data.get(&self.name) {
205            Some(bucket) => Ok(bucket
206                .data
207                .iter()
208                .map(|(k, (_rev, v))| (k.to_string(), bytes::Bytes::from(v.clone())))
209                .collect()),
210            None => Err(StorageError::MissingBucket(self.name.clone())),
211        }
212    }
213}