Skip to main content

oxide_graph/
kernel.rs

1//! `oxide-k` bus integration for the graph.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use oxide_k::bus::{Command, Event, Message, MessageBus};
7use oxide_k::module::{Module, ModuleKind, ModuleMetadata};
8use oxide_k::{KernelError, Result as KernelResult};
9use serde::Deserialize;
10use tokio::task::JoinHandle;
11
12use crate::graph::{Edge, GraphStore, InMemoryGraph, Node, NodeId};
13use crate::ingest::{ingest_record, RecordRef};
14use crate::query::{traverse, EdgeQuery, NodeQuery};
15
16/// Default module id.
17pub const DEFAULT_MODULE_ID: &str = "graph";
18
19/// Knowledge-graph module wrapped around an [`InMemoryGraph`].
20pub struct GraphModule {
21    id: String,
22    store: Arc<dyn GraphStore>,
23    listener: Option<JoinHandle<()>>,
24}
25
26impl GraphModule {
27    /// Build with the default id and a fresh in-memory store.
28    pub fn new() -> Self {
29        Self::with_store(Arc::new(InMemoryGraph::new()))
30    }
31
32    /// Build with an explicit store.
33    pub fn with_store(store: Arc<dyn GraphStore>) -> Self {
34        Self {
35            id: DEFAULT_MODULE_ID.into(),
36            store,
37            listener: None,
38        }
39    }
40
41    /// Access the underlying store.
42    pub fn store(&self) -> Arc<dyn GraphStore> {
43        self.store.clone()
44    }
45}
46
47impl Default for GraphModule {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53#[async_trait]
54impl Module for GraphModule {
55    fn metadata(&self) -> ModuleMetadata {
56        ModuleMetadata {
57            id: self.id.clone(),
58            name: "Oxide Knowledge Graph".into(),
59            version: env!("CARGO_PKG_VERSION").into(),
60            kind: ModuleKind::Native,
61            description: Some(
62                "In-memory typed property graph; ingests mirrored records, answers pattern + traversal queries.".into(),
63            ),
64        }
65    }
66
67    async fn init(&mut self, bus: MessageBus) -> KernelResult<()> {
68        let mut sub = bus.subscribe().await;
69        let store = self.store.clone();
70        let id = self.id.clone();
71        let bus_emit = bus.clone();
72        let handle = tokio::spawn(async move {
73            while let Some(env) = sub.receiver.recv().await {
74                let Message::Command(Command::Invoke {
75                    module_id,
76                    method,
77                    payload,
78                }) = env.message
79                else {
80                    continue;
81                };
82                if module_id != id {
83                    continue;
84                }
85                let result = dispatch(&store, &method, payload).await;
86                let event = match result {
87                    Ok(value) => Event::Custom {
88                        module_id: id.clone(),
89                        kind: format!("{method}.ok"),
90                        payload: value,
91                    },
92                    Err(err) => Event::Custom {
93                        module_id: id.clone(),
94                        kind: format!("{method}.err"),
95                        payload: serde_json::json!({ "error": err.to_string() }),
96                    },
97                };
98                let _ = bus_emit.emit_event(id.clone(), event).await;
99            }
100        });
101        self.listener = Some(handle);
102        Ok(())
103    }
104
105    async fn start(&mut self) -> KernelResult<()> {
106        Ok(())
107    }
108
109    async fn stop(&mut self) -> KernelResult<()> {
110        if let Some(h) = self.listener.take() {
111            h.abort();
112        }
113        Ok(())
114    }
115}
116
117#[derive(Deserialize)]
118struct IngestPayload {
119    resource: String,
120    record_id: String,
121    source: String,
122    payload: serde_json::Value,
123}
124
125#[derive(Deserialize)]
126struct UpsertNodePayload {
127    node: Node,
128}
129
130#[derive(Deserialize)]
131struct AddEdgePayload {
132    edge: Edge,
133}
134
135#[derive(Deserialize)]
136struct GetNodePayload {
137    id: NodeId,
138}
139
140#[derive(Deserialize)]
141struct NodeQueryPayload {
142    label: String,
143    #[serde(default)]
144    property_eq: Vec<(String, serde_json::Value)>,
145}
146
147#[derive(Deserialize)]
148struct EdgeQueryPayload {
149    anchor: NodeId,
150    #[serde(default)]
151    label: Option<String>,
152    #[serde(default)]
153    direction: Option<String>,
154}
155
156#[derive(Deserialize)]
157struct TraversePayload {
158    start: NodeId,
159    #[serde(default)]
160    edge_label: Option<String>,
161    #[serde(default = "default_depth")]
162    max_depth: usize,
163}
164
165fn default_depth() -> usize {
166    2
167}
168
169async fn dispatch(
170    store: &Arc<dyn GraphStore>,
171    method: &str,
172    payload: serde_json::Value,
173) -> KernelResult<serde_json::Value> {
174    let to_kernel = |e: crate::error::GraphError| KernelError::Other(anyhow::anyhow!(e));
175
176    match method {
177        "ingest" => {
178            let p: IngestPayload = serde_json::from_value(payload)?;
179            let id = ingest_record(
180                store.as_ref(),
181                RecordRef {
182                    resource: &p.resource,
183                    record_id: &p.record_id,
184                    payload: &p.payload,
185                    source: &p.source,
186                },
187            )
188            .await
189            .map_err(to_kernel)?;
190            Ok(serde_json::json!({ "id": id }))
191        }
192        "upsert_node" => {
193            let p: UpsertNodePayload = serde_json::from_value(payload)?;
194            store.upsert_node(p.node).await.map_err(to_kernel)?;
195            Ok(serde_json::json!({"ok": true}))
196        }
197        "add_edge" => {
198            let p: AddEdgePayload = serde_json::from_value(payload)?;
199            let id = store.add_edge(p.edge).await.map_err(to_kernel)?;
200            Ok(serde_json::json!({"id": id}))
201        }
202        "get_node" => {
203            let p: GetNodePayload = serde_json::from_value(payload)?;
204            let node = store.get_node(&p.id).await.map_err(to_kernel)?;
205            Ok(serde_json::to_value(node)?)
206        }
207        "node_query" => {
208            let p: NodeQueryPayload = serde_json::from_value(payload)?;
209            let mut q = NodeQuery::label(p.label);
210            q.property_eq = p.property_eq;
211            let nodes = q.run(store.as_ref()).await.map_err(to_kernel)?;
212            Ok(serde_json::to_value(nodes)?)
213        }
214        "edge_query" => {
215            let p: EdgeQueryPayload = serde_json::from_value(payload)?;
216            let direction = match p.direction.as_deref() {
217                Some("in") => crate::query::EdgeDirection::In,
218                Some("either") => crate::query::EdgeDirection::Either,
219                _ => crate::query::EdgeDirection::Out,
220            };
221            let q = EdgeQuery {
222                label: p.label,
223                direction,
224            };
225            let edges = q.run(store.as_ref(), &p.anchor).await.map_err(to_kernel)?;
226            Ok(serde_json::to_value(edges)?)
227        }
228        "traverse" => {
229            let p: TraversePayload = serde_json::from_value(payload)?;
230            let nodes = traverse(
231                store.as_ref(),
232                &p.start,
233                p.edge_label.as_deref(),
234                p.max_depth,
235            )
236            .await
237            .map_err(to_kernel)?;
238            Ok(serde_json::to_value(nodes)?)
239        }
240        "stats" => {
241            let (n, e) = store.stats().await.map_err(to_kernel)?;
242            Ok(serde_json::json!({"nodes": n, "edges": e}))
243        }
244        other => Err(KernelError::Other(anyhow::anyhow!(
245            "unknown graph method `{other}`"
246        ))),
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use oxide_k::bus::{Event, Message};
254    use serde_json::json;
255
256    #[tokio::test]
257    async fn bus_ingest_then_query() {
258        let mut module = GraphModule::new();
259        let bus = MessageBus::new();
260        let mut sub = bus.subscribe().await;
261        Module::init(&mut module, bus.clone()).await.unwrap();
262        Module::start(&mut module).await.unwrap();
263
264        bus.send_command(
265            "test",
266            Command::Invoke {
267                module_id: DEFAULT_MODULE_ID.into(),
268                method: "ingest".into(),
269                payload: json!({
270                    "resource": "pet",
271                    "record_id": "1",
272                    "source": "petstore",
273                    "payload": {"name": "Rex", "status": "available"}
274                }),
275            },
276        )
277        .await
278        .unwrap();
279
280        let mut saw_ingest = false;
281        for _ in 0..15 {
282            match tokio::time::timeout(std::time::Duration::from_millis(400), sub.receiver.recv())
283                .await
284            {
285                Ok(Some(env)) => {
286                    if let Message::Event(Event::Custom { kind, .. }) = env.message {
287                        if kind == "ingest.ok" {
288                            saw_ingest = true;
289                            break;
290                        }
291                    }
292                }
293                _ => break,
294            }
295        }
296        assert!(saw_ingest);
297
298        bus.send_command(
299            "test",
300            Command::Invoke {
301                module_id: DEFAULT_MODULE_ID.into(),
302                method: "node_query".into(),
303                payload: json!({"label": "pet", "property_eq": [["status", "available"]]}),
304            },
305        )
306        .await
307        .unwrap();
308
309        let mut saw_query = false;
310        for _ in 0..15 {
311            match tokio::time::timeout(std::time::Duration::from_millis(400), sub.receiver.recv())
312                .await
313            {
314                Ok(Some(env)) => {
315                    if let Message::Event(Event::Custom { kind, payload, .. }) = env.message {
316                        if kind == "node_query.ok" {
317                            let nodes = payload.as_array().unwrap();
318                            assert_eq!(nodes.len(), 1);
319                            assert_eq!(nodes[0]["id"], json!("pet:1"));
320                            saw_query = true;
321                            break;
322                        }
323                    }
324                }
325                _ => break,
326            }
327        }
328        assert!(saw_query);
329        Module::stop(&mut module).await.unwrap();
330    }
331}