Skip to main content

arrow_graph_core/
store.rs

1//! ArrowGraphStore — the core partitioned graph store.
2//!
3//! Holds triples partitioned by namespace (string keys), with an optional layer
4//! column for sub-partitioning. Supports add, query, delete (logical), and batch
5//! operations.
6//!
7//! # Namespace Partitioning
8//!
9//! Namespaces are arbitrary strings — you define them when constructing the store.
10//! This keeps the core generic: a knowledge graph might use `["world", "code", "self"]`,
11//! while a document store might use `["drafts", "published", "archived"]`.
12//!
13//! ```rust
14//! use arrow_graph_core::store::ArrowGraphStore;
15//!
16//! let mut store = ArrowGraphStore::new(&["world", "code", "self"]);
17//! ```
18
19use crate::schema::{col, triples_schema};
20
21use arrow::array::{
22    Array, BooleanArray, Float64Array, RecordBatch, StringArray, TimestampMillisecondArray,
23    UInt8Array,
24};
25use arrow::compute;
26use arrow::datatypes::SchemaRef;
27use std::collections::HashMap;
28use std::sync::Arc;
29
30/// Errors from store operations.
31#[derive(Debug, thiserror::Error)]
32pub enum StoreError {
33    #[error("Arrow error: {0}")]
34    Arrow(#[from] arrow::error::ArrowError),
35
36    #[error("Unknown namespace: {0}")]
37    UnknownNamespace(String),
38
39    #[error("Invalid layer: {0}")]
40    InvalidLayer(u8),
41
42    #[error("Triple not found: {0}")]
43    TripleNotFound(String),
44}
45
46pub type Result<T> = std::result::Result<T, StoreError>;
47
48/// A single triple to be added to the store.
49#[derive(Debug, Clone)]
50pub struct Triple {
51    pub subject: String,
52    pub predicate: String,
53    pub object: String,
54    pub graph: Option<String>,
55    pub confidence: Option<f64>,
56    pub source_document: Option<String>,
57    /// FK to ChunkTable for fine-grained provenance.
58    pub source_chunk_id: Option<String>,
59    pub extracted_by: Option<String>,
60    /// The triple_id of the triple that caused this one (causal chain).
61    pub caused_by: Option<String>,
62    /// The triple_id of the triple this was derived from.
63    pub derived_from: Option<String>,
64    /// Timestamp (ms since epoch) when this triple was consolidated.
65    pub consolidated_at: Option<i64>,
66}
67
68/// Filter specification for queries.
69#[derive(Debug, Default, Clone)]
70pub struct QuerySpec {
71    pub subject: Option<String>,
72    pub predicate: Option<String>,
73    pub object: Option<String>,
74    pub namespace: Option<String>,
75    pub layer: Option<u8>,
76    pub include_deleted: bool,
77}
78
79/// A node in a causal derivation chain.
80#[derive(Debug, Clone, PartialEq)]
81pub struct CausalNode {
82    pub triple_id: String,
83    pub caused_by: Option<String>,
84    pub derived_from: Option<String>,
85}
86
87/// The core Arrow-native graph store, partitioned by namespace.
88///
89/// Each namespace holds a vector of RecordBatches (appended over time).
90/// Queries filter by namespace first, then by column predicates.
91///
92/// Namespaces are arbitrary strings defined at construction time.
93pub struct ArrowGraphStore {
94    schema: SchemaRef,
95    /// Per-namespace storage: `Vec<RecordBatch>` (append-only within a partition).
96    partitions: HashMap<String, Vec<RecordBatch>>,
97    /// Ordered list of namespace names (for deterministic iteration).
98    namespace_names: Vec<String>,
99}
100
101impl ArrowGraphStore {
102    /// Create a new empty store with the given namespace partitions.
103    ///
104    /// # Example
105    ///
106    /// ```rust
107    /// use arrow_graph_core::store::ArrowGraphStore;
108    ///
109    /// // Knowledge graph with three partitions
110    /// let store = ArrowGraphStore::new(&["world", "code", "self"]);
111    ///
112    /// // Document store with different partitions
113    /// let store = ArrowGraphStore::new(&["drafts", "published"]);
114    /// ```
115    pub fn new(namespaces: &[&str]) -> Self {
116        let schema = Arc::new(triples_schema());
117        let mut partitions = HashMap::new();
118        let mut namespace_names = Vec::new();
119        for ns in namespaces {
120            partitions.insert(ns.to_string(), Vec::new());
121            namespace_names.push(ns.to_string());
122        }
123        ArrowGraphStore {
124            schema,
125            partitions,
126            namespace_names,
127        }
128    }
129
130    /// Get the triples schema.
131    pub fn schema(&self) -> &SchemaRef {
132        &self.schema
133    }
134
135    /// Get the list of namespace names.
136    pub fn namespaces(&self) -> &[String] {
137        &self.namespace_names
138    }
139
140    /// Add a single triple to the specified namespace and layer.
141    ///
142    /// The `layer` parameter is optional — pass `None` to use layer 0 (default).
143    pub fn add_triple(
144        &mut self,
145        triple: &Triple,
146        namespace: &str,
147        layer: Option<u8>,
148    ) -> Result<String> {
149        self.add_batch(std::slice::from_ref(triple), namespace, layer)
150            .map(|ids| ids.into_iter().next().unwrap())
151    }
152
153    /// Add a batch of triples to the specified namespace and layer.
154    /// Returns the generated triple IDs.
155    pub fn add_batch(
156        &mut self,
157        triples: &[Triple],
158        namespace: &str,
159        layer: Option<u8>,
160    ) -> Result<Vec<String>> {
161        if !self.partitions.contains_key(namespace) {
162            return Err(StoreError::UnknownNamespace(namespace.to_string()));
163        }
164
165        let n = triples.len();
166        if n == 0 {
167            return Ok(vec![]);
168        }
169
170        let now_ms = chrono::Utc::now().timestamp_millis();
171        let layer_val = layer.unwrap_or(0);
172
173        let ids: Vec<String> = (0..n).map(|_| uuid::Uuid::new_v4().to_string()).collect();
174
175        let subjects: Vec<&str> = triples.iter().map(|t| t.subject.as_str()).collect();
176        let predicates: Vec<&str> = triples.iter().map(|t| t.predicate.as_str()).collect();
177        let objects: Vec<&str> = triples.iter().map(|t| t.object.as_str()).collect();
178        let graphs: Vec<Option<&str>> = triples.iter().map(|t| t.graph.as_deref()).collect();
179        let ns_vals: Vec<&str> = vec![namespace; n];
180        let layer_vals: Vec<u8> = vec![layer_val; n];
181        let confidences: Vec<Option<f64>> = triples.iter().map(|t| t.confidence).collect();
182        let source_docs: Vec<Option<&str>> = triples
183            .iter()
184            .map(|t| t.source_document.as_deref())
185            .collect();
186        let source_chunks: Vec<Option<&str>> = triples
187            .iter()
188            .map(|t| t.source_chunk_id.as_deref())
189            .collect();
190        let extracted: Vec<Option<&str>> =
191            triples.iter().map(|t| t.extracted_by.as_deref()).collect();
192        let caused_by: Vec<Option<&str>> = triples.iter().map(|t| t.caused_by.as_deref()).collect();
193        let derived_from: Vec<Option<&str>> =
194            triples.iter().map(|t| t.derived_from.as_deref()).collect();
195        let consolidated_at: Vec<Option<i64>> = triples.iter().map(|t| t.consolidated_at).collect();
196        let timestamps: Vec<i64> = vec![now_ms; n];
197        let deleted: Vec<bool> = vec![false; n];
198        let id_strs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
199
200        let batch = RecordBatch::try_new(
201            self.schema.clone(),
202            vec![
203                Arc::new(StringArray::from(id_strs)),
204                Arc::new(StringArray::from(subjects)),
205                Arc::new(StringArray::from(predicates)),
206                Arc::new(StringArray::from(objects)),
207                Arc::new(StringArray::from(graphs)),
208                Arc::new(StringArray::from(ns_vals)),
209                Arc::new(UInt8Array::from(layer_vals)),
210                Arc::new(Float64Array::from(confidences)),
211                Arc::new(StringArray::from(source_docs)),
212                Arc::new(StringArray::from(source_chunks)),
213                Arc::new(StringArray::from(extracted)),
214                Arc::new(TimestampMillisecondArray::from(timestamps).with_timezone("UTC")),
215                Arc::new(StringArray::from(caused_by)),
216                Arc::new(StringArray::from(derived_from)),
217                Arc::new(TimestampMillisecondArray::from(consolidated_at).with_timezone("UTC")),
218                Arc::new(BooleanArray::from(deleted)),
219            ],
220        )?;
221
222        self.partitions.get_mut(namespace).unwrap().push(batch);
223
224        Ok(ids)
225    }
226
227    /// Query triples matching the given spec.
228    pub fn query(&self, spec: &QuerySpec) -> Result<Vec<RecordBatch>> {
229        let namespaces: Vec<&String> = match &spec.namespace {
230            Some(ns) => {
231                if !self.partitions.contains_key(ns) {
232                    return Err(StoreError::UnknownNamespace(ns.clone()));
233                }
234                vec![ns]
235            }
236            None => self.namespace_names.iter().collect(),
237        };
238
239        let mut results = Vec::new();
240
241        for ns in namespaces {
242            let batches = self.partitions.get(ns).unwrap();
243            for batch in batches {
244                let filtered = self.filter_batch(batch, spec)?;
245                if filtered.num_rows() > 0 {
246                    results.push(filtered);
247                }
248            }
249        }
250
251        Ok(results)
252    }
253
254    /// Total number of non-deleted triples across all namespaces.
255    pub fn len(&self) -> usize {
256        let spec = QuerySpec::default();
257        self.query(&spec)
258            .unwrap_or_default()
259            .iter()
260            .map(|b| b.num_rows())
261            .sum()
262    }
263
264    /// Whether the store has no non-deleted triples.
265    pub fn is_empty(&self) -> bool {
266        self.len() == 0
267    }
268
269    /// Total number of triples including deleted.
270    pub fn len_all(&self) -> usize {
271        self.partitions
272            .values()
273            .flat_map(|batches| batches.iter())
274            .map(|b| b.num_rows())
275            .sum()
276    }
277
278    /// Logically delete a triple by ID (sets deleted=true).
279    pub fn delete(&mut self, triple_id: &str) -> Result<bool> {
280        for batches in self.partitions.values_mut() {
281            for batch in batches.iter_mut() {
282                let id_col = batch
283                    .column(col::TRIPLE_ID)
284                    .as_any()
285                    .downcast_ref::<StringArray>()
286                    .expect("triple_id column must be StringArray");
287
288                let mut found_idx = None;
289                for i in 0..id_col.len() {
290                    if id_col.value(i) == triple_id {
291                        found_idx = Some(i);
292                        break;
293                    }
294                }
295
296                if let Some(idx) = found_idx {
297                    let del_col = batch
298                        .column(col::DELETED)
299                        .as_any()
300                        .downcast_ref::<BooleanArray>()
301                        .expect("deleted column must be BooleanArray");
302
303                    let mut new_del: Vec<bool> =
304                        (0..del_col.len()).map(|i| del_col.value(i)).collect();
305                    new_del[idx] = true;
306
307                    let mut columns: Vec<Arc<dyn Array>> = Vec::new();
308                    for c in 0..batch.num_columns() {
309                        if c == col::DELETED {
310                            columns.push(Arc::new(BooleanArray::from(new_del.clone())));
311                        } else {
312                            columns.push(batch.column(c).clone());
313                        }
314                    }
315
316                    *batch = RecordBatch::try_new(self.schema.clone(), columns)?;
317                    return Ok(true);
318                }
319            }
320        }
321        Ok(false)
322    }
323
324    /// Get all RecordBatches for a given namespace (including deleted triples).
325    pub fn get_namespace_batches(&self, namespace: &str) -> &[RecordBatch] {
326        self.partitions.get(namespace).map_or(&[], |v| v.as_slice())
327    }
328
329    /// Replace all data for a namespace (used by checkout/restore).
330    pub fn set_namespace_batches(&mut self, namespace: &str, batches: Vec<RecordBatch>) {
331        if let Some(existing) = self.partitions.get_mut(namespace) {
332            *existing = batches;
333        }
334    }
335
336    /// Follow the caused_by/derived_from chain from a triple to build a derivation graph.
337    ///
338    /// Returns a list of causal nodes representing the full ancestry of the given
339    /// triple. The first element is always the queried triple itself.
340    /// Traversal is breadth-first.
341    pub fn causal_chain(&self, triple_id: &str) -> Vec<CausalNode> {
342        let mut result = Vec::new();
343        let mut visited = std::collections::HashSet::new();
344        let mut queue = std::collections::VecDeque::new();
345        queue.push_back(triple_id.to_string());
346
347        // Build an index for efficient lookup
348        let mut index: HashMap<String, (Option<String>, Option<String>)> = HashMap::new();
349        for batches in self.partitions.values() {
350            for batch in batches {
351                let id_col = batch
352                    .column(col::TRIPLE_ID)
353                    .as_any()
354                    .downcast_ref::<StringArray>()
355                    .expect("triple_id column");
356                let caused_col = batch
357                    .column(col::CAUSED_BY)
358                    .as_any()
359                    .downcast_ref::<StringArray>()
360                    .expect("caused_by column");
361                let derived_col = batch
362                    .column(col::DERIVED_FROM)
363                    .as_any()
364                    .downcast_ref::<StringArray>()
365                    .expect("derived_from column");
366                let del_col = batch
367                    .column(col::DELETED)
368                    .as_any()
369                    .downcast_ref::<BooleanArray>()
370                    .expect("deleted column");
371
372                for i in 0..batch.num_rows() {
373                    if del_col.value(i) {
374                        continue;
375                    }
376                    let id = id_col.value(i).to_string();
377                    let caused = if caused_col.is_null(i) {
378                        None
379                    } else {
380                        Some(caused_col.value(i).to_string())
381                    };
382                    let derived = if derived_col.is_null(i) {
383                        None
384                    } else {
385                        Some(derived_col.value(i).to_string())
386                    };
387                    index.insert(id, (caused, derived));
388                }
389            }
390        }
391
392        while let Some(tid) = queue.pop_front() {
393            if !visited.insert(tid.clone()) {
394                continue;
395            }
396            if let Some((caused, derived)) = index.get(&tid) {
397                result.push(CausalNode {
398                    triple_id: tid.clone(),
399                    caused_by: caused.clone(),
400                    derived_from: derived.clone(),
401                });
402                if let Some(cb) = caused
403                    && !visited.contains(cb)
404                {
405                    queue.push_back(cb.clone());
406                }
407                if let Some(df) = derived
408                    && !visited.contains(df)
409                {
410                    queue.push_back(df.clone());
411                }
412            }
413        }
414
415        result
416    }
417
418    /// Clear all data from the store.
419    pub fn clear(&mut self) {
420        for batches in self.partitions.values_mut() {
421            batches.clear();
422        }
423    }
424
425    /// Filter a RecordBatch by the QuerySpec predicates.
426    fn filter_batch(&self, batch: &RecordBatch, spec: &QuerySpec) -> Result<RecordBatch> {
427        let n = batch.num_rows();
428        let mut mask = BooleanArray::from(vec![true; n]);
429
430        // Filter out deleted unless include_deleted
431        if !spec.include_deleted {
432            let del_col = batch
433                .column(col::DELETED)
434                .as_any()
435                .downcast_ref::<BooleanArray>()
436                .expect("deleted column must be BooleanArray");
437            let not_deleted = compute::not(del_col)?;
438            mask = compute::and(&mask, &not_deleted)?;
439        }
440
441        if let Some(ref subj) = spec.subject {
442            let c = batch
443                .column(col::SUBJECT)
444                .as_any()
445                .downcast_ref::<StringArray>()
446                .expect("subject column must be StringArray");
447            let eq = string_eq_scalar(c, subj);
448            mask = compute::and(&mask, &eq)?;
449        }
450
451        if let Some(ref pred) = spec.predicate {
452            let c = batch
453                .column(col::PREDICATE)
454                .as_any()
455                .downcast_ref::<StringArray>()
456                .expect("predicate column must be StringArray");
457            let eq = string_eq_scalar(c, pred);
458            mask = compute::and(&mask, &eq)?;
459        }
460
461        if let Some(ref obj) = spec.object {
462            let c = batch
463                .column(col::OBJECT)
464                .as_any()
465                .downcast_ref::<StringArray>()
466                .expect("object column must be StringArray");
467            let eq = string_eq_scalar(c, obj);
468            mask = compute::and(&mask, &eq)?;
469        }
470
471        if let Some(layer) = spec.layer {
472            let c = batch
473                .column(col::LAYER)
474                .as_any()
475                .downcast_ref::<UInt8Array>()
476                .expect("layer column must be UInt8Array");
477            let eq = u8_eq_scalar(c, layer);
478            mask = compute::and(&mask, &eq)?;
479        }
480
481        let filtered = compute::filter_record_batch(batch, &mask)?;
482        Ok(filtered)
483    }
484}
485
486/// Scalar string equality: returns BooleanArray where each element == value.
487fn string_eq_scalar(array: &StringArray, value: &str) -> BooleanArray {
488    let bools: Vec<bool> = (0..array.len()).map(|i| array.value(i) == value).collect();
489    BooleanArray::from(bools)
490}
491
492/// Scalar u8 equality.
493fn u8_eq_scalar(array: &UInt8Array, value: u8) -> BooleanArray {
494    let bools: Vec<bool> = (0..array.len()).map(|i| array.value(i) == value).collect();
495    BooleanArray::from(bools)
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    fn sample_triple(subj: &str, pred: &str, obj: &str) -> Triple {
503        Triple {
504            subject: subj.to_string(),
505            predicate: pred.to_string(),
506            object: obj.to_string(),
507            graph: None,
508            confidence: Some(0.9),
509            source_document: None,
510            source_chunk_id: None,
511            extracted_by: Some("test".to_string()),
512            caused_by: None,
513            derived_from: None,
514            consolidated_at: None,
515        }
516    }
517
518    #[test]
519    fn test_add_and_query_single() {
520        let mut store = ArrowGraphStore::new(&["world", "work"]);
521        let id = store
522            .add_triple(&sample_triple("s1", "p1", "o1"), "world", Some(1))
523            .unwrap();
524
525        assert!(!id.is_empty());
526        assert_eq!(store.len(), 1);
527
528        let results = store
529            .query(&QuerySpec {
530                subject: Some("s1".to_string()),
531                ..Default::default()
532            })
533            .unwrap();
534        let total: usize = results.iter().map(|b| b.num_rows()).sum();
535        assert_eq!(total, 1);
536    }
537
538    #[test]
539    fn test_namespace_isolation() {
540        let mut store = ArrowGraphStore::new(&["world", "work"]);
541
542        let world_triples: Vec<Triple> = (0..100)
543            .map(|i| sample_triple(&format!("w{i}"), "rdf:type", "Thing"))
544            .collect();
545        store.add_batch(&world_triples, "world", Some(1)).unwrap();
546
547        let work_triples: Vec<Triple> = (0..100)
548            .map(|i| sample_triple(&format!("k{i}"), "rdf:type", "Task"))
549            .collect();
550        store.add_batch(&work_triples, "work", Some(1)).unwrap();
551
552        let world_results = store
553            .query(&QuerySpec {
554                namespace: Some("world".to_string()),
555                ..Default::default()
556            })
557            .unwrap();
558        let world_count: usize = world_results.iter().map(|b| b.num_rows()).sum();
559        assert_eq!(world_count, 100);
560
561        let work_results = store
562            .query(&QuerySpec {
563                namespace: Some("work".to_string()),
564                ..Default::default()
565            })
566            .unwrap();
567        let work_count: usize = work_results.iter().map(|b| b.num_rows()).sum();
568        assert_eq!(work_count, 100);
569
570        assert_eq!(store.len(), 200);
571    }
572
573    #[test]
574    fn test_layer_query() {
575        let mut store = ArrowGraphStore::new(&["default"]);
576
577        store
578            .add_triple(&sample_triple("s1", "p1", "o1"), "default", Some(0))
579            .unwrap();
580        store
581            .add_triple(&sample_triple("s2", "p2", "o2"), "default", Some(1))
582            .unwrap();
583
584        let l0_results = store
585            .query(&QuerySpec {
586                layer: Some(0),
587                ..Default::default()
588            })
589            .unwrap();
590        let l0_count: usize = l0_results.iter().map(|b| b.num_rows()).sum();
591        assert_eq!(l0_count, 1);
592    }
593
594    #[test]
595    fn test_default_layer() {
596        let mut store = ArrowGraphStore::new(&["default"]);
597
598        // None layer defaults to 0
599        store
600            .add_triple(&sample_triple("s1", "p1", "o1"), "default", None)
601            .unwrap();
602
603        let results = store
604            .query(&QuerySpec {
605                layer: Some(0),
606                ..Default::default()
607            })
608            .unwrap();
609        let count: usize = results.iter().map(|b| b.num_rows()).sum();
610        assert_eq!(count, 1);
611    }
612
613    #[test]
614    fn test_unknown_namespace_error() {
615        let mut store = ArrowGraphStore::new(&["world"]);
616        let result = store.add_triple(&sample_triple("s", "p", "o"), "unknown", None);
617        assert!(matches!(result, Err(StoreError::UnknownNamespace(_))));
618    }
619
620    #[test]
621    fn test_logical_delete() {
622        let mut store = ArrowGraphStore::new(&["world"]);
623        let id = store
624            .add_triple(&sample_triple("s1", "p1", "o1"), "world", Some(1))
625            .unwrap();
626
627        assert_eq!(store.len(), 1);
628        assert!(store.delete(&id).unwrap());
629        assert_eq!(store.len(), 0);
630        assert_eq!(store.len_all(), 1); // Still physically present
631    }
632
633    #[test]
634    fn test_batch_add_performance() {
635        let mut store = ArrowGraphStore::new(&["world"]);
636
637        let triples: Vec<Triple> = (0..10_000)
638            .map(|i| sample_triple(&format!("s{i}"), "rdf:type", "Entity"))
639            .collect();
640
641        let start = std::time::Instant::now();
642        store.add_batch(&triples, "world", Some(1)).unwrap();
643        let elapsed = start.elapsed();
644
645        assert_eq!(store.len(), 10_000);
646        assert!(
647            elapsed.as_millis() < 100,
648            "Batch add took too long: {:?}",
649            elapsed
650        );
651    }
652
653    #[test]
654    fn test_custom_namespaces() {
655        let mut store = ArrowGraphStore::new(&["drafts", "published", "archived"]);
656        assert_eq!(store.namespaces().len(), 3);
657
658        store
659            .add_triple(&sample_triple("doc1", "status", "draft"), "drafts", None)
660            .unwrap();
661        store
662            .add_triple(&sample_triple("doc2", "status", "live"), "published", None)
663            .unwrap();
664
665        assert_eq!(store.len(), 2);
666
667        let drafts = store
668            .query(&QuerySpec {
669                namespace: Some("drafts".to_string()),
670                ..Default::default()
671            })
672            .unwrap();
673        assert_eq!(drafts.iter().map(|b| b.num_rows()).sum::<usize>(), 1);
674    }
675
676    #[test]
677    fn test_causal_chain_linear() {
678        let mut store = ArrowGraphStore::new(&["world"]);
679
680        let t0 = Triple {
681            caused_by: None,
682            derived_from: None,
683            ..sample_triple("s0", "p", "o0")
684        };
685        let id0 = store.add_triple(&t0, "world", Some(1)).unwrap();
686
687        let t1 = Triple {
688            caused_by: Some(id0.clone()),
689            derived_from: None,
690            ..sample_triple("s1", "p", "o1")
691        };
692        let id1 = store.add_triple(&t1, "world", Some(1)).unwrap();
693
694        let t2 = Triple {
695            caused_by: Some(id1.clone()),
696            derived_from: None,
697            ..sample_triple("s2", "p", "o2")
698        };
699        let id2 = store.add_triple(&t2, "world", Some(1)).unwrap();
700
701        let chain = store.causal_chain(&id2);
702        assert_eq!(chain.len(), 3);
703        assert_eq!(chain[0].triple_id, id2);
704        assert_eq!(chain[0].caused_by, Some(id1.clone()));
705        assert_eq!(chain[1].triple_id, id1);
706        assert_eq!(chain[1].caused_by, Some(id0.clone()));
707        assert_eq!(chain[2].triple_id, id0);
708        assert_eq!(chain[2].caused_by, None);
709    }
710
711    #[test]
712    fn test_causal_chain_nonexistent_triple() {
713        let store = ArrowGraphStore::new(&["world"]);
714        let chain = store.causal_chain("nonexistent");
715        assert!(chain.is_empty());
716    }
717}