Skip to main content

agent_office/storage/
memory.rs

1use crate::domain::{Edge, EdgeId, GraphQuery, Node, NodeId};
2use crate::storage::{EdgeDirection, GraphStorage, Result, StorageError, SearchQuery, SearchResults, OrderBy, OrderDirection};
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7
8#[derive(Clone)]
9pub struct InMemoryStorage {
10    nodes: Arc<RwLock<HashMap<NodeId, Node>>>,
11    edges: Arc<RwLock<HashMap<EdgeId, Edge>>>,
12}
13
14impl InMemoryStorage {
15    pub fn new() -> Self {
16        Self {
17            nodes: Arc::new(RwLock::new(HashMap::new())),
18            edges: Arc::new(RwLock::new(HashMap::new())),
19        }
20    }
21
22    fn matches_query(node: &Node, query: &GraphQuery) -> bool {
23        // Check node type filter
24        if let Some(ref types) = query.node_types {
25            if !types.contains(&node.node_type) {
26                return false;
27            }
28        }
29
30        // Check property filters
31        if let Some(ref filters) = query.property_filters {
32            for (key, expected_value) in filters {
33                match node.properties.get(key) {
34                    Some(actual_value) if actual_value == expected_value => continue,
35                    _ => return false,
36                }
37            }
38        }
39
40        true
41    }
42    
43    fn matches_search_query(node: &Node, query: &SearchQuery) -> bool {
44        // Check node types
45        if !query.node_types.is_empty() && !query.node_types.contains(&node.node_type) {
46            return false;
47        }
48        
49        // Check text search
50        if let Some(ref search_text) = query.search_text {
51            let search_lower = search_text.to_lowercase();
52            let node_text = serde_json::to_string(&node.properties).unwrap_or_default().to_lowercase();
53            if !node_text.contains(&search_lower) {
54                return false;
55            }
56        }
57        
58        // Check created time range
59        if let Some(after) = query.created_after {
60            if node.created_at < after {
61                return false;
62            }
63        }
64        if let Some(before) = query.created_before {
65            if node.created_at > before {
66                return false;
67            }
68        }
69        
70        // Check updated time range
71        if let Some(after) = query.updated_after {
72            if node.updated_at < after {
73                return false;
74            }
75        }
76        
77        // Check property filters
78        for (key, value) in &query.property_filters {
79            match node.properties.get(key) {
80                Some(prop_val) => {
81                    let prop_str = serde_json::to_string(prop_val).unwrap_or_default();
82                    let value_str = format!("\"{}\"", value);
83                    if prop_str != value_str && prop_str != *value {
84                        return false;
85                    }
86                }
87                None => return false,
88            }
89        }
90        
91        true
92    }
93}
94
95impl Default for InMemoryStorage {
96    fn default() -> Self {
97        Self::new()
98    }
99}
100
101#[async_trait]
102impl GraphStorage for InMemoryStorage {
103    async fn create_node(&self, node: &Node) -> Result<Node> {
104        let mut nodes = self.nodes.write().await;
105        if nodes.contains_key(&node.id) {
106            return Err(StorageError::ConstraintViolation(
107                format!("Node with ID {} already exists", node.id)
108            ));
109        }
110        nodes.insert(node.id, node.clone());
111        Ok(node.clone())
112    }
113
114    async fn get_node(&self, id: NodeId) -> Result<Node> {
115        let nodes = self.nodes.read().await;
116        nodes.get(&id)
117            .cloned()
118            .ok_or(StorageError::NodeNotFound(id))
119    }
120
121    async fn update_node(&self, node: &Node) -> Result<Node> {
122        let mut nodes = self.nodes.write().await;
123        if !nodes.contains_key(&node.id) {
124            return Err(StorageError::NodeNotFound(node.id));
125        }
126        nodes.insert(node.id, node.clone());
127        Ok(node.clone())
128    }
129
130    async fn delete_node(&self, id: NodeId) -> Result<()> {
131        let mut nodes = self.nodes.write().await;
132        let mut edges = self.edges.write().await;
133        
134        if !nodes.contains_key(&id) {
135            return Err(StorageError::NodeNotFound(id));
136        }
137        
138        // Remove all edges connected to this node
139        edges.retain(|_, edge| {
140            edge.from_node_id != id && edge.to_node_id != id
141        });
142        
143        nodes.remove(&id);
144        Ok(())
145    }
146
147    async fn query_nodes(&self, query: &GraphQuery) -> Result<Vec<Node>> {
148        let nodes = self.nodes.read().await;
149        let mut results: Vec<Node> = nodes
150            .values()
151            .filter(|node| Self::matches_query(node, query))
152            .cloned()
153            .collect();
154        
155        if let Some(limit) = query.limit {
156            results.truncate(limit);
157        }
158        
159        Ok(results)
160    }
161
162    async fn create_edge(&self, edge: &Edge) -> Result<Edge> {
163        let nodes = self.nodes.read().await;
164        
165        // Verify both nodes exist
166        if !nodes.contains_key(&edge.from_node_id) {
167            return Err(StorageError::NodeNotFound(edge.from_node_id));
168        }
169        if !nodes.contains_key(&edge.to_node_id) {
170            return Err(StorageError::NodeNotFound(edge.to_node_id));
171        }
172        
173        drop(nodes);
174        
175        let mut edges = self.edges.write().await;
176        edges.insert(edge.id, edge.clone());
177        Ok(edge.clone())
178    }
179
180    async fn get_edge(&self, id: EdgeId) -> Result<Edge> {
181        let edges = self.edges.read().await;
182        edges.get(&id)
183            .cloned()
184            .ok_or(StorageError::EdgeNotFound(id))
185    }
186
187    async fn delete_edge(&self, id: EdgeId) -> Result<()> {
188        let mut edges = self.edges.write().await;
189        if !edges.contains_key(&id) {
190            return Err(StorageError::EdgeNotFound(id));
191        }
192        edges.remove(&id);
193        Ok(())
194    }
195
196    async fn get_edges_from(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
197        let edges = self.edges.read().await;
198        let results: Vec<Edge> = edges
199            .values()
200            .filter(|edge| {
201                edge.from_node_id == node_id &&
202                edge_type.map_or(true, |et| edge.edge_type == et)
203            })
204            .cloned()
205            .collect();
206        Ok(results)
207    }
208
209    async fn get_edges_to(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
210        let edges = self.edges.read().await;
211        let results: Vec<Edge> = edges
212            .values()
213            .filter(|edge| {
214                edge.to_node_id == node_id &&
215                edge_type.map_or(true, |et| edge.edge_type == et)
216            })
217            .cloned()
218            .collect();
219        Ok(results)
220    }
221
222    async fn get_neighbors(
223        &self,
224        node_id: NodeId,
225        edge_type: Option<&str>,
226        direction: EdgeDirection,
227    ) -> Result<Vec<Node>> {
228        let edges = self.edges.read().await;
229        let nodes = self.nodes.read().await;
230        
231        let mut neighbor_ids: Vec<NodeId> = Vec::new();
232        
233        for edge in edges.values() {
234            let matches_type = edge_type.map_or(true, |et| edge.edge_type == et);
235            
236            match direction {
237                EdgeDirection::Outgoing if edge.from_node_id == node_id && matches_type => {
238                    neighbor_ids.push(edge.to_node_id);
239                }
240                EdgeDirection::Incoming if edge.to_node_id == node_id && matches_type => {
241                    neighbor_ids.push(edge.from_node_id);
242                }
243                EdgeDirection::Both if matches_type && 
244                    (edge.from_node_id == node_id || edge.to_node_id == node_id) => {
245                    let neighbor_id = if edge.from_node_id == node_id {
246                        edge.to_node_id
247                    } else {
248                        edge.from_node_id
249                    };
250                    neighbor_ids.push(neighbor_id);
251                }
252                _ => {}
253            }
254        }
255        
256        let neighbors: Vec<Node> = neighbor_ids
257            .into_iter()
258            .filter_map(|id| nodes.get(&id).cloned())
259            .collect();
260        
261        Ok(neighbors)
262    }
263    
264    async fn search_nodes(&self, query: &SearchQuery) -> Result<SearchResults<Node>> {
265        let nodes = self.nodes.read().await;
266        
267        // Filter nodes based on query criteria
268        let mut results: Vec<Node> = nodes.values()
269            .filter(|node| Self::matches_search_query(node, query))
270            .cloned()
271            .collect();
272        
273        // Sort results
274        results.sort_by(|a, b| {
275            let cmp = match query.order_by {
276                OrderBy::CreatedAt => a.created_at.cmp(&b.created_at),
277                OrderBy::UpdatedAt => a.updated_at.cmp(&b.updated_at),
278                OrderBy::Relevance => a.updated_at.cmp(&b.updated_at), // Fallback
279            };
280            
281            match query.order_direction {
282                OrderDirection::Asc => cmp,
283                OrderDirection::Desc => cmp.reverse(),
284            }
285        });
286        
287        let total_count = results.len();
288        let offset = query.offset;
289        let limit = query.limit;
290        
291        // Check if there are more results
292        let has_more = results.len() > offset + limit;
293        
294        // Apply pagination
295        let paginated: Vec<Node> = results.into_iter()
296            .skip(offset)
297            .take(limit)
298            .collect();
299        
300        let returned_count = paginated.len();
301        
302        Ok(SearchResults {
303            items: paginated,
304            total_count,
305            returned_count,
306            has_more,
307            limit,
308            offset,
309        })
310    }
311    
312    async fn count_nodes(&self, query: &SearchQuery) -> Result<usize> {
313        let nodes = self.nodes.read().await;
314        
315        let count = nodes.values()
316            .filter(|node| Self::matches_search_query(node, query))
317            .count();
318        
319        Ok(count)
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::domain::Properties;
327
328    #[tokio::test]
329    async fn test_create_and_get_node() {
330        let storage = InMemoryStorage::new();
331        let node = Node::new("test", Properties::new());
332        
333        let created = storage.create_node(&node).await.unwrap();
334        assert_eq!(created.id, node.id);
335        
336        let retrieved = storage.get_node(node.id).await.unwrap();
337        assert_eq!(retrieved.id, node.id);
338    }
339
340    #[tokio::test]
341    async fn test_create_edge_between_nodes() {
342        let storage = InMemoryStorage::new();
343        
344        let node1 = Node::new("agent", Properties::new());
345        let node2 = Node::new("mailbox", Properties::new());
346        
347        storage.create_node(&node1).await.unwrap();
348        storage.create_node(&node2).await.unwrap();
349        
350        let edge = Edge::new("owns", node1.id, node2.id, Properties::new());
351        let created = storage.create_edge(&edge).await.unwrap();
352        
353        assert_eq!(created.from_node_id, node1.id);
354        assert_eq!(created.to_node_id, node2.id);
355    }
356
357    #[tokio::test]
358    async fn test_query_nodes_with_type_filter() {
359        let storage = InMemoryStorage::new();
360        
361        let agent = Node::new("agent", Properties::new());
362        let mailbox = Node::new("mailbox", Properties::new());
363        
364        storage.create_node(&agent).await.unwrap();
365        storage.create_node(&mailbox).await.unwrap();
366        
367        let query = GraphQuery::new().with_node_type("agent");
368        let results = storage.query_nodes(&query).await.unwrap();
369        
370        assert_eq!(results.len(), 1);
371        assert_eq!(results[0].node_type, "agent");
372    }
373
374    #[tokio::test]
375    async fn test_get_neighbors() {
376        let storage = InMemoryStorage::new();
377        
378        let agent = Node::new("agent", Properties::new());
379        let mailbox1 = Node::new("mailbox", Properties::new());
380        let mailbox2 = Node::new("mailbox", Properties::new());
381        
382        storage.create_node(&agent).await.unwrap();
383        storage.create_node(&mailbox1).await.unwrap();
384        storage.create_node(&mailbox2).await.unwrap();
385        
386        let edge1 = Edge::new("owns", agent.id, mailbox1.id, Properties::new());
387        let edge2 = Edge::new("owns", agent.id, mailbox2.id, Properties::new());
388        
389        storage.create_edge(&edge1).await.unwrap();
390        storage.create_edge(&edge2).await.unwrap();
391        
392        let neighbors = storage.get_neighbors(agent.id, Some("owns"), EdgeDirection::Outgoing).await.unwrap();
393        assert_eq!(neighbors.len(), 2);
394    }
395}