1use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::RwLock;
6
7use crate::entity::{Entity, EntityId, EntityStatus};
8use crate::error::{RegistryError, Result};
9
10#[async_trait]
12pub trait EntityStore: Send + Sync {
13 async fn create(&self, entity: &Entity) -> Result<()>;
15
16 async fn get(&self, id: &EntityId) -> Result<Option<Entity>>;
18
19 async fn find_by_public_key(&self, key: &[u8]) -> Result<Option<Entity>>;
21
22 async fn find_by_tag(&self, tag: &str) -> Result<Vec<Entity>>;
24
25 async fn find_by_namespace(&self, namespace: &str) -> Result<Vec<Entity>>;
27
28 async fn list(&self, offset: usize, limit: usize) -> Result<Vec<Entity>>;
30
31 async fn update(&self, entity: &Entity) -> Result<()>;
33
34 async fn update_status(&self, id: &EntityId, status: EntityStatus) -> Result<()>;
36
37 async fn delete(&self, id: &EntityId) -> Result<bool>;
39
40 async fn count(&self) -> Result<usize>;
42}
43
44pub struct MemoryEntityStore {
46 entities: RwLock<HashMap<String, Entity>>,
47}
48
49impl MemoryEntityStore {
50 pub fn new() -> Self {
51 Self {
52 entities: RwLock::new(HashMap::new()),
53 }
54 }
55}
56
57impl Default for MemoryEntityStore {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63#[async_trait]
64impl EntityStore for MemoryEntityStore {
65 async fn create(&self, entity: &Entity) -> Result<()> {
66 let mut entities = self.entities.write().unwrap();
67 let key = entity.id.as_str().to_string();
68 if entities.contains_key(&key) {
69 return Err(RegistryError::AlreadyExists(key));
70 }
71 entities.insert(key, entity.clone());
72 Ok(())
73 }
74
75 async fn get(&self, id: &EntityId) -> Result<Option<Entity>> {
76 let entities = self.entities.read().unwrap();
77 Ok(entities.get(id.as_str()).cloned())
78 }
79
80 async fn find_by_public_key(&self, key: &[u8]) -> Result<Option<Entity>> {
81 let entities = self.entities.read().unwrap();
82 Ok(entities.values().find(|e| e.public_key == key).cloned())
83 }
84
85 async fn find_by_tag(&self, tag: &str) -> Result<Vec<Entity>> {
86 let entities = self.entities.read().unwrap();
87 Ok(entities
88 .values()
89 .filter(|e| e.tags.iter().any(|t| t == tag))
90 .cloned()
91 .collect())
92 }
93
94 async fn find_by_namespace(&self, namespace: &str) -> Result<Vec<Entity>> {
95 let entities = self.entities.read().unwrap();
96 Ok(entities
97 .values()
98 .filter(|e| e.namespaces.iter().any(|ns| ns == namespace))
99 .cloned()
100 .collect())
101 }
102
103 async fn list(&self, offset: usize, limit: usize) -> Result<Vec<Entity>> {
104 let entities = self.entities.read().unwrap();
105 Ok(entities
106 .values()
107 .skip(offset)
108 .take(limit)
109 .cloned()
110 .collect())
111 }
112
113 async fn update(&self, entity: &Entity) -> Result<()> {
114 let mut entities = self.entities.write().unwrap();
115 let key = entity.id.as_str().to_string();
116 if !entities.contains_key(&key) {
117 return Err(RegistryError::NotFound(key));
118 }
119 entities.insert(key, entity.clone());
120 Ok(())
121 }
122
123 async fn update_status(&self, id: &EntityId, status: EntityStatus) -> Result<()> {
124 let mut entities = self.entities.write().unwrap();
125 let key = id.as_str().to_string();
126 match entities.get_mut(&key) {
127 Some(entity) => {
128 entity.status = status;
129 Ok(())
130 }
131 None => Err(RegistryError::NotFound(key)),
132 }
133 }
134
135 async fn delete(&self, id: &EntityId) -> Result<bool> {
136 let mut entities = self.entities.write().unwrap();
137 Ok(entities.remove(id.as_str()).is_some())
138 }
139
140 async fn count(&self) -> Result<usize> {
141 Ok(self.entities.read().unwrap().len())
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use crate::entity::{EntityKeypair, EntityType};
149
150 fn create_test_entity(name: &str) -> Entity {
151 let keypair = EntityKeypair::generate().unwrap();
152 keypair.to_entity(EntityType::Device, name.to_string())
153 }
154
155 #[tokio::test]
156 async fn test_memory_store_create_get() {
157 let store = MemoryEntityStore::new();
158 let entity = create_test_entity("test-device");
159
160 store.create(&entity).await.unwrap();
161
162 let found = store.get(&entity.id).await.unwrap();
163 assert!(found.is_some());
164 assert_eq!(found.unwrap().name, "test-device");
165 }
166
167 #[tokio::test]
168 async fn test_memory_store_duplicate() {
169 let store = MemoryEntityStore::new();
170 let entity = create_test_entity("test-device");
171
172 store.create(&entity).await.unwrap();
173 assert!(store.create(&entity).await.is_err());
174 }
175
176 #[tokio::test]
177 async fn test_memory_store_find_by_key() {
178 let store = MemoryEntityStore::new();
179 let entity = create_test_entity("test-device");
180 let key = entity.public_key.clone();
181
182 store.create(&entity).await.unwrap();
183
184 let found = store.find_by_public_key(&key).await.unwrap();
185 assert!(found.is_some());
186 assert_eq!(found.unwrap().id, entity.id);
187 }
188
189 #[tokio::test]
190 async fn test_memory_store_find_by_tag() {
191 let store = MemoryEntityStore::new();
192 let mut entity = create_test_entity("test-device");
193 entity.tags = vec!["lighting".to_string(), "dmx".to_string()];
194
195 store.create(&entity).await.unwrap();
196
197 let found = store.find_by_tag("lighting").await.unwrap();
198 assert_eq!(found.len(), 1);
199
200 let found = store.find_by_tag("audio").await.unwrap();
201 assert!(found.is_empty());
202 }
203
204 #[tokio::test]
205 async fn test_memory_store_update_status() {
206 let store = MemoryEntityStore::new();
207 let entity = create_test_entity("test-device");
208 let id = entity.id.clone();
209
210 store.create(&entity).await.unwrap();
211 store
212 .update_status(&id, EntityStatus::Suspended)
213 .await
214 .unwrap();
215
216 let found = store.get(&id).await.unwrap().unwrap();
217 assert_eq!(found.status, EntityStatus::Suspended);
218 }
219
220 #[tokio::test]
221 async fn test_memory_store_delete() {
222 let store = MemoryEntityStore::new();
223 let entity = create_test_entity("test-device");
224 let id = entity.id.clone();
225
226 store.create(&entity).await.unwrap();
227 assert!(store.delete(&id).await.unwrap());
228 assert!(store.get(&id).await.unwrap().is_none());
229 assert!(!store.delete(&id).await.unwrap());
230 }
231
232 #[tokio::test]
233 async fn test_memory_store_list() {
234 let store = MemoryEntityStore::new();
235
236 for i in 0..5 {
237 let entity = create_test_entity(&format!("device-{}", i));
238 store.create(&entity).await.unwrap();
239 }
240
241 assert_eq!(store.count().await.unwrap(), 5);
242
243 let page1 = store.list(0, 3).await.unwrap();
244 assert_eq!(page1.len(), 3);
245
246 let page2 = store.list(3, 3).await.unwrap();
247 assert_eq!(page2.len(), 2);
248 }
249}