1use std::collections::{BTreeSet, HashMap};
4use std::sync::RwLock;
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use uuid::Uuid;
9
10use crate::error::{GraphError, Result};
11
12pub type NodeId = String;
14
15pub type EdgeId = String;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Node {
21 pub id: NodeId,
23 pub labels: Vec<String>,
25 pub properties: serde_json::Map<String, serde_json::Value>,
27}
28
29impl Node {
30 pub fn new(id: impl Into<NodeId>, label: impl Into<String>) -> Self {
32 Self {
33 id: id.into(),
34 labels: vec![label.into()],
35 properties: serde_json::Map::new(),
36 }
37 }
38
39 #[must_use]
41 pub fn with_property(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
42 self.properties.insert(key.into(), value);
43 self
44 }
45
46 #[must_use]
48 pub fn with_label(mut self, label: impl Into<String>) -> Self {
49 self.labels.push(label.into());
50 self
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct Edge {
57 pub id: EdgeId,
59 pub from: NodeId,
61 pub to: NodeId,
63 pub label: String,
65 pub properties: serde_json::Map<String, serde_json::Value>,
67}
68
69impl Edge {
70 pub fn new(from: impl Into<NodeId>, to: impl Into<NodeId>, label: impl Into<String>) -> Self {
72 Self {
73 id: Uuid::new_v4().to_string(),
74 from: from.into(),
75 to: to.into(),
76 label: label.into(),
77 properties: serde_json::Map::new(),
78 }
79 }
80
81 #[must_use]
83 pub fn with_property(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
84 self.properties.insert(key.into(), value);
85 self
86 }
87}
88
89#[async_trait]
91pub trait GraphStore: Send + Sync {
92 async fn upsert_node(&self, node: Node) -> Result<()>;
94 async fn add_edge(&self, edge: Edge) -> Result<EdgeId>;
96 async fn get_node(&self, id: &NodeId) -> Result<Option<Node>>;
98 async fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>>;
100 async fn nodes_by_label(&self, label: &str) -> Result<Vec<Node>>;
102 async fn edges_from(&self, from: &NodeId, label: Option<&str>) -> Result<Vec<Edge>>;
104 async fn edges_to(&self, to: &NodeId, label: Option<&str>) -> Result<Vec<Edge>>;
106 async fn stats(&self) -> Result<(usize, usize)>;
108}
109
110struct InnerGraph {
111 nodes: HashMap<NodeId, Node>,
112 edges: HashMap<EdgeId, Edge>,
113 by_label: HashMap<String, BTreeSet<NodeId>>,
114 out_edges: HashMap<NodeId, BTreeSet<EdgeId>>,
115 in_edges: HashMap<NodeId, BTreeSet<EdgeId>>,
116}
117
118#[derive(Clone)]
120pub struct InMemoryGraph {
121 inner: std::sync::Arc<RwLock<InnerGraph>>,
122}
123
124impl InMemoryGraph {
125 pub fn new() -> Self {
127 Self {
128 inner: std::sync::Arc::new(RwLock::new(InnerGraph {
129 nodes: HashMap::new(),
130 edges: HashMap::new(),
131 by_label: HashMap::new(),
132 out_edges: HashMap::new(),
133 in_edges: HashMap::new(),
134 })),
135 }
136 }
137}
138
139impl Default for InMemoryGraph {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145#[async_trait]
146impl GraphStore for InMemoryGraph {
147 async fn upsert_node(&self, node: Node) -> Result<()> {
148 let mut g = self.inner.write().unwrap();
149 if let Some(prev) = g.nodes.get(&node.id) {
151 for label in prev.labels.clone() {
152 if let Some(set) = g.by_label.get_mut(&label) {
153 set.remove(&node.id);
154 }
155 }
156 }
157 for label in &node.labels {
158 g.by_label
159 .entry(label.clone())
160 .or_default()
161 .insert(node.id.clone());
162 }
163 g.nodes.insert(node.id.clone(), node);
164 Ok(())
165 }
166
167 async fn add_edge(&self, edge: Edge) -> Result<EdgeId> {
168 let mut g = self.inner.write().unwrap();
169 if !g.nodes.contains_key(&edge.from) {
170 return Err(GraphError::UnknownNode(edge.from.clone()));
171 }
172 if !g.nodes.contains_key(&edge.to) {
173 return Err(GraphError::UnknownNode(edge.to.clone()));
174 }
175 let id = edge.id.clone();
176 g.out_edges
177 .entry(edge.from.clone())
178 .or_default()
179 .insert(id.clone());
180 g.in_edges
181 .entry(edge.to.clone())
182 .or_default()
183 .insert(id.clone());
184 g.edges.insert(id.clone(), edge);
185 Ok(id)
186 }
187
188 async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
189 Ok(self.inner.read().unwrap().nodes.get(id).cloned())
190 }
191
192 async fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>> {
193 Ok(self.inner.read().unwrap().edges.get(id).cloned())
194 }
195
196 async fn nodes_by_label(&self, label: &str) -> Result<Vec<Node>> {
197 let g = self.inner.read().unwrap();
198 Ok(g.by_label
199 .get(label)
200 .map(|ids| ids.iter().filter_map(|i| g.nodes.get(i).cloned()).collect())
201 .unwrap_or_default())
202 }
203
204 async fn edges_from(&self, from: &NodeId, label: Option<&str>) -> Result<Vec<Edge>> {
205 let g = self.inner.read().unwrap();
206 Ok(g.out_edges
207 .get(from)
208 .map(|ids| {
209 ids.iter()
210 .filter_map(|i| g.edges.get(i).cloned())
211 .filter(|e| label.is_none_or(|l| e.label == l))
212 .collect()
213 })
214 .unwrap_or_default())
215 }
216
217 async fn edges_to(&self, to: &NodeId, label: Option<&str>) -> Result<Vec<Edge>> {
218 let g = self.inner.read().unwrap();
219 Ok(g.in_edges
220 .get(to)
221 .map(|ids| {
222 ids.iter()
223 .filter_map(|i| g.edges.get(i).cloned())
224 .filter(|e| label.is_none_or(|l| e.label == l))
225 .collect()
226 })
227 .unwrap_or_default())
228 }
229
230 async fn stats(&self) -> Result<(usize, usize)> {
231 let g = self.inner.read().unwrap();
232 Ok((g.nodes.len(), g.edges.len()))
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use serde_json::json;
240
241 #[tokio::test]
242 async fn upsert_and_query_by_label() {
243 let g = InMemoryGraph::new();
244 g.upsert_node(Node::new("pet:1", "pet").with_property("name", json!("Rex")))
245 .await
246 .unwrap();
247 g.upsert_node(Node::new("pet:2", "pet").with_property("name", json!("Buddy")))
248 .await
249 .unwrap();
250 g.upsert_node(Node::new("user:1", "user")).await.unwrap();
251
252 let pets = g.nodes_by_label("pet").await.unwrap();
253 assert_eq!(pets.len(), 2);
254 let users = g.nodes_by_label("user").await.unwrap();
255 assert_eq!(users.len(), 1);
256 }
257
258 #[tokio::test]
259 async fn edges_link_existing_nodes_only() {
260 let g = InMemoryGraph::new();
261 g.upsert_node(Node::new("a", "node")).await.unwrap();
262 g.upsert_node(Node::new("b", "node")).await.unwrap();
263 let id = g.add_edge(Edge::new("a", "b", "links")).await.unwrap();
264 assert!(g.get_edge(&id).await.unwrap().is_some());
265
266 let err = g
267 .add_edge(Edge::new("a", "missing", "links"))
268 .await
269 .unwrap_err();
270 assert!(matches!(err, GraphError::UnknownNode(_)));
271 }
272
273 #[tokio::test]
274 async fn directional_edge_queries() {
275 let g = InMemoryGraph::new();
276 for n in ["a", "b", "c"] {
277 g.upsert_node(Node::new(n, "n")).await.unwrap();
278 }
279 g.add_edge(Edge::new("a", "b", "knows")).await.unwrap();
280 g.add_edge(Edge::new("a", "c", "knows")).await.unwrap();
281 g.add_edge(Edge::new("b", "c", "owns")).await.unwrap();
282
283 let from_a = g.edges_from(&"a".into(), None).await.unwrap();
284 assert_eq!(from_a.len(), 2);
285 let from_a_owns = g.edges_from(&"a".into(), Some("owns")).await.unwrap();
286 assert_eq!(from_a_owns.len(), 0);
287 let to_c = g.edges_to(&"c".into(), None).await.unwrap();
288 assert_eq!(to_c.len(), 2);
289 let to_c_owns = g.edges_to(&"c".into(), Some("owns")).await.unwrap();
290 assert_eq!(to_c_owns.len(), 1);
291 }
292
293 #[tokio::test]
294 async fn stats_reflect_inserts() {
295 let g = InMemoryGraph::new();
296 g.upsert_node(Node::new("a", "n")).await.unwrap();
297 g.upsert_node(Node::new("b", "n")).await.unwrap();
298 g.add_edge(Edge::new("a", "b", "x")).await.unwrap();
299 assert_eq!(g.stats().await.unwrap(), (2, 1));
300 }
301}