Skip to main content

storage/
memory.rs

1use async_trait::async_trait;
2use common::{DakeraError, NamespaceId, Result, Vector, VectorId};
3use parking_lot::RwLock;
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::traits::VectorStorage;
8
9/// In-memory storage implementation for Phase 1
10/// Uses RwLock for concurrent access, wrapped in Arc for Clone support
11#[derive(Clone)]
12pub struct InMemoryStorage {
13    /// namespace -> (vector_id -> vector)
14    namespaces: Arc<RwLock<HashMap<NamespaceId, HashMap<VectorId, Vector>>>>,
15}
16
17impl InMemoryStorage {
18    pub fn new() -> Self {
19        Self {
20            namespaces: Arc::new(RwLock::new(HashMap::new())),
21        }
22    }
23}
24
25impl Default for InMemoryStorage {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31/// Capture current Unix timestamp in seconds — used to avoid N syscalls per filter loop.
32fn now_secs() -> u64 {
33    std::time::SystemTime::now()
34        .duration_since(std::time::UNIX_EPOCH)
35        .unwrap_or_default()
36        .as_secs()
37}
38
39#[async_trait]
40impl VectorStorage for InMemoryStorage {
41    async fn upsert(&self, namespace: &NamespaceId, vectors: Vec<Vector>) -> Result<usize> {
42        let mut namespaces = self.namespaces.write();
43        let ns = namespaces.entry(namespace.clone()).or_default();
44
45        // Validate dimensions: check against existing vectors OR ensure intra-batch consistency
46        if !vectors.is_empty() {
47            let now = now_secs();
48            let expected_dim = ns
49                .values()
50                .find(|v| !v.is_expired_at(now))
51                .map(|v| v.values.len())
52                .unwrap_or_else(|| vectors[0].values.len());
53            for v in &vectors {
54                if v.values.len() != expected_dim {
55                    return Err(DakeraError::DimensionMismatch {
56                        expected: expected_dim,
57                        actual: v.values.len(),
58                    });
59                }
60            }
61        }
62
63        let count = vectors.len();
64        for mut vector in vectors {
65            // Apply TTL if specified
66            vector.apply_ttl();
67            ns.insert(vector.id.clone(), vector);
68        }
69
70        tracing::debug!(
71            namespace = %namespace,
72            count = count,
73            "Upserted vectors"
74        );
75
76        Ok(count)
77    }
78
79    async fn get(&self, namespace: &NamespaceId, ids: &[VectorId]) -> Result<Vec<Vector>> {
80        let namespaces = self.namespaces.read();
81        let ns = namespaces
82            .get(namespace)
83            .ok_or_else(|| DakeraError::NamespaceNotFound(namespace.clone()))?;
84
85        // Filter out expired vectors
86        let now = now_secs();
87        Ok(ids
88            .iter()
89            .filter_map(|id| ns.get(id).cloned())
90            .filter(|v| !v.is_expired_at(now))
91            .collect())
92    }
93
94    async fn get_all(&self, namespace: &NamespaceId) -> Result<Vec<Vector>> {
95        let namespaces = self.namespaces.read();
96        let ns = namespaces
97            .get(namespace)
98            .ok_or_else(|| DakeraError::NamespaceNotFound(namespace.clone()))?;
99
100        // Filter out expired vectors
101        let now = now_secs();
102        Ok(ns
103            .values()
104            .filter(|v| !v.is_expired_at(now))
105            .cloned()
106            .collect())
107    }
108
109    async fn delete(&self, namespace: &NamespaceId, ids: &[VectorId]) -> Result<usize> {
110        let mut namespaces = self.namespaces.write();
111        let ns = namespaces
112            .get_mut(namespace)
113            .ok_or_else(|| DakeraError::NamespaceNotFound(namespace.clone()))?;
114
115        let mut deleted = 0;
116        for id in ids {
117            if ns.remove(id).is_some() {
118                deleted += 1;
119            }
120        }
121
122        tracing::debug!(
123            namespace = %namespace,
124            deleted = deleted,
125            "Deleted vectors"
126        );
127
128        Ok(deleted)
129    }
130
131    async fn namespace_exists(&self, namespace: &NamespaceId) -> Result<bool> {
132        Ok(self.namespaces.read().contains_key(namespace))
133    }
134
135    async fn ensure_namespace(&self, namespace: &NamespaceId) -> Result<()> {
136        self.namespaces
137            .write()
138            .entry(namespace.clone())
139            .or_default();
140        Ok(())
141    }
142
143    async fn count(&self, namespace: &NamespaceId) -> Result<usize> {
144        let namespaces = self.namespaces.read();
145        let now = now_secs();
146        Ok(namespaces
147            .get(namespace)
148            .map(|ns| ns.values().filter(|v| !v.is_expired_at(now)).count())
149            .unwrap_or(0))
150    }
151
152    async fn dimension(&self, namespace: &NamespaceId) -> Result<Option<usize>> {
153        let namespaces = self.namespaces.read();
154        let now = now_secs();
155        Ok(namespaces
156            .get(namespace)
157            .and_then(|ns| ns.values().find(|v| !v.is_expired_at(now)))
158            .map(|v| v.values.len()))
159    }
160
161    async fn list_namespaces(&self) -> Result<Vec<NamespaceId>> {
162        let namespaces = self.namespaces.read();
163        Ok(namespaces.keys().cloned().collect())
164    }
165
166    async fn delete_namespace(&self, namespace: &NamespaceId) -> Result<bool> {
167        let mut namespaces = self.namespaces.write();
168        let existed = namespaces.remove(namespace).is_some();
169
170        if existed {
171            tracing::debug!(
172                namespace = %namespace,
173                "Deleted namespace"
174            );
175        }
176
177        Ok(existed)
178    }
179
180    async fn cleanup_expired(&self, namespace: &NamespaceId) -> Result<usize> {
181        let mut namespaces = self.namespaces.write();
182        let ns = match namespaces.get_mut(namespace) {
183            Some(ns) => ns,
184            None => return Ok(0),
185        };
186
187        let now = now_secs();
188        let before_count = ns.len();
189        ns.retain(|_, v| !v.is_expired_at(now));
190        let removed = before_count - ns.len();
191
192        if removed > 0 {
193            tracing::debug!(
194                namespace = %namespace,
195                removed = removed,
196                "Cleaned up expired vectors"
197            );
198        }
199
200        Ok(removed)
201    }
202
203    async fn cleanup_all_expired(&self) -> Result<usize> {
204        let mut namespaces = self.namespaces.write();
205        let mut total_removed = 0;
206
207        let now = now_secs();
208        for (namespace, ns) in namespaces.iter_mut() {
209            let before_count = ns.len();
210            ns.retain(|_, v| !v.is_expired_at(now));
211            let removed = before_count - ns.len();
212            total_removed += removed;
213
214            if removed > 0 {
215                tracing::debug!(
216                    namespace = %namespace,
217                    removed = removed,
218                    "Cleaned up expired vectors"
219                );
220            }
221        }
222
223        if total_removed > 0 {
224            tracing::info!(
225                total_removed = total_removed,
226                "Cleaned up expired vectors across all namespaces"
227            );
228        }
229
230        Ok(total_removed)
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[tokio::test]
239    async fn test_upsert_and_get() {
240        let storage = InMemoryStorage::new();
241        let namespace = "test".to_string();
242
243        let vectors = vec![Vector {
244            id: "v1".to_string(),
245            values: vec![1.0, 2.0, 3.0],
246            metadata: None,
247            ttl_seconds: None,
248            expires_at: None,
249        }];
250
251        storage.ensure_namespace(&namespace).await.unwrap();
252        let count = storage.upsert(&namespace, vectors).await.unwrap();
253        assert_eq!(count, 1);
254
255        let retrieved = storage.get(&namespace, &["v1".to_string()]).await.unwrap();
256        assert_eq!(retrieved.len(), 1);
257        assert_eq!(retrieved[0].id, "v1");
258    }
259
260    #[tokio::test]
261    async fn test_dimension_mismatch() {
262        let storage = InMemoryStorage::new();
263        let namespace = "test".to_string();
264
265        storage.ensure_namespace(&namespace).await.unwrap();
266
267        // First vector sets dimension
268        storage
269            .upsert(
270                &namespace,
271                vec![Vector {
272                    id: "v1".to_string(),
273                    values: vec![1.0, 2.0, 3.0],
274                    metadata: None,
275                    ttl_seconds: None,
276                    expires_at: None,
277                }],
278            )
279            .await
280            .unwrap();
281
282        // Second vector has wrong dimension
283        let result = storage
284            .upsert(
285                &namespace,
286                vec![Vector {
287                    id: "v2".to_string(),
288                    values: vec![1.0, 2.0],
289                    metadata: None,
290                    ttl_seconds: None,
291                    expires_at: None,
292                }],
293            )
294            .await;
295
296        assert!(matches!(result, Err(DakeraError::DimensionMismatch { .. })));
297    }
298
299    #[tokio::test]
300    async fn test_delete() {
301        let storage = InMemoryStorage::new();
302        let namespace = "test".to_string();
303
304        storage.ensure_namespace(&namespace).await.unwrap();
305        storage
306            .upsert(
307                &namespace,
308                vec![
309                    Vector {
310                        id: "v1".to_string(),
311                        values: vec![1.0],
312                        metadata: None,
313                        ttl_seconds: None,
314                        expires_at: None,
315                    },
316                    Vector {
317                        id: "v2".to_string(),
318                        values: vec![2.0],
319                        metadata: None,
320                        ttl_seconds: None,
321                        expires_at: None,
322                    },
323                ],
324            )
325            .await
326            .unwrap();
327
328        let deleted = storage
329            .delete(&namespace, &["v1".to_string()])
330            .await
331            .unwrap();
332        assert_eq!(deleted, 1);
333
334        let count = storage.count(&namespace).await.unwrap();
335        assert_eq!(count, 1);
336    }
337
338    #[tokio::test]
339    async fn test_get_all() {
340        let storage = InMemoryStorage::new();
341        let namespace = "test".to_string();
342
343        storage.ensure_namespace(&namespace).await.unwrap();
344        storage
345            .upsert(
346                &namespace,
347                vec![
348                    Vector {
349                        id: "v1".to_string(),
350                        values: vec![1.0, 2.0],
351                        metadata: None,
352                        ttl_seconds: None,
353                        expires_at: None,
354                    },
355                    Vector {
356                        id: "v2".to_string(),
357                        values: vec![3.0, 4.0],
358                        metadata: None,
359                        ttl_seconds: None,
360                        expires_at: None,
361                    },
362                ],
363            )
364            .await
365            .unwrap();
366
367        let all = storage.get_all(&namespace).await.unwrap();
368        assert_eq!(all.len(), 2);
369    }
370
371    #[tokio::test]
372    async fn test_ttl_expired_vectors_filtered() {
373        let storage = InMemoryStorage::new();
374        let namespace = "test".to_string();
375
376        storage.ensure_namespace(&namespace).await.unwrap();
377
378        // Insert a vector that's already expired (expires_at in the past)
379        let past_timestamp = std::time::SystemTime::now()
380            .duration_since(std::time::UNIX_EPOCH)
381            .unwrap()
382            .as_secs()
383            - 100; // 100 seconds ago
384
385        storage
386            .upsert(
387                &namespace,
388                vec![
389                    Vector {
390                        id: "expired".to_string(),
391                        values: vec![1.0, 2.0],
392                        metadata: None,
393                        ttl_seconds: None,
394                        expires_at: Some(past_timestamp),
395                    },
396                    Vector {
397                        id: "valid".to_string(),
398                        values: vec![3.0, 4.0],
399                        metadata: None,
400                        ttl_seconds: None,
401                        expires_at: None,
402                    },
403                ],
404            )
405            .await
406            .unwrap();
407
408        // get() should only return the valid vector
409        let retrieved = storage
410            .get(&namespace, &["expired".to_string(), "valid".to_string()])
411            .await
412            .unwrap();
413        assert_eq!(retrieved.len(), 1);
414        assert_eq!(retrieved[0].id, "valid");
415
416        // get_all() should only return the valid vector
417        let all = storage.get_all(&namespace).await.unwrap();
418        assert_eq!(all.len(), 1);
419        assert_eq!(all[0].id, "valid");
420
421        // count() should only count the valid vector
422        let count = storage.count(&namespace).await.unwrap();
423        assert_eq!(count, 1);
424    }
425
426    #[tokio::test]
427    async fn test_ttl_applied_on_upsert() {
428        let storage = InMemoryStorage::new();
429        let namespace = "test".to_string();
430
431        storage.ensure_namespace(&namespace).await.unwrap();
432
433        // Insert a vector with TTL
434        storage
435            .upsert(
436                &namespace,
437                vec![Vector {
438                    id: "with_ttl".to_string(),
439                    values: vec![1.0, 2.0],
440                    metadata: None,
441                    ttl_seconds: Some(3600), // 1 hour TTL
442                    expires_at: None,
443                }],
444            )
445            .await
446            .unwrap();
447
448        // The vector should have expires_at set
449        let namespaces = storage.namespaces.read();
450        let ns = namespaces.get(&namespace).unwrap();
451        let vector = ns.get("with_ttl").unwrap();
452        assert!(vector.expires_at.is_some());
453
454        // expires_at should be roughly 1 hour from now
455        let now = std::time::SystemTime::now()
456            .duration_since(std::time::UNIX_EPOCH)
457            .unwrap()
458            .as_secs();
459        let expires_at = vector.expires_at.unwrap();
460        assert!(expires_at > now);
461        assert!(expires_at <= now + 3601); // Allow 1 second tolerance
462    }
463
464    #[tokio::test]
465    async fn test_cleanup_expired() {
466        let storage = InMemoryStorage::new();
467        let namespace = "test".to_string();
468
469        storage.ensure_namespace(&namespace).await.unwrap();
470
471        let past_timestamp = std::time::SystemTime::now()
472            .duration_since(std::time::UNIX_EPOCH)
473            .unwrap()
474            .as_secs()
475            - 100;
476
477        storage
478            .upsert(
479                &namespace,
480                vec![
481                    Vector {
482                        id: "expired1".to_string(),
483                        values: vec![1.0],
484                        metadata: None,
485                        ttl_seconds: None,
486                        expires_at: Some(past_timestamp),
487                    },
488                    Vector {
489                        id: "expired2".to_string(),
490                        values: vec![2.0],
491                        metadata: None,
492                        ttl_seconds: None,
493                        expires_at: Some(past_timestamp),
494                    },
495                    Vector {
496                        id: "valid".to_string(),
497                        values: vec![3.0],
498                        metadata: None,
499                        ttl_seconds: None,
500                        expires_at: None,
501                    },
502                ],
503            )
504            .await
505            .unwrap();
506
507        // Before cleanup, raw storage has 3 vectors
508        {
509            let namespaces = storage.namespaces.read();
510            let ns = namespaces.get(&namespace).unwrap();
511            assert_eq!(ns.len(), 3);
512        }
513
514        // Cleanup should remove 2 expired vectors
515        let removed = storage.cleanup_expired(&namespace).await.unwrap();
516        assert_eq!(removed, 2);
517
518        // After cleanup, storage has 1 vector
519        {
520            let namespaces = storage.namespaces.read();
521            let ns = namespaces.get(&namespace).unwrap();
522            assert_eq!(ns.len(), 1);
523            assert!(ns.contains_key("valid"));
524        }
525    }
526
527    #[tokio::test]
528    async fn test_cleanup_all_expired() {
529        let storage = InMemoryStorage::new();
530        let ns1 = "test1".to_string();
531        let ns2 = "test2".to_string();
532
533        storage.ensure_namespace(&ns1).await.unwrap();
534        storage.ensure_namespace(&ns2).await.unwrap();
535
536        let past_timestamp = std::time::SystemTime::now()
537            .duration_since(std::time::UNIX_EPOCH)
538            .unwrap()
539            .as_secs()
540            - 100;
541
542        // Add expired vectors to both namespaces
543        storage
544            .upsert(
545                &ns1,
546                vec![Vector {
547                    id: "expired".to_string(),
548                    values: vec![1.0],
549                    metadata: None,
550                    ttl_seconds: None,
551                    expires_at: Some(past_timestamp),
552                }],
553            )
554            .await
555            .unwrap();
556
557        storage
558            .upsert(
559                &ns2,
560                vec![
561                    Vector {
562                        id: "expired".to_string(),
563                        values: vec![2.0],
564                        metadata: None,
565                        ttl_seconds: None,
566                        expires_at: Some(past_timestamp),
567                    },
568                    Vector {
569                        id: "valid".to_string(),
570                        values: vec![3.0],
571                        metadata: None,
572                        ttl_seconds: None,
573                        expires_at: None,
574                    },
575                ],
576            )
577            .await
578            .unwrap();
579
580        // Cleanup all should remove 2 expired vectors total
581        let removed = storage.cleanup_all_expired().await.unwrap();
582        assert_eq!(removed, 2);
583
584        // ns1 should be empty, ns2 should have 1 valid vector
585        {
586            let namespaces = storage.namespaces.read();
587            assert_eq!(namespaces.get(&ns1).unwrap().len(), 0);
588            assert_eq!(namespaces.get(&ns2).unwrap().len(), 1);
589        }
590    }
591}