Skip to main content

arrow_graph_core/
triple_store.rs

1//! SimpleTripleStore — lightweight Arrow-native triple store.
2//!
3//! Wraps [`ArrowGraphStore`] with a simplified API that doesn't require
4//! explicit namespace/layer on every operation.
5//!
6//! # Quick Start
7//!
8//! ```rust
9//! use arrow_graph_core::triple_store::SimpleTripleStore;
10//!
11//! let mut store = SimpleTripleStore::new();
12//! let id = store.add("Alice", "knows", "Bob", 0.9, "test").unwrap();
13//! assert_eq!(store.count(None, None, None), 1);
14//!
15//! let results = store.query(Some("Alice"), None, None).unwrap();
16//! assert_eq!(results.len(), 1);
17//! assert_eq!(results[0].subject, "Alice");
18//!
19//! store.remove(&id).unwrap();
20//! assert_eq!(store.count(None, None, None), 0);
21//! ```
22
23use crate::schema::col;
24use crate::store::{ArrowGraphStore, QuerySpec, StoreError, Triple};
25
26use arrow::array::{Array, Float64Array, RecordBatch, StringArray};
27use std::collections::HashMap;
28
29/// Default namespace for SimpleTripleStore operations.
30const DEFAULT_NAMESPACE: &str = "default";
31
32/// A retrieved triple with all metadata.
33#[derive(Debug, Clone)]
34pub struct StoredTriple {
35    pub id: String,
36    pub subject: String,
37    pub predicate: String,
38    pub object: String,
39    pub graph: Option<String>,
40    pub confidence: f64,
41    pub source: Option<String>,
42}
43
44/// Statistics about the store.
45#[derive(Debug, Clone)]
46pub struct StoreStats {
47    pub total_triples: usize,
48    pub unique_subjects: usize,
49    pub unique_predicates: usize,
50    pub unique_objects: usize,
51    pub by_source: HashMap<String, usize>,
52}
53
54/// Lightweight Arrow-native triple store.
55///
56/// Wraps `ArrowGraphStore` with defaults for namespace and layer,
57/// providing a simplified API.
58pub struct SimpleTripleStore {
59    inner: ArrowGraphStore,
60    namespace: String,
61    layer: Option<u8>,
62}
63
64impl SimpleTripleStore {
65    /// Create a new empty store with a single "default" namespace and layer 0.
66    pub fn new() -> Self {
67        Self {
68            inner: ArrowGraphStore::new(&[DEFAULT_NAMESPACE]),
69            namespace: DEFAULT_NAMESPACE.to_string(),
70            layer: Some(0),
71        }
72    }
73
74    /// Create with custom namespace and layer defaults.
75    pub fn with_defaults(namespace: &str, layer: Option<u8>) -> Self {
76        Self {
77            inner: ArrowGraphStore::new(&[namespace]),
78            namespace: namespace.to_string(),
79            layer,
80        }
81    }
82
83    /// Add a triple. Returns the generated triple ID.
84    pub fn add(
85        &mut self,
86        subject: &str,
87        predicate: &str,
88        object: &str,
89        confidence: f64,
90        source: &str,
91    ) -> Result<String, StoreError> {
92        let triple = Triple {
93            subject: subject.to_string(),
94            predicate: predicate.to_string(),
95            object: object.to_string(),
96            graph: None,
97            confidence: Some(confidence),
98            source_document: Some(source.to_string()),
99            source_chunk_id: None,
100            extracted_by: Some(source.to_string()),
101            caused_by: None,
102            derived_from: None,
103            consolidated_at: None,
104        };
105        self.inner.add_triple(&triple, &self.namespace, self.layer)
106    }
107
108    /// Add a batch of triples. Returns generated IDs.
109    pub fn add_batch(
110        &mut self,
111        triples: &[(&str, &str, &str, f64, &str)],
112    ) -> Result<Vec<String>, StoreError> {
113        let ts: Vec<Triple> = triples
114            .iter()
115            .map(|(s, p, o, conf, src)| Triple {
116                subject: s.to_string(),
117                predicate: p.to_string(),
118                object: o.to_string(),
119                graph: None,
120                confidence: Some(*conf),
121                source_document: Some(src.to_string()),
122                source_chunk_id: None,
123                extracted_by: Some(src.to_string()),
124                caused_by: None,
125                derived_from: None,
126                consolidated_at: None,
127            })
128            .collect();
129        self.inner.add_batch(&ts, &self.namespace, self.layer)
130    }
131
132    /// Remove a triple by ID. Returns true if found.
133    pub fn remove(&mut self, triple_id: &str) -> Result<bool, StoreError> {
134        self.inner.delete(triple_id)
135    }
136
137    /// Query triples matching pattern. None means wildcard (match all).
138    pub fn query(
139        &self,
140        subject: Option<&str>,
141        predicate: Option<&str>,
142        object: Option<&str>,
143    ) -> Result<Vec<StoredTriple>, StoreError> {
144        let spec = QuerySpec {
145            subject: subject.map(|s| s.to_string()),
146            predicate: predicate.map(|s| s.to_string()),
147            object: object.map(|s| s.to_string()),
148            namespace: Some(self.namespace.clone()),
149            ..Default::default()
150        };
151        let batches = self.inner.query(&spec)?;
152        Ok(batches_to_stored_triples(&batches))
153    }
154
155    /// Count triples matching pattern.
156    pub fn count(
157        &self,
158        subject: Option<&str>,
159        predicate: Option<&str>,
160        object: Option<&str>,
161    ) -> usize {
162        self.query(subject, predicate, object)
163            .map(|v| v.len())
164            .unwrap_or(0)
165    }
166
167    /// Get a single triple by ID.
168    pub fn get(&self, triple_id: &str) -> Option<StoredTriple> {
169        let spec = QuerySpec {
170            include_deleted: false,
171            ..Default::default()
172        };
173        let batches = self.inner.query(&spec).ok()?;
174        for batch in &batches {
175            let ids = batch
176                .column(col::TRIPLE_ID)
177                .as_any()
178                .downcast_ref::<StringArray>()?;
179            for i in 0..ids.len() {
180                if ids.value(i) == triple_id {
181                    return Some(extract_stored_triple(batch, i));
182                }
183            }
184        }
185        None
186    }
187
188    /// Update confidence on an existing triple. Returns true if found.
189    pub fn update_confidence(
190        &mut self,
191        triple_id: &str,
192        confidence: f64,
193    ) -> Result<bool, StoreError> {
194        let existing = match self.get(triple_id) {
195            Some(t) => t,
196            None => return Ok(false),
197        };
198        self.inner.delete(triple_id)?;
199        let triple = Triple {
200            subject: existing.subject,
201            predicate: existing.predicate,
202            object: existing.object,
203            graph: existing.graph,
204            confidence: Some(confidence),
205            source_document: existing.source.clone(),
206            source_chunk_id: None,
207            extracted_by: existing.source,
208            caused_by: None,
209            derived_from: None,
210            consolidated_at: None,
211        };
212        self.inner
213            .add_triple(&triple, &self.namespace, self.layer)?;
214        Ok(true)
215    }
216
217    /// Group and count triples by a field ("subject", "predicate", or "object").
218    pub fn group_by(&self, field: &str) -> Result<HashMap<String, usize>, StoreError> {
219        let col_idx = match field {
220            "subject" => col::SUBJECT,
221            "predicate" => col::PREDICATE,
222            "object" => col::OBJECT,
223            _ => {
224                return Err(StoreError::Arrow(
225                    arrow::error::ArrowError::InvalidArgumentError(format!(
226                        "invalid group_by field: {field}"
227                    )),
228                ));
229            }
230        };
231        let spec = QuerySpec {
232            namespace: Some(self.namespace.clone()),
233            ..Default::default()
234        };
235        let batches = self.inner.query(&spec)?;
236        let mut counts: HashMap<String, usize> = HashMap::new();
237        for batch in &batches {
238            let col_array = batch
239                .column(col_idx)
240                .as_any()
241                .downcast_ref::<StringArray>()
242                .expect("column must be StringArray");
243            for i in 0..col_array.len() {
244                *counts.entry(col_array.value(i).to_string()).or_insert(0) += 1;
245            }
246        }
247        Ok(counts)
248    }
249
250    /// Get store statistics.
251    pub fn stats(&self) -> StoreStats {
252        let spec = QuerySpec {
253            namespace: Some(self.namespace.clone()),
254            ..Default::default()
255        };
256        let batches = self.inner.query(&spec).unwrap_or_default();
257        let triples = batches_to_stored_triples(&batches);
258
259        let mut subjects = std::collections::HashSet::new();
260        let mut predicates = std::collections::HashSet::new();
261        let mut objects = std::collections::HashSet::new();
262        let mut by_source: HashMap<String, usize> = HashMap::new();
263
264        for t in &triples {
265            subjects.insert(t.subject.clone());
266            predicates.insert(t.predicate.clone());
267            objects.insert(t.object.clone());
268            if let Some(ref src) = t.source {
269                *by_source.entry(src.clone()).or_insert(0) += 1;
270            }
271        }
272
273        StoreStats {
274            total_triples: triples.len(),
275            unique_subjects: subjects.len(),
276            unique_predicates: predicates.len(),
277            unique_objects: objects.len(),
278            by_source,
279        }
280    }
281
282    /// Total number of triples.
283    pub fn len(&self) -> usize {
284        self.count(None, None, None)
285    }
286
287    /// Whether the store is empty.
288    pub fn is_empty(&self) -> bool {
289        self.len() == 0
290    }
291
292    /// Get a reference to the underlying ArrowGraphStore.
293    pub fn inner(&self) -> &ArrowGraphStore {
294        &self.inner
295    }
296
297    /// Get a mutable reference to the underlying ArrowGraphStore.
298    pub fn inner_mut(&mut self) -> &mut ArrowGraphStore {
299        &mut self.inner
300    }
301}
302
303impl Default for SimpleTripleStore {
304    fn default() -> Self {
305        Self::new()
306    }
307}
308
309// ── Helper functions ──────────────────────────────────────────────────
310
311/// Extract a [`StoredTriple`] from a RecordBatch at the given row index.
312pub fn extract_stored_triple(batch: &RecordBatch, idx: usize) -> StoredTriple {
313    let ids = batch
314        .column(col::TRIPLE_ID)
315        .as_any()
316        .downcast_ref::<StringArray>()
317        .expect("triple_id column");
318    let subjects = batch
319        .column(col::SUBJECT)
320        .as_any()
321        .downcast_ref::<StringArray>()
322        .expect("subject column");
323    let predicates = batch
324        .column(col::PREDICATE)
325        .as_any()
326        .downcast_ref::<StringArray>()
327        .expect("predicate column");
328    let objects = batch
329        .column(col::OBJECT)
330        .as_any()
331        .downcast_ref::<StringArray>()
332        .expect("object column");
333    let graphs = batch
334        .column(col::GRAPH)
335        .as_any()
336        .downcast_ref::<StringArray>()
337        .expect("graph column");
338    let confidences = batch
339        .column(col::CONFIDENCE)
340        .as_any()
341        .downcast_ref::<Float64Array>()
342        .expect("confidence column");
343    let sources = batch
344        .column(col::EXTRACTED_BY)
345        .as_any()
346        .downcast_ref::<StringArray>()
347        .expect("extracted_by column");
348
349    StoredTriple {
350        id: ids.value(idx).to_string(),
351        subject: subjects.value(idx).to_string(),
352        predicate: predicates.value(idx).to_string(),
353        object: objects.value(idx).to_string(),
354        graph: if graphs.is_null(idx) {
355            None
356        } else {
357            Some(graphs.value(idx).to_string())
358        },
359        confidence: if confidences.is_null(idx) {
360            1.0
361        } else {
362            confidences.value(idx)
363        },
364        source: if sources.is_null(idx) {
365            None
366        } else {
367            Some(sources.value(idx).to_string())
368        },
369    }
370}
371
372/// Convert a slice of RecordBatches into a Vec of [`StoredTriple`].
373pub fn batches_to_stored_triples(batches: &[RecordBatch]) -> Vec<StoredTriple> {
374    let mut result = Vec::new();
375    for batch in batches {
376        for i in 0..batch.num_rows() {
377            result.push(extract_stored_triple(batch, i));
378        }
379    }
380    result
381}
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386
387    #[test]
388    fn test_add_and_query() {
389        let mut store = SimpleTripleStore::new();
390        let id = store.add("Alice", "knows", "Bob", 0.9, "test").unwrap();
391        assert!(!id.is_empty());
392        assert_eq!(store.len(), 1);
393
394        let results = store.query(Some("Alice"), None, None).unwrap();
395        assert_eq!(results.len(), 1);
396        assert_eq!(results[0].subject, "Alice");
397        assert_eq!(results[0].predicate, "knows");
398        assert_eq!(results[0].object, "Bob");
399        assert!((results[0].confidence - 0.9).abs() < 1e-10);
400    }
401
402    #[test]
403    fn test_remove() {
404        let mut store = SimpleTripleStore::new();
405        let id = store.add("s", "p", "o", 1.0, "test").unwrap();
406        assert_eq!(store.len(), 1);
407
408        assert!(store.remove(&id).unwrap());
409        assert_eq!(store.len(), 0);
410    }
411
412    #[test]
413    fn test_query_wildcard() {
414        let mut store = SimpleTripleStore::new();
415        store.add("Alice", "knows", "Bob", 0.9, "test").unwrap();
416        store.add("Alice", "likes", "Carol", 0.8, "test").unwrap();
417        store.add("Bob", "knows", "Carol", 0.7, "test").unwrap();
418
419        assert_eq!(store.query(Some("Alice"), None, None).unwrap().len(), 2);
420        assert_eq!(store.query(None, Some("knows"), None).unwrap().len(), 2);
421        assert_eq!(store.query(None, None, Some("Carol")).unwrap().len(), 2);
422        assert_eq!(
423            store
424                .query(Some("Alice"), Some("knows"), None)
425                .unwrap()
426                .len(),
427            1
428        );
429        assert_eq!(store.query(None, None, None).unwrap().len(), 3);
430    }
431
432    #[test]
433    fn test_group_by() {
434        let mut store = SimpleTripleStore::new();
435        store.add("Alice", "knows", "Bob", 1.0, "test").unwrap();
436        store.add("Alice", "likes", "Carol", 1.0, "test").unwrap();
437        store.add("Bob", "knows", "Carol", 1.0, "test").unwrap();
438
439        let by_subj = store.group_by("subject").unwrap();
440        assert_eq!(by_subj["Alice"], 2);
441        assert_eq!(by_subj["Bob"], 1);
442    }
443
444    #[test]
445    fn test_stats() {
446        let mut store = SimpleTripleStore::new();
447        store.add("s1", "p1", "o1", 1.0, "src_a").unwrap();
448        store.add("s2", "p1", "o2", 1.0, "src_a").unwrap();
449        store.add("s1", "p2", "o1", 1.0, "src_b").unwrap();
450
451        let stats = store.stats();
452        assert_eq!(stats.total_triples, 3);
453        assert_eq!(stats.unique_subjects, 2);
454        assert_eq!(stats.unique_predicates, 2);
455        assert_eq!(stats.by_source["src_a"], 2);
456        assert_eq!(stats.by_source["src_b"], 1);
457    }
458
459    #[test]
460    fn test_batch_add() {
461        let mut store = SimpleTripleStore::new();
462        let ids = store
463            .add_batch(&[
464                ("s1", "p", "o1", 0.9, "batch"),
465                ("s2", "p", "o2", 0.8, "batch"),
466                ("s3", "p", "o3", 0.7, "batch"),
467            ])
468            .unwrap();
469        assert_eq!(ids.len(), 3);
470        assert_eq!(store.len(), 3);
471    }
472
473    #[test]
474    fn test_custom_namespace() {
475        let store = SimpleTripleStore::with_defaults("my_namespace", Some(5));
476        assert!(store.is_empty());
477    }
478}