1use async_trait::async_trait;
2use common::{DakeraError, NamespaceId, Result, Vector, VectorId};
3use parking_lot::RwLock;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::traits::VectorStorage;
8
9#[derive(Clone)]
12pub struct InMemoryStorage {
13 namespaces: Arc<RwLock<HashMap<NamespaceId, HashMap<VectorId, Vector>>>>,
15}
16
17impl InMemoryStorage {
18 pub fn new() -> Self {
19 Self {
20 namespaces: Arc::new(RwLock::new(HashMap::new())),
21 }
22 }
23}
24
25impl Default for InMemoryStorage {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31fn now_secs() -> u64 {
33 std::time::SystemTime::now()
34 .duration_since(std::time::UNIX_EPOCH)
35 .unwrap_or_default()
36 .as_secs()
37}
38
39#[async_trait]
40impl VectorStorage for InMemoryStorage {
41 async fn upsert(&self, namespace: &NamespaceId, vectors: Vec<Vector>) -> Result<usize> {
42 let mut namespaces = self.namespaces.write();
43 let ns = namespaces.entry(namespace.clone()).or_default();
44
45 if !vectors.is_empty() {
47 let now = now_secs();
48 let expected_dim = ns
49 .values()
50 .find(|v| !v.is_expired_at(now))
51 .map(|v| v.values.len())
52 .unwrap_or_else(|| vectors[0].values.len());
53 for v in &vectors {
54 if v.values.len() != expected_dim {
55 return Err(DakeraError::DimensionMismatch {
56 expected: expected_dim,
57 actual: v.values.len(),
58 });
59 }
60 }
61 }
62
63 let count = vectors.len();
64 for mut vector in vectors {
65 vector.apply_ttl();
67 ns.insert(vector.id.clone(), vector);
68 }
69
70 tracing::debug!(
71 namespace = %namespace,
72 count = count,
73 "Upserted vectors"
74 );
75
76 Ok(count)
77 }
78
79 async fn get(&self, namespace: &NamespaceId, ids: &[VectorId]) -> Result<Vec<Vector>> {
80 let namespaces = self.namespaces.read();
81 let ns = namespaces
82 .get(namespace)
83 .ok_or_else(|| DakeraError::NamespaceNotFound(namespace.clone()))?;
84
85 let now = now_secs();
87 Ok(ids
88 .iter()
89 .filter_map(|id| ns.get(id).cloned())
90 .filter(|v| !v.is_expired_at(now))
91 .collect())
92 }
93
94 async fn get_all(&self, namespace: &NamespaceId) -> Result<Vec<Vector>> {
95 let namespaces = self.namespaces.read();
96 let ns = namespaces
97 .get(namespace)
98 .ok_or_else(|| DakeraError::NamespaceNotFound(namespace.clone()))?;
99
100 let now = now_secs();
102 Ok(ns
103 .values()
104 .filter(|v| !v.is_expired_at(now))
105 .cloned()
106 .collect())
107 }
108
109 async fn delete(&self, namespace: &NamespaceId, ids: &[VectorId]) -> Result<usize> {
110 let mut namespaces = self.namespaces.write();
111 let ns = namespaces
112 .get_mut(namespace)
113 .ok_or_else(|| DakeraError::NamespaceNotFound(namespace.clone()))?;
114
115 let mut deleted = 0;
116 for id in ids {
117 if ns.remove(id).is_some() {
118 deleted += 1;
119 }
120 }
121
122 tracing::debug!(
123 namespace = %namespace,
124 deleted = deleted,
125 "Deleted vectors"
126 );
127
128 Ok(deleted)
129 }
130
131 async fn namespace_exists(&self, namespace: &NamespaceId) -> Result<bool> {
132 Ok(self.namespaces.read().contains_key(namespace))
133 }
134
135 async fn ensure_namespace(&self, namespace: &NamespaceId) -> Result<()> {
136 self.namespaces
137 .write()
138 .entry(namespace.clone())
139 .or_default();
140 Ok(())
141 }
142
143 async fn count(&self, namespace: &NamespaceId) -> Result<usize> {
144 let namespaces = self.namespaces.read();
145 let now = now_secs();
146 Ok(namespaces
147 .get(namespace)
148 .map(|ns| ns.values().filter(|v| !v.is_expired_at(now)).count())
149 .unwrap_or(0))
150 }
151
152 async fn dimension(&self, namespace: &NamespaceId) -> Result<Option<usize>> {
153 let namespaces = self.namespaces.read();
154 let now = now_secs();
155 Ok(namespaces
156 .get(namespace)
157 .and_then(|ns| ns.values().find(|v| !v.is_expired_at(now)))
158 .map(|v| v.values.len()))
159 }
160
161 async fn list_namespaces(&self) -> Result<Vec<NamespaceId>> {
162 let namespaces = self.namespaces.read();
163 Ok(namespaces.keys().cloned().collect())
164 }
165
166 async fn delete_namespace(&self, namespace: &NamespaceId) -> Result<bool> {
167 let mut namespaces = self.namespaces.write();
168 let existed = namespaces.remove(namespace).is_some();
169
170 if existed {
171 tracing::debug!(
172 namespace = %namespace,
173 "Deleted namespace"
174 );
175 }
176
177 Ok(existed)
178 }
179
180 async fn cleanup_expired(&self, namespace: &NamespaceId) -> Result<usize> {
181 let mut namespaces = self.namespaces.write();
182 let ns = match namespaces.get_mut(namespace) {
183 Some(ns) => ns,
184 None => return Ok(0),
185 };
186
187 let now = now_secs();
188 let before_count = ns.len();
189 ns.retain(|_, v| !v.is_expired_at(now));
190 let removed = before_count - ns.len();
191
192 if removed > 0 {
193 tracing::debug!(
194 namespace = %namespace,
195 removed = removed,
196 "Cleaned up expired vectors"
197 );
198 }
199
200 Ok(removed)
201 }
202
203 async fn cleanup_all_expired(&self) -> Result<usize> {
204 let mut namespaces = self.namespaces.write();
205 let mut total_removed = 0;
206
207 let now = now_secs();
208 for (namespace, ns) in namespaces.iter_mut() {
209 let before_count = ns.len();
210 ns.retain(|_, v| !v.is_expired_at(now));
211 let removed = before_count - ns.len();
212 total_removed += removed;
213
214 if removed > 0 {
215 tracing::debug!(
216 namespace = %namespace,
217 removed = removed,
218 "Cleaned up expired vectors"
219 );
220 }
221 }
222
223 if total_removed > 0 {
224 tracing::info!(
225 total_removed = total_removed,
226 "Cleaned up expired vectors across all namespaces"
227 );
228 }
229
230 Ok(total_removed)
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[tokio::test]
239 async fn test_upsert_and_get() {
240 let storage = InMemoryStorage::new();
241 let namespace = "test".to_string();
242
243 let vectors = vec![Vector {
244 id: "v1".to_string(),
245 values: vec![1.0, 2.0, 3.0],
246 metadata: None,
247 ttl_seconds: None,
248 expires_at: None,
249 }];
250
251 storage.ensure_namespace(&namespace).await.unwrap();
252 let count = storage.upsert(&namespace, vectors).await.unwrap();
253 assert_eq!(count, 1);
254
255 let retrieved = storage.get(&namespace, &["v1".to_string()]).await.unwrap();
256 assert_eq!(retrieved.len(), 1);
257 assert_eq!(retrieved[0].id, "v1");
258 }
259
260 #[tokio::test]
261 async fn test_dimension_mismatch() {
262 let storage = InMemoryStorage::new();
263 let namespace = "test".to_string();
264
265 storage.ensure_namespace(&namespace).await.unwrap();
266
267 storage
269 .upsert(
270 &namespace,
271 vec![Vector {
272 id: "v1".to_string(),
273 values: vec![1.0, 2.0, 3.0],
274 metadata: None,
275 ttl_seconds: None,
276 expires_at: None,
277 }],
278 )
279 .await
280 .unwrap();
281
282 let result = storage
284 .upsert(
285 &namespace,
286 vec![Vector {
287 id: "v2".to_string(),
288 values: vec![1.0, 2.0],
289 metadata: None,
290 ttl_seconds: None,
291 expires_at: None,
292 }],
293 )
294 .await;
295
296 assert!(matches!(result, Err(DakeraError::DimensionMismatch { .. })));
297 }
298
299 #[tokio::test]
300 async fn test_delete() {
301 let storage = InMemoryStorage::new();
302 let namespace = "test".to_string();
303
304 storage.ensure_namespace(&namespace).await.unwrap();
305 storage
306 .upsert(
307 &namespace,
308 vec![
309 Vector {
310 id: "v1".to_string(),
311 values: vec![1.0],
312 metadata: None,
313 ttl_seconds: None,
314 expires_at: None,
315 },
316 Vector {
317 id: "v2".to_string(),
318 values: vec![2.0],
319 metadata: None,
320 ttl_seconds: None,
321 expires_at: None,
322 },
323 ],
324 )
325 .await
326 .unwrap();
327
328 let deleted = storage
329 .delete(&namespace, &["v1".to_string()])
330 .await
331 .unwrap();
332 assert_eq!(deleted, 1);
333
334 let count = storage.count(&namespace).await.unwrap();
335 assert_eq!(count, 1);
336 }
337
338 #[tokio::test]
339 async fn test_get_all() {
340 let storage = InMemoryStorage::new();
341 let namespace = "test".to_string();
342
343 storage.ensure_namespace(&namespace).await.unwrap();
344 storage
345 .upsert(
346 &namespace,
347 vec![
348 Vector {
349 id: "v1".to_string(),
350 values: vec![1.0, 2.0],
351 metadata: None,
352 ttl_seconds: None,
353 expires_at: None,
354 },
355 Vector {
356 id: "v2".to_string(),
357 values: vec![3.0, 4.0],
358 metadata: None,
359 ttl_seconds: None,
360 expires_at: None,
361 },
362 ],
363 )
364 .await
365 .unwrap();
366
367 let all = storage.get_all(&namespace).await.unwrap();
368 assert_eq!(all.len(), 2);
369 }
370
371 #[tokio::test]
372 async fn test_ttl_expired_vectors_filtered() {
373 let storage = InMemoryStorage::new();
374 let namespace = "test".to_string();
375
376 storage.ensure_namespace(&namespace).await.unwrap();
377
378 let past_timestamp = std::time::SystemTime::now()
380 .duration_since(std::time::UNIX_EPOCH)
381 .unwrap()
382 .as_secs()
383 - 100; storage
386 .upsert(
387 &namespace,
388 vec![
389 Vector {
390 id: "expired".to_string(),
391 values: vec![1.0, 2.0],
392 metadata: None,
393 ttl_seconds: None,
394 expires_at: Some(past_timestamp),
395 },
396 Vector {
397 id: "valid".to_string(),
398 values: vec![3.0, 4.0],
399 metadata: None,
400 ttl_seconds: None,
401 expires_at: None,
402 },
403 ],
404 )
405 .await
406 .unwrap();
407
408 let retrieved = storage
410 .get(&namespace, &["expired".to_string(), "valid".to_string()])
411 .await
412 .unwrap();
413 assert_eq!(retrieved.len(), 1);
414 assert_eq!(retrieved[0].id, "valid");
415
416 let all = storage.get_all(&namespace).await.unwrap();
418 assert_eq!(all.len(), 1);
419 assert_eq!(all[0].id, "valid");
420
421 let count = storage.count(&namespace).await.unwrap();
423 assert_eq!(count, 1);
424 }
425
426 #[tokio::test]
427 async fn test_ttl_applied_on_upsert() {
428 let storage = InMemoryStorage::new();
429 let namespace = "test".to_string();
430
431 storage.ensure_namespace(&namespace).await.unwrap();
432
433 storage
435 .upsert(
436 &namespace,
437 vec![Vector {
438 id: "with_ttl".to_string(),
439 values: vec![1.0, 2.0],
440 metadata: None,
441 ttl_seconds: Some(3600), expires_at: None,
443 }],
444 )
445 .await
446 .unwrap();
447
448 let namespaces = storage.namespaces.read();
450 let ns = namespaces.get(&namespace).unwrap();
451 let vector = ns.get("with_ttl").unwrap();
452 assert!(vector.expires_at.is_some());
453
454 let now = std::time::SystemTime::now()
456 .duration_since(std::time::UNIX_EPOCH)
457 .unwrap()
458 .as_secs();
459 let expires_at = vector.expires_at.unwrap();
460 assert!(expires_at > now);
461 assert!(expires_at <= now + 3601); }
463
464 #[tokio::test]
465 async fn test_cleanup_expired() {
466 let storage = InMemoryStorage::new();
467 let namespace = "test".to_string();
468
469 storage.ensure_namespace(&namespace).await.unwrap();
470
471 let past_timestamp = std::time::SystemTime::now()
472 .duration_since(std::time::UNIX_EPOCH)
473 .unwrap()
474 .as_secs()
475 - 100;
476
477 storage
478 .upsert(
479 &namespace,
480 vec![
481 Vector {
482 id: "expired1".to_string(),
483 values: vec![1.0],
484 metadata: None,
485 ttl_seconds: None,
486 expires_at: Some(past_timestamp),
487 },
488 Vector {
489 id: "expired2".to_string(),
490 values: vec![2.0],
491 metadata: None,
492 ttl_seconds: None,
493 expires_at: Some(past_timestamp),
494 },
495 Vector {
496 id: "valid".to_string(),
497 values: vec![3.0],
498 metadata: None,
499 ttl_seconds: None,
500 expires_at: None,
501 },
502 ],
503 )
504 .await
505 .unwrap();
506
507 {
509 let namespaces = storage.namespaces.read();
510 let ns = namespaces.get(&namespace).unwrap();
511 assert_eq!(ns.len(), 3);
512 }
513
514 let removed = storage.cleanup_expired(&namespace).await.unwrap();
516 assert_eq!(removed, 2);
517
518 {
520 let namespaces = storage.namespaces.read();
521 let ns = namespaces.get(&namespace).unwrap();
522 assert_eq!(ns.len(), 1);
523 assert!(ns.contains_key("valid"));
524 }
525 }
526
527 #[tokio::test]
528 async fn test_cleanup_all_expired() {
529 let storage = InMemoryStorage::new();
530 let ns1 = "test1".to_string();
531 let ns2 = "test2".to_string();
532
533 storage.ensure_namespace(&ns1).await.unwrap();
534 storage.ensure_namespace(&ns2).await.unwrap();
535
536 let past_timestamp = std::time::SystemTime::now()
537 .duration_since(std::time::UNIX_EPOCH)
538 .unwrap()
539 .as_secs()
540 - 100;
541
542 storage
544 .upsert(
545 &ns1,
546 vec![Vector {
547 id: "expired".to_string(),
548 values: vec![1.0],
549 metadata: None,
550 ttl_seconds: None,
551 expires_at: Some(past_timestamp),
552 }],
553 )
554 .await
555 .unwrap();
556
557 storage
558 .upsert(
559 &ns2,
560 vec![
561 Vector {
562 id: "expired".to_string(),
563 values: vec![2.0],
564 metadata: None,
565 ttl_seconds: None,
566 expires_at: Some(past_timestamp),
567 },
568 Vector {
569 id: "valid".to_string(),
570 values: vec![3.0],
571 metadata: None,
572 ttl_seconds: None,
573 expires_at: None,
574 },
575 ],
576 )
577 .await
578 .unwrap();
579
580 let removed = storage.cleanup_all_expired().await.unwrap();
582 assert_eq!(removed, 2);
583
584 {
586 let namespaces = storage.namespaces.read();
587 assert_eq!(namespaces.get(&ns1).unwrap().len(), 0);
588 assert_eq!(namespaces.get(&ns2).unwrap().len(), 1);
589 }
590 }
591}