1use crate::domain::{Edge, EdgeId, GraphQuery, Node, NodeId};
2use crate::storage::{EdgeDirection, GraphStorage, Result, StorageError, SearchQuery, SearchResults, OrderBy, OrderDirection};
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8#[derive(Clone)]
9pub struct InMemoryStorage {
10 nodes: Arc<RwLock<HashMap<NodeId, Node>>>,
11 edges: Arc<RwLock<HashMap<EdgeId, Edge>>>,
12}
13
14impl InMemoryStorage {
15 pub fn new() -> Self {
16 Self {
17 nodes: Arc::new(RwLock::new(HashMap::new())),
18 edges: Arc::new(RwLock::new(HashMap::new())),
19 }
20 }
21
22 fn matches_query(node: &Node, query: &GraphQuery) -> bool {
23 if let Some(ref types) = query.node_types {
25 if !types.contains(&node.node_type) {
26 return false;
27 }
28 }
29
30 if let Some(ref filters) = query.property_filters {
32 for (key, expected_value) in filters {
33 match node.properties.get(key) {
34 Some(actual_value) if actual_value == expected_value => continue,
35 _ => return false,
36 }
37 }
38 }
39
40 true
41 }
42
43 fn matches_search_query(node: &Node, query: &SearchQuery) -> bool {
44 if !query.node_types.is_empty() && !query.node_types.contains(&node.node_type) {
46 return false;
47 }
48
49 if let Some(ref search_text) = query.search_text {
51 let search_lower = search_text.to_lowercase();
52 let node_text = serde_json::to_string(&node.properties).unwrap_or_default().to_lowercase();
53 if !node_text.contains(&search_lower) {
54 return false;
55 }
56 }
57
58 if let Some(after) = query.created_after {
60 if node.created_at < after {
61 return false;
62 }
63 }
64 if let Some(before) = query.created_before {
65 if node.created_at > before {
66 return false;
67 }
68 }
69
70 if let Some(after) = query.updated_after {
72 if node.updated_at < after {
73 return false;
74 }
75 }
76
77 for (key, value) in &query.property_filters {
79 match node.properties.get(key) {
80 Some(prop_val) => {
81 let prop_str = serde_json::to_string(prop_val).unwrap_or_default();
82 let value_str = format!("\"{}\"", value);
83 if prop_str != value_str && prop_str != *value {
84 return false;
85 }
86 }
87 None => return false,
88 }
89 }
90
91 true
92 }
93}
94
95impl Default for InMemoryStorage {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101#[async_trait]
102impl GraphStorage for InMemoryStorage {
103 async fn create_node(&self, node: &Node) -> Result<Node> {
104 let mut nodes = self.nodes.write().await;
105 if nodes.contains_key(&node.id) {
106 return Err(StorageError::ConstraintViolation(
107 format!("Node with ID {} already exists", node.id)
108 ));
109 }
110 nodes.insert(node.id, node.clone());
111 Ok(node.clone())
112 }
113
114 async fn get_node(&self, id: NodeId) -> Result<Node> {
115 let nodes = self.nodes.read().await;
116 nodes.get(&id)
117 .cloned()
118 .ok_or(StorageError::NodeNotFound(id))
119 }
120
121 async fn update_node(&self, node: &Node) -> Result<Node> {
122 let mut nodes = self.nodes.write().await;
123 if !nodes.contains_key(&node.id) {
124 return Err(StorageError::NodeNotFound(node.id));
125 }
126 nodes.insert(node.id, node.clone());
127 Ok(node.clone())
128 }
129
130 async fn delete_node(&self, id: NodeId) -> Result<()> {
131 let mut nodes = self.nodes.write().await;
132 let mut edges = self.edges.write().await;
133
134 if !nodes.contains_key(&id) {
135 return Err(StorageError::NodeNotFound(id));
136 }
137
138 edges.retain(|_, edge| {
140 edge.from_node_id != id && edge.to_node_id != id
141 });
142
143 nodes.remove(&id);
144 Ok(())
145 }
146
147 async fn query_nodes(&self, query: &GraphQuery) -> Result<Vec<Node>> {
148 let nodes = self.nodes.read().await;
149 let mut results: Vec<Node> = nodes
150 .values()
151 .filter(|node| Self::matches_query(node, query))
152 .cloned()
153 .collect();
154
155 if let Some(limit) = query.limit {
156 results.truncate(limit);
157 }
158
159 Ok(results)
160 }
161
162 async fn create_edge(&self, edge: &Edge) -> Result<Edge> {
163 let nodes = self.nodes.read().await;
164
165 if !nodes.contains_key(&edge.from_node_id) {
167 return Err(StorageError::NodeNotFound(edge.from_node_id));
168 }
169 if !nodes.contains_key(&edge.to_node_id) {
170 return Err(StorageError::NodeNotFound(edge.to_node_id));
171 }
172
173 drop(nodes);
174
175 let mut edges = self.edges.write().await;
176 edges.insert(edge.id, edge.clone());
177 Ok(edge.clone())
178 }
179
180 async fn get_edge(&self, id: EdgeId) -> Result<Edge> {
181 let edges = self.edges.read().await;
182 edges.get(&id)
183 .cloned()
184 .ok_or(StorageError::EdgeNotFound(id))
185 }
186
187 async fn delete_edge(&self, id: EdgeId) -> Result<()> {
188 let mut edges = self.edges.write().await;
189 if !edges.contains_key(&id) {
190 return Err(StorageError::EdgeNotFound(id));
191 }
192 edges.remove(&id);
193 Ok(())
194 }
195
196 async fn get_edges_from(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
197 let edges = self.edges.read().await;
198 let results: Vec<Edge> = edges
199 .values()
200 .filter(|edge| {
201 edge.from_node_id == node_id &&
202 edge_type.map_or(true, |et| edge.edge_type == et)
203 })
204 .cloned()
205 .collect();
206 Ok(results)
207 }
208
209 async fn get_edges_to(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
210 let edges = self.edges.read().await;
211 let results: Vec<Edge> = edges
212 .values()
213 .filter(|edge| {
214 edge.to_node_id == node_id &&
215 edge_type.map_or(true, |et| edge.edge_type == et)
216 })
217 .cloned()
218 .collect();
219 Ok(results)
220 }
221
222 async fn get_neighbors(
223 &self,
224 node_id: NodeId,
225 edge_type: Option<&str>,
226 direction: EdgeDirection,
227 ) -> Result<Vec<Node>> {
228 let edges = self.edges.read().await;
229 let nodes = self.nodes.read().await;
230
231 let mut neighbor_ids: Vec<NodeId> = Vec::new();
232
233 for edge in edges.values() {
234 let matches_type = edge_type.map_or(true, |et| edge.edge_type == et);
235
236 match direction {
237 EdgeDirection::Outgoing if edge.from_node_id == node_id && matches_type => {
238 neighbor_ids.push(edge.to_node_id);
239 }
240 EdgeDirection::Incoming if edge.to_node_id == node_id && matches_type => {
241 neighbor_ids.push(edge.from_node_id);
242 }
243 EdgeDirection::Both if matches_type &&
244 (edge.from_node_id == node_id || edge.to_node_id == node_id) => {
245 let neighbor_id = if edge.from_node_id == node_id {
246 edge.to_node_id
247 } else {
248 edge.from_node_id
249 };
250 neighbor_ids.push(neighbor_id);
251 }
252 _ => {}
253 }
254 }
255
256 let neighbors: Vec<Node> = neighbor_ids
257 .into_iter()
258 .filter_map(|id| nodes.get(&id).cloned())
259 .collect();
260
261 Ok(neighbors)
262 }
263
264 async fn search_nodes(&self, query: &SearchQuery) -> Result<SearchResults<Node>> {
265 let nodes = self.nodes.read().await;
266
267 let mut results: Vec<Node> = nodes.values()
269 .filter(|node| Self::matches_search_query(node, query))
270 .cloned()
271 .collect();
272
273 results.sort_by(|a, b| {
275 let cmp = match query.order_by {
276 OrderBy::CreatedAt => a.created_at.cmp(&b.created_at),
277 OrderBy::UpdatedAt => a.updated_at.cmp(&b.updated_at),
278 OrderBy::Relevance => a.updated_at.cmp(&b.updated_at), };
280
281 match query.order_direction {
282 OrderDirection::Asc => cmp,
283 OrderDirection::Desc => cmp.reverse(),
284 }
285 });
286
287 let total_count = results.len();
288 let offset = query.offset;
289 let limit = query.limit;
290
291 let has_more = results.len() > offset + limit;
293
294 let paginated: Vec<Node> = results.into_iter()
296 .skip(offset)
297 .take(limit)
298 .collect();
299
300 let returned_count = paginated.len();
301
302 Ok(SearchResults {
303 items: paginated,
304 total_count,
305 returned_count,
306 has_more,
307 limit,
308 offset,
309 })
310 }
311
312 async fn count_nodes(&self, query: &SearchQuery) -> Result<usize> {
313 let nodes = self.nodes.read().await;
314
315 let count = nodes.values()
316 .filter(|node| Self::matches_search_query(node, query))
317 .count();
318
319 Ok(count)
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::domain::Properties;
327
328 #[tokio::test]
329 async fn test_create_and_get_node() {
330 let storage = InMemoryStorage::new();
331 let node = Node::new("test", Properties::new());
332
333 let created = storage.create_node(&node).await.unwrap();
334 assert_eq!(created.id, node.id);
335
336 let retrieved = storage.get_node(node.id).await.unwrap();
337 assert_eq!(retrieved.id, node.id);
338 }
339
340 #[tokio::test]
341 async fn test_create_edge_between_nodes() {
342 let storage = InMemoryStorage::new();
343
344 let node1 = Node::new("agent", Properties::new());
345 let node2 = Node::new("mailbox", Properties::new());
346
347 storage.create_node(&node1).await.unwrap();
348 storage.create_node(&node2).await.unwrap();
349
350 let edge = Edge::new("owns", node1.id, node2.id, Properties::new());
351 let created = storage.create_edge(&edge).await.unwrap();
352
353 assert_eq!(created.from_node_id, node1.id);
354 assert_eq!(created.to_node_id, node2.id);
355 }
356
357 #[tokio::test]
358 async fn test_query_nodes_with_type_filter() {
359 let storage = InMemoryStorage::new();
360
361 let agent = Node::new("agent", Properties::new());
362 let mailbox = Node::new("mailbox", Properties::new());
363
364 storage.create_node(&agent).await.unwrap();
365 storage.create_node(&mailbox).await.unwrap();
366
367 let query = GraphQuery::new().with_node_type("agent");
368 let results = storage.query_nodes(&query).await.unwrap();
369
370 assert_eq!(results.len(), 1);
371 assert_eq!(results[0].node_type, "agent");
372 }
373
374 #[tokio::test]
375 async fn test_get_neighbors() {
376 let storage = InMemoryStorage::new();
377
378 let agent = Node::new("agent", Properties::new());
379 let mailbox1 = Node::new("mailbox", Properties::new());
380 let mailbox2 = Node::new("mailbox", Properties::new());
381
382 storage.create_node(&agent).await.unwrap();
383 storage.create_node(&mailbox1).await.unwrap();
384 storage.create_node(&mailbox2).await.unwrap();
385
386 let edge1 = Edge::new("owns", agent.id, mailbox1.id, Properties::new());
387 let edge2 = Edge::new("owns", agent.id, mailbox2.id, Properties::new());
388
389 storage.create_edge(&edge1).await.unwrap();
390 storage.create_edge(&edge2).await.unwrap();
391
392 let neighbors = storage.get_neighbors(agent.id, Some("owns"), EdgeDirection::Outgoing).await.unwrap();
393 assert_eq!(neighbors.len(), 2);
394 }
395}