Skip to main content

clasp_registry/
store.rs

1//! Entity storage trait and in-memory implementation
2
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::RwLock;
6
7use crate::entity::{Entity, EntityId, EntityStatus};
8use crate::error::{RegistryError, Result};
9
10/// Storage backend for entities
11#[async_trait]
12pub trait EntityStore: Send + Sync {
13    /// Create a new entity
14    async fn create(&self, entity: &Entity) -> Result<()>;
15
16    /// Get an entity by ID
17    async fn get(&self, id: &EntityId) -> Result<Option<Entity>>;
18
19    /// Find an entity by its public key
20    async fn find_by_public_key(&self, key: &[u8]) -> Result<Option<Entity>>;
21
22    /// Find entities by tag
23    async fn find_by_tag(&self, tag: &str) -> Result<Vec<Entity>>;
24
25    /// Find entities by namespace pattern
26    async fn find_by_namespace(&self, namespace: &str) -> Result<Vec<Entity>>;
27
28    /// List entities with pagination
29    async fn list(&self, offset: usize, limit: usize) -> Result<Vec<Entity>>;
30
31    /// Update an entity
32    async fn update(&self, entity: &Entity) -> Result<()>;
33
34    /// Update entity status
35    async fn update_status(&self, id: &EntityId, status: EntityStatus) -> Result<()>;
36
37    /// Delete an entity
38    async fn delete(&self, id: &EntityId) -> Result<bool>;
39
40    /// Count total entities
41    async fn count(&self) -> Result<usize>;
42}
43
44/// In-memory entity store for development and testing
45pub struct MemoryEntityStore {
46    entities: RwLock<HashMap<String, Entity>>,
47}
48
49impl MemoryEntityStore {
50    pub fn new() -> Self {
51        Self {
52            entities: RwLock::new(HashMap::new()),
53        }
54    }
55}
56
57impl Default for MemoryEntityStore {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63#[async_trait]
64impl EntityStore for MemoryEntityStore {
65    async fn create(&self, entity: &Entity) -> Result<()> {
66        let mut entities = self.entities.write().unwrap();
67        let key = entity.id.as_str().to_string();
68        if entities.contains_key(&key) {
69            return Err(RegistryError::AlreadyExists(key));
70        }
71        entities.insert(key, entity.clone());
72        Ok(())
73    }
74
75    async fn get(&self, id: &EntityId) -> Result<Option<Entity>> {
76        let entities = self.entities.read().unwrap();
77        Ok(entities.get(id.as_str()).cloned())
78    }
79
80    async fn find_by_public_key(&self, key: &[u8]) -> Result<Option<Entity>> {
81        let entities = self.entities.read().unwrap();
82        Ok(entities.values().find(|e| e.public_key == key).cloned())
83    }
84
85    async fn find_by_tag(&self, tag: &str) -> Result<Vec<Entity>> {
86        let entities = self.entities.read().unwrap();
87        Ok(entities
88            .values()
89            .filter(|e| e.tags.iter().any(|t| t == tag))
90            .cloned()
91            .collect())
92    }
93
94    async fn find_by_namespace(&self, namespace: &str) -> Result<Vec<Entity>> {
95        let entities = self.entities.read().unwrap();
96        Ok(entities
97            .values()
98            .filter(|e| e.namespaces.iter().any(|ns| ns == namespace))
99            .cloned()
100            .collect())
101    }
102
103    async fn list(&self, offset: usize, limit: usize) -> Result<Vec<Entity>> {
104        let entities = self.entities.read().unwrap();
105        Ok(entities
106            .values()
107            .skip(offset)
108            .take(limit)
109            .cloned()
110            .collect())
111    }
112
113    async fn update(&self, entity: &Entity) -> Result<()> {
114        let mut entities = self.entities.write().unwrap();
115        let key = entity.id.as_str().to_string();
116        if !entities.contains_key(&key) {
117            return Err(RegistryError::NotFound(key));
118        }
119        entities.insert(key, entity.clone());
120        Ok(())
121    }
122
123    async fn update_status(&self, id: &EntityId, status: EntityStatus) -> Result<()> {
124        let mut entities = self.entities.write().unwrap();
125        let key = id.as_str().to_string();
126        match entities.get_mut(&key) {
127            Some(entity) => {
128                entity.status = status;
129                Ok(())
130            }
131            None => Err(RegistryError::NotFound(key)),
132        }
133    }
134
135    async fn delete(&self, id: &EntityId) -> Result<bool> {
136        let mut entities = self.entities.write().unwrap();
137        Ok(entities.remove(id.as_str()).is_some())
138    }
139
140    async fn count(&self) -> Result<usize> {
141        Ok(self.entities.read().unwrap().len())
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use crate::entity::{EntityKeypair, EntityType};
149
150    fn create_test_entity(name: &str) -> Entity {
151        let keypair = EntityKeypair::generate().unwrap();
152        keypair.to_entity(EntityType::Device, name.to_string())
153    }
154
155    #[tokio::test]
156    async fn test_memory_store_create_get() {
157        let store = MemoryEntityStore::new();
158        let entity = create_test_entity("test-device");
159
160        store.create(&entity).await.unwrap();
161
162        let found = store.get(&entity.id).await.unwrap();
163        assert!(found.is_some());
164        assert_eq!(found.unwrap().name, "test-device");
165    }
166
167    #[tokio::test]
168    async fn test_memory_store_duplicate() {
169        let store = MemoryEntityStore::new();
170        let entity = create_test_entity("test-device");
171
172        store.create(&entity).await.unwrap();
173        assert!(store.create(&entity).await.is_err());
174    }
175
176    #[tokio::test]
177    async fn test_memory_store_find_by_key() {
178        let store = MemoryEntityStore::new();
179        let entity = create_test_entity("test-device");
180        let key = entity.public_key.clone();
181
182        store.create(&entity).await.unwrap();
183
184        let found = store.find_by_public_key(&key).await.unwrap();
185        assert!(found.is_some());
186        assert_eq!(found.unwrap().id, entity.id);
187    }
188
189    #[tokio::test]
190    async fn test_memory_store_find_by_tag() {
191        let store = MemoryEntityStore::new();
192        let mut entity = create_test_entity("test-device");
193        entity.tags = vec!["lighting".to_string(), "dmx".to_string()];
194
195        store.create(&entity).await.unwrap();
196
197        let found = store.find_by_tag("lighting").await.unwrap();
198        assert_eq!(found.len(), 1);
199
200        let found = store.find_by_tag("audio").await.unwrap();
201        assert!(found.is_empty());
202    }
203
204    #[tokio::test]
205    async fn test_memory_store_update_status() {
206        let store = MemoryEntityStore::new();
207        let entity = create_test_entity("test-device");
208        let id = entity.id.clone();
209
210        store.create(&entity).await.unwrap();
211        store
212            .update_status(&id, EntityStatus::Suspended)
213            .await
214            .unwrap();
215
216        let found = store.get(&id).await.unwrap().unwrap();
217        assert_eq!(found.status, EntityStatus::Suspended);
218    }
219
220    #[tokio::test]
221    async fn test_memory_store_delete() {
222        let store = MemoryEntityStore::new();
223        let entity = create_test_entity("test-device");
224        let id = entity.id.clone();
225
226        store.create(&entity).await.unwrap();
227        assert!(store.delete(&id).await.unwrap());
228        assert!(store.get(&id).await.unwrap().is_none());
229        assert!(!store.delete(&id).await.unwrap());
230    }
231
232    #[tokio::test]
233    async fn test_memory_store_list() {
234        let store = MemoryEntityStore::new();
235
236        for i in 0..5 {
237            let entity = create_test_entity(&format!("device-{}", i));
238            store.create(&entity).await.unwrap();
239        }
240
241        assert_eq!(store.count().await.unwrap(), 5);
242
243        let page1 = store.list(0, 3).await.unwrap();
244        assert_eq!(page1.len(), 3);
245
246        let page2 = store.list(3, 3).await.unwrap();
247        assert_eq!(page2.len(), 2);
248    }
249}