dynamo_llm/key_value_store/
mem.rs1use std::collections::hash_map::Entry;
17use std::collections::{HashMap, HashSet};
18use std::pin::Pin;
19use std::sync::Arc;
20use std::time::Duration;
21
22use async_trait::async_trait;
23use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
24use tokio::sync::Mutex;
25
26use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
27
28#[derive(Clone)]
29pub struct MemoryStorage {
30 inner: Arc<MemoryStorageInner>,
31}
32
33impl Default for MemoryStorage {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39struct MemoryStorageInner {
40 data: Mutex<HashMap<String, MemoryBucket>>,
41 change_sender: UnboundedSender<(String, String)>,
42 change_receiver: Mutex<UnboundedReceiver<(String, String)>>,
43}
44
45pub struct MemoryBucketRef {
46 name: String,
47 inner: Arc<MemoryStorageInner>,
48}
49
50struct MemoryBucket {
51 data: HashMap<String, (u64, String)>,
52}
53
54impl MemoryBucket {
55 fn new() -> Self {
56 MemoryBucket {
57 data: HashMap::new(),
58 }
59 }
60}
61
62impl MemoryStorage {
63 pub fn new() -> Self {
64 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
65 MemoryStorage {
66 inner: Arc::new(MemoryStorageInner {
67 data: Mutex::new(HashMap::new()),
68 change_sender: tx,
69 change_receiver: Mutex::new(rx),
70 }),
71 }
72 }
73}
74
75#[async_trait]
76impl KeyValueStore for MemoryStorage {
77 async fn get_or_create_bucket(
78 &self,
79 bucket_name: &str,
80 _ttl: Option<Duration>,
82 ) -> Result<Box<dyn KeyValueBucket>, StorageError> {
83 let mut locked_data = self.inner.data.lock().await;
84 locked_data
86 .entry(bucket_name.to_string())
87 .or_insert_with(MemoryBucket::new);
88 Ok(Box::new(MemoryBucketRef {
90 name: bucket_name.to_string(),
91 inner: self.inner.clone(),
92 }))
93 }
94
95 async fn get_bucket(
97 &self,
98 bucket_name: &str,
99 ) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
100 let locked_data = self.inner.data.lock().await;
101 match locked_data.get(bucket_name) {
102 Some(_) => Ok(Some(Box::new(MemoryBucketRef {
103 name: bucket_name.to_string(),
104 inner: self.inner.clone(),
105 }))),
106 None => Ok(None),
107 }
108 }
109}
110
111#[async_trait]
112impl KeyValueBucket for MemoryBucketRef {
113 async fn insert(
114 &self,
115 key: String,
116 value: String,
117 revision: u64,
118 ) -> Result<StorageOutcome, StorageError> {
119 let mut locked_data = self.inner.data.lock().await;
120 let mut b = locked_data.get_mut(&self.name);
121 let Some(bucket) = b.as_mut() else {
122 return Err(StorageError::MissingBucket(self.name.to_string()));
123 };
124 let outcome = match bucket.data.entry(key.to_string()) {
125 Entry::Vacant(e) => {
126 e.insert((revision, value.clone()));
127 let _ = self.inner.change_sender.send((key, value));
128 StorageOutcome::Created(revision)
129 }
130 Entry::Occupied(mut entry) => {
131 let (rev, _v) = entry.get();
132 if *rev == revision {
133 StorageOutcome::Exists(revision)
134 } else {
135 entry.insert((revision, value));
136 StorageOutcome::Created(revision)
137 }
138 }
139 };
140 Ok(outcome)
141 }
142
143 async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
144 let locked_data = self.inner.data.lock().await;
145 let Some(bucket) = locked_data.get(&self.name) else {
146 return Ok(None);
147 };
148 Ok(bucket
149 .data
150 .get(key)
151 .map(|(_, v)| bytes::Bytes::from(v.clone())))
152 }
153
154 async fn delete(&self, key: &str) -> Result<(), StorageError> {
155 let mut locked_data = self.inner.data.lock().await;
156 let Some(bucket) = locked_data.get_mut(&self.name) else {
157 return Err(StorageError::MissingBucket(self.name.to_string()));
158 };
159 bucket.data.remove(key);
160 Ok(())
161 }
162
163 async fn watch(
167 &self,
168 ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
169 {
170 Ok(Box::pin(async_stream::stream! {
171 let mut seen = HashSet::new();
173 let data_lock = self.inner.data.lock().await;
174 let Some(bucket) = data_lock.get(&self.name) else {
175 tracing::error!(bucket_name = self.name, "watch: Missing bucket");
176 return;
177 };
178 for (_rev, v) in bucket.data.values() {
179 seen.insert(v.clone());
180 yield bytes::Bytes::from(v.clone());
181 }
182 drop(data_lock);
183 let mut rcv_lock = self.inner.change_receiver.lock().await;
185 loop {
186 match rcv_lock.recv().await {
187 None => {
188 break;
190 },
191 Some((_k, v)) => {
192 if seen.contains(&v) {
193 continue;
194 }
195 yield bytes::Bytes::from(v.clone());
196 }
197 }
198 }
199 }))
200 }
201
202 async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
203 let locked_data = self.inner.data.lock().await;
204 match locked_data.get(&self.name) {
205 Some(bucket) => Ok(bucket
206 .data
207 .iter()
208 .map(|(k, (_rev, v))| (k.to_string(), bytes::Bytes::from(v.clone())))
209 .collect()),
210 None => Err(StorageError::MissingBucket(self.name.clone())),
211 }
212 }
213}