Skip to main content

oxide_graph/
graph.rs

1//! Node / Edge types and the [`InMemoryGraph`] store.
2
3use std::collections::{BTreeSet, HashMap};
4use std::sync::RwLock;
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use uuid::Uuid;
9
10use crate::error::{GraphError, Result};
11
12/// Stable, user-supplied id for a node (`"resource:record_id"` typical).
13pub type NodeId = String;
14
15/// Auto-generated edge id.
16pub type EdgeId = String;
17
18/// A property-graph node.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct Node {
21    /// Stable id.
22    pub id: NodeId,
23    /// One or more semantic labels (e.g. `["pet"]`, `["user", "agent"]`).
24    pub labels: Vec<String>,
25    /// Free-form properties.
26    pub properties: serde_json::Map<String, serde_json::Value>,
27}
28
29impl Node {
30    /// Build a node with a single label.
31    pub fn new(id: impl Into<NodeId>, label: impl Into<String>) -> Self {
32        Self {
33            id: id.into(),
34            labels: vec![label.into()],
35            properties: serde_json::Map::new(),
36        }
37    }
38
39    /// Builder helper.
40    #[must_use]
41    pub fn with_property(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
42        self.properties.insert(key.into(), value);
43        self
44    }
45
46    /// Builder helper.
47    #[must_use]
48    pub fn with_label(mut self, label: impl Into<String>) -> Self {
49        self.labels.push(label.into());
50        self
51    }
52}
53
54/// A directed labelled edge.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct Edge {
57    /// Auto-generated id.
58    pub id: EdgeId,
59    /// Tail node id.
60    pub from: NodeId,
61    /// Head node id.
62    pub to: NodeId,
63    /// Semantic label (`"owns"`, `"references"`, `"mentions"`).
64    pub label: String,
65    /// Free-form properties.
66    pub properties: serde_json::Map<String, serde_json::Value>,
67}
68
69impl Edge {
70    /// Build an edge with a generated id.
71    pub fn new(from: impl Into<NodeId>, to: impl Into<NodeId>, label: impl Into<String>) -> Self {
72        Self {
73            id: Uuid::new_v4().to_string(),
74            from: from.into(),
75            to: to.into(),
76            label: label.into(),
77            properties: serde_json::Map::new(),
78        }
79    }
80
81    /// Builder helper.
82    #[must_use]
83    pub fn with_property(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
84        self.properties.insert(key.into(), value);
85        self
86    }
87}
88
89/// Storage abstraction. Concrete implementations: [`InMemoryGraph`].
90#[async_trait]
91pub trait GraphStore: Send + Sync {
92    /// Insert / replace a node by id.
93    async fn upsert_node(&self, node: Node) -> Result<()>;
94    /// Insert an edge.
95    async fn add_edge(&self, edge: Edge) -> Result<EdgeId>;
96    /// Look up a node by id.
97    async fn get_node(&self, id: &NodeId) -> Result<Option<Node>>;
98    /// Look up an edge by id.
99    async fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>>;
100    /// All nodes carrying any of `labels`.
101    async fn nodes_by_label(&self, label: &str) -> Result<Vec<Node>>;
102    /// All edges leaving `from` (optionally filtered by label).
103    async fn edges_from(&self, from: &NodeId, label: Option<&str>) -> Result<Vec<Edge>>;
104    /// All edges entering `to` (optionally filtered by label).
105    async fn edges_to(&self, to: &NodeId, label: Option<&str>) -> Result<Vec<Edge>>;
106    /// Counts.
107    async fn stats(&self) -> Result<(usize, usize)>;
108}
109
110struct InnerGraph {
111    nodes: HashMap<NodeId, Node>,
112    edges: HashMap<EdgeId, Edge>,
113    by_label: HashMap<String, BTreeSet<NodeId>>,
114    out_edges: HashMap<NodeId, BTreeSet<EdgeId>>,
115    in_edges: HashMap<NodeId, BTreeSet<EdgeId>>,
116}
117
118/// In-process knowledge graph. Cheap to clone (wraps `Arc<RwLock<…>>`).
119#[derive(Clone)]
120pub struct InMemoryGraph {
121    inner: std::sync::Arc<RwLock<InnerGraph>>,
122}
123
124impl InMemoryGraph {
125    /// Build an empty graph.
126    pub fn new() -> Self {
127        Self {
128            inner: std::sync::Arc::new(RwLock::new(InnerGraph {
129                nodes: HashMap::new(),
130                edges: HashMap::new(),
131                by_label: HashMap::new(),
132                out_edges: HashMap::new(),
133                in_edges: HashMap::new(),
134            })),
135        }
136    }
137}
138
139impl Default for InMemoryGraph {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145#[async_trait]
146impl GraphStore for InMemoryGraph {
147    async fn upsert_node(&self, node: Node) -> Result<()> {
148        let mut g = self.inner.write().unwrap();
149        // Remove prior label indexing if replacing.
150        if let Some(prev) = g.nodes.get(&node.id) {
151            for label in prev.labels.clone() {
152                if let Some(set) = g.by_label.get_mut(&label) {
153                    set.remove(&node.id);
154                }
155            }
156        }
157        for label in &node.labels {
158            g.by_label
159                .entry(label.clone())
160                .or_default()
161                .insert(node.id.clone());
162        }
163        g.nodes.insert(node.id.clone(), node);
164        Ok(())
165    }
166
167    async fn add_edge(&self, edge: Edge) -> Result<EdgeId> {
168        let mut g = self.inner.write().unwrap();
169        if !g.nodes.contains_key(&edge.from) {
170            return Err(GraphError::UnknownNode(edge.from.clone()));
171        }
172        if !g.nodes.contains_key(&edge.to) {
173            return Err(GraphError::UnknownNode(edge.to.clone()));
174        }
175        let id = edge.id.clone();
176        g.out_edges
177            .entry(edge.from.clone())
178            .or_default()
179            .insert(id.clone());
180        g.in_edges
181            .entry(edge.to.clone())
182            .or_default()
183            .insert(id.clone());
184        g.edges.insert(id.clone(), edge);
185        Ok(id)
186    }
187
188    async fn get_node(&self, id: &NodeId) -> Result<Option<Node>> {
189        Ok(self.inner.read().unwrap().nodes.get(id).cloned())
190    }
191
192    async fn get_edge(&self, id: &EdgeId) -> Result<Option<Edge>> {
193        Ok(self.inner.read().unwrap().edges.get(id).cloned())
194    }
195
196    async fn nodes_by_label(&self, label: &str) -> Result<Vec<Node>> {
197        let g = self.inner.read().unwrap();
198        Ok(g.by_label
199            .get(label)
200            .map(|ids| ids.iter().filter_map(|i| g.nodes.get(i).cloned()).collect())
201            .unwrap_or_default())
202    }
203
204    async fn edges_from(&self, from: &NodeId, label: Option<&str>) -> Result<Vec<Edge>> {
205        let g = self.inner.read().unwrap();
206        Ok(g.out_edges
207            .get(from)
208            .map(|ids| {
209                ids.iter()
210                    .filter_map(|i| g.edges.get(i).cloned())
211                    .filter(|e| label.is_none_or(|l| e.label == l))
212                    .collect()
213            })
214            .unwrap_or_default())
215    }
216
217    async fn edges_to(&self, to: &NodeId, label: Option<&str>) -> Result<Vec<Edge>> {
218        let g = self.inner.read().unwrap();
219        Ok(g.in_edges
220            .get(to)
221            .map(|ids| {
222                ids.iter()
223                    .filter_map(|i| g.edges.get(i).cloned())
224                    .filter(|e| label.is_none_or(|l| e.label == l))
225                    .collect()
226            })
227            .unwrap_or_default())
228    }
229
230    async fn stats(&self) -> Result<(usize, usize)> {
231        let g = self.inner.read().unwrap();
232        Ok((g.nodes.len(), g.edges.len()))
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use serde_json::json;
240
241    #[tokio::test]
242    async fn upsert_and_query_by_label() {
243        let g = InMemoryGraph::new();
244        g.upsert_node(Node::new("pet:1", "pet").with_property("name", json!("Rex")))
245            .await
246            .unwrap();
247        g.upsert_node(Node::new("pet:2", "pet").with_property("name", json!("Buddy")))
248            .await
249            .unwrap();
250        g.upsert_node(Node::new("user:1", "user")).await.unwrap();
251
252        let pets = g.nodes_by_label("pet").await.unwrap();
253        assert_eq!(pets.len(), 2);
254        let users = g.nodes_by_label("user").await.unwrap();
255        assert_eq!(users.len(), 1);
256    }
257
258    #[tokio::test]
259    async fn edges_link_existing_nodes_only() {
260        let g = InMemoryGraph::new();
261        g.upsert_node(Node::new("a", "node")).await.unwrap();
262        g.upsert_node(Node::new("b", "node")).await.unwrap();
263        let id = g.add_edge(Edge::new("a", "b", "links")).await.unwrap();
264        assert!(g.get_edge(&id).await.unwrap().is_some());
265
266        let err = g
267            .add_edge(Edge::new("a", "missing", "links"))
268            .await
269            .unwrap_err();
270        assert!(matches!(err, GraphError::UnknownNode(_)));
271    }
272
273    #[tokio::test]
274    async fn directional_edge_queries() {
275        let g = InMemoryGraph::new();
276        for n in ["a", "b", "c"] {
277            g.upsert_node(Node::new(n, "n")).await.unwrap();
278        }
279        g.add_edge(Edge::new("a", "b", "knows")).await.unwrap();
280        g.add_edge(Edge::new("a", "c", "knows")).await.unwrap();
281        g.add_edge(Edge::new("b", "c", "owns")).await.unwrap();
282
283        let from_a = g.edges_from(&"a".into(), None).await.unwrap();
284        assert_eq!(from_a.len(), 2);
285        let from_a_owns = g.edges_from(&"a".into(), Some("owns")).await.unwrap();
286        assert_eq!(from_a_owns.len(), 0);
287        let to_c = g.edges_to(&"c".into(), None).await.unwrap();
288        assert_eq!(to_c.len(), 2);
289        let to_c_owns = g.edges_to(&"c".into(), Some("owns")).await.unwrap();
290        assert_eq!(to_c_owns.len(), 1);
291    }
292
293    #[tokio::test]
294    async fn stats_reflect_inserts() {
295        let g = InMemoryGraph::new();
296        g.upsert_node(Node::new("a", "n")).await.unwrap();
297        g.upsert_node(Node::new("b", "n")).await.unwrap();
298        g.add_edge(Edge::new("a", "b", "x")).await.unwrap();
299        assert_eq!(g.stats().await.unwrap(), (2, 1));
300    }
301}