Skip to main content

nusy_arrow_core/
store.rs

1//! ArrowGraphStore — the core partitioned graph store.
2//!
3//! Holds triples partitioned by namespace, with Y-layer as a column filter.
4//! Supports add, query, delete (logical), and batch operations.
5
6use crate::namespace::Namespace;
7use crate::schema::{col, triples_schema};
8use crate::y_layer::YLayer;
9
10use arrow::array::{
11    Array, BooleanArray, Float64Array, RecordBatch, StringArray, TimestampMillisecondArray,
12    UInt8Array,
13};
14use arrow::compute;
15use arrow::datatypes::SchemaRef;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19/// Errors from store operations.
20#[derive(Debug, thiserror::Error)]
21pub enum StoreError {
22    #[error("Arrow error: {0}")]
23    Arrow(#[from] arrow::error::ArrowError),
24
25    #[error("Unknown namespace: {0}")]
26    UnknownNamespace(String),
27
28    #[error("Invalid Y-layer: {0}")]
29    InvalidYLayer(u8),
30
31    #[error("Triple not found: {0}")]
32    TripleNotFound(String),
33}
34
35pub type Result<T> = std::result::Result<T, StoreError>;
36
37/// A single triple to be added to the store.
38#[derive(Debug, Clone)]
39pub struct Triple {
40    pub subject: String,
41    pub predicate: String,
42    pub object: String,
43    pub graph: Option<String>,
44    pub confidence: Option<f64>,
45    pub source_document: Option<String>,
46    /// FK to ChunkTable for fine-grained Y0 provenance.
47    pub source_chunk_id: Option<String>,
48    pub extracted_by: Option<String>,
49    /// The triple_id of the triple that caused this one (causal chain).
50    pub caused_by: Option<String>,
51    /// The triple_id of the triple this was derived from.
52    pub derived_from: Option<String>,
53    /// Timestamp (ms since epoch) when this triple was consolidated.
54    pub consolidated_at: Option<i64>,
55    /// Certifiability class: "symbolic" (graph-backed, provable),
56    /// "neural" (LLM-generated, probabilistic), "co-voted" (both agreed).
57    /// EX-3570: Tags triples for PAR tracking and routing decisions.
58    pub certifiability_class: Option<String>,
59}
60
61/// Filter specification for queries.
62#[derive(Debug, Default, Clone)]
63pub struct QuerySpec {
64    pub subject: Option<String>,
65    pub predicate: Option<String>,
66    pub object: Option<String>,
67    pub namespace: Option<Namespace>,
68    pub y_layer: Option<YLayer>,
69    pub include_deleted: bool,
70}
71
72/// A node in a causal derivation chain.
73#[derive(Debug, Clone, PartialEq)]
74pub struct CausalNode {
75    pub triple_id: String,
76    pub caused_by: Option<String>,
77    pub derived_from: Option<String>,
78}
79
80/// The core Arrow-native graph store, partitioned by namespace.
81///
82/// Each namespace holds a vector of RecordBatches (appended over time).
83/// Queries filter by namespace first, then by column predicates.
84pub struct ArrowGraphStore {
85    schema: SchemaRef,
86    /// Per-namespace storage: Vec<RecordBatch> (append-only within a partition).
87    partitions: HashMap<Namespace, Vec<RecordBatch>>,
88}
89
90impl ArrowGraphStore {
91    /// Create a new empty store.
92    pub fn new() -> Self {
93        let schema = Arc::new(triples_schema());
94        let mut partitions = HashMap::new();
95        for ns in Namespace::ALL {
96            partitions.insert(ns, Vec::new());
97        }
98        ArrowGraphStore { schema, partitions }
99    }
100
101    /// Get the triples schema.
102    pub fn schema(&self) -> &SchemaRef {
103        &self.schema
104    }
105
106    /// Add a single triple to the specified namespace and Y-layer.
107    pub fn add_triple(
108        &mut self,
109        triple: &Triple,
110        namespace: Namespace,
111        y_layer: YLayer,
112    ) -> Result<String> {
113        self.add_batch(std::slice::from_ref(triple), namespace, y_layer)
114            .map(|ids| ids.into_iter().next().unwrap())
115    }
116
117    /// Add a batch of triples to the specified namespace and Y-layer.
118    /// Returns the generated triple IDs.
119    pub fn add_batch(
120        &mut self,
121        triples: &[Triple],
122        namespace: Namespace,
123        y_layer: YLayer,
124    ) -> Result<Vec<String>> {
125        let n = triples.len();
126        if n == 0 {
127            return Ok(vec![]);
128        }
129
130        let now_ms = chrono::Utc::now().timestamp_millis();
131        let ns_str = namespace.as_str();
132        let layer_val = y_layer.as_u8();
133
134        let ids: Vec<String> = (0..n).map(|_| uuid::Uuid::new_v4().to_string()).collect();
135
136        let subjects: Vec<&str> = triples.iter().map(|t| t.subject.as_str()).collect();
137        let predicates: Vec<&str> = triples.iter().map(|t| t.predicate.as_str()).collect();
138        let objects: Vec<&str> = triples.iter().map(|t| t.object.as_str()).collect();
139        let graphs: Vec<Option<&str>> = triples.iter().map(|t| t.graph.as_deref()).collect();
140        let ns_vals: Vec<&str> = vec![ns_str; n];
141        let layer_vals: Vec<u8> = vec![layer_val; n];
142        let confidences: Vec<Option<f64>> = triples.iter().map(|t| t.confidence).collect();
143        let source_docs: Vec<Option<&str>> = triples
144            .iter()
145            .map(|t| t.source_document.as_deref())
146            .collect();
147        let source_chunks: Vec<Option<&str>> = triples
148            .iter()
149            .map(|t| t.source_chunk_id.as_deref())
150            .collect();
151        let extracted: Vec<Option<&str>> =
152            triples.iter().map(|t| t.extracted_by.as_deref()).collect();
153        let caused_by: Vec<Option<&str>> = triples.iter().map(|t| t.caused_by.as_deref()).collect();
154        let derived_from: Vec<Option<&str>> =
155            triples.iter().map(|t| t.derived_from.as_deref()).collect();
156        let consolidated_at: Vec<Option<i64>> = triples.iter().map(|t| t.consolidated_at).collect();
157        let timestamps: Vec<i64> = vec![now_ms; n];
158        let deleted: Vec<bool> = vec![false; n];
159        let certifiability_class: Vec<Option<&str>> = triples
160            .iter()
161            .map(|t| t.certifiability_class.as_deref())
162            .collect();
163        let id_strs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
164
165        let batch = RecordBatch::try_new(
166            self.schema.clone(),
167            vec![
168                Arc::new(StringArray::from(id_strs)),
169                Arc::new(StringArray::from(subjects)),
170                Arc::new(StringArray::from(predicates)),
171                Arc::new(StringArray::from(objects)),
172                Arc::new(StringArray::from(graphs)),
173                Arc::new(StringArray::from(ns_vals)),
174                Arc::new(UInt8Array::from(layer_vals)),
175                Arc::new(Float64Array::from(confidences)),
176                Arc::new(StringArray::from(source_docs)),
177                Arc::new(StringArray::from(source_chunks)),
178                Arc::new(StringArray::from(extracted)),
179                Arc::new(TimestampMillisecondArray::from(timestamps).with_timezone("UTC")),
180                Arc::new(StringArray::from(caused_by)),
181                Arc::new(StringArray::from(derived_from)),
182                Arc::new(TimestampMillisecondArray::from(consolidated_at).with_timezone("UTC")),
183                Arc::new(BooleanArray::from(deleted)),
184                Arc::new(StringArray::from(certifiability_class)),
185            ],
186        )?;
187
188        self.partitions.get_mut(&namespace).unwrap().push(batch);
189
190        Ok(ids)
191    }
192
193    /// Query triples matching the given spec.
194    pub fn query(&self, spec: &QuerySpec) -> Result<Vec<RecordBatch>> {
195        let namespaces: Vec<Namespace> = match spec.namespace {
196            Some(ns) => vec![ns],
197            None => Namespace::ALL.to_vec(),
198        };
199
200        let mut results = Vec::new();
201
202        for ns in namespaces {
203            let batches = self.partitions.get(&ns).unwrap();
204            for batch in batches {
205                let filtered = self.filter_batch(batch, spec)?;
206                if filtered.num_rows() > 0 {
207                    results.push(filtered);
208                }
209            }
210        }
211
212        Ok(results)
213    }
214
215    /// Total number of non-deleted triples across all namespaces.
216    pub fn len(&self) -> usize {
217        let spec = QuerySpec::default();
218        self.query(&spec)
219            .unwrap_or_default()
220            .iter()
221            .map(|b| b.num_rows())
222            .sum()
223    }
224
225    /// Whether the store has no non-deleted triples.
226    pub fn is_empty(&self) -> bool {
227        self.len() == 0
228    }
229
230    /// Total number of triples including deleted.
231    pub fn len_all(&self) -> usize {
232        self.partitions
233            .values()
234            .flat_map(|batches| batches.iter())
235            .map(|b| b.num_rows())
236            .sum()
237    }
238
239    /// Logically delete a triple by ID (sets deleted=true).
240    pub fn delete(&mut self, triple_id: &str) -> Result<bool> {
241        for batches in self.partitions.values_mut() {
242            for batch in batches.iter_mut() {
243                let id_col = batch
244                    .column(col::TRIPLE_ID)
245                    .as_any()
246                    .downcast_ref::<StringArray>()
247                    .expect("triple_id column must be StringArray");
248
249                let mut found_idx = None;
250                for i in 0..id_col.len() {
251                    if id_col.value(i) == triple_id {
252                        found_idx = Some(i);
253                        break;
254                    }
255                }
256
257                if let Some(idx) = found_idx {
258                    // Rebuild the batch with the deleted flag set
259                    let del_col = batch
260                        .column(col::DELETED)
261                        .as_any()
262                        .downcast_ref::<BooleanArray>()
263                        .expect("deleted column must be BooleanArray");
264
265                    let mut new_del: Vec<bool> =
266                        (0..del_col.len()).map(|i| del_col.value(i)).collect();
267                    new_del[idx] = true;
268
269                    let mut columns: Vec<Arc<dyn Array>> = Vec::new();
270                    for c in 0..batch.num_columns() {
271                        if c == col::DELETED {
272                            columns.push(Arc::new(BooleanArray::from(new_del.clone())));
273                        } else {
274                            columns.push(batch.column(c).clone());
275                        }
276                    }
277
278                    *batch = RecordBatch::try_new(self.schema.clone(), columns)?;
279                    return Ok(true);
280                }
281            }
282        }
283        Ok(false)
284    }
285
286    /// Get all RecordBatches for a given namespace (including deleted triples).
287    pub fn get_namespace_batches(&self, namespace: Namespace) -> &[RecordBatch] {
288        self.partitions
289            .get(&namespace)
290            .map_or(&[], |v| v.as_slice())
291    }
292
293    /// Replace all data for a namespace (used by checkout/restore).
294    pub fn set_namespace_batches(&mut self, namespace: Namespace, batches: Vec<RecordBatch>) {
295        self.partitions.insert(namespace, batches);
296    }
297
298    /// Follow the caused_by/derived_from chain from a triple to build a derivation graph.
299    ///
300    /// Returns a list of (triple_id, caused_by, derived_from) tuples representing
301    /// the full causal ancestry of the given triple. The first element is always the
302    /// queried triple itself. Traversal is breadth-first, following both `caused_by`
303    /// and `derived_from` links.
304    pub fn causal_chain(&self, triple_id: &str) -> Vec<CausalNode> {
305        let mut result = Vec::new();
306        let mut visited = std::collections::HashSet::new();
307        let mut queue = std::collections::VecDeque::new();
308        queue.push_back(triple_id.to_string());
309
310        // Build an index of triple_id → (caused_by, derived_from) for efficient lookup
311        let mut index: HashMap<String, (Option<String>, Option<String>)> = HashMap::new();
312        for batches in self.partitions.values() {
313            for batch in batches {
314                let id_col = batch
315                    .column(col::TRIPLE_ID)
316                    .as_any()
317                    .downcast_ref::<StringArray>()
318                    .expect("triple_id column");
319                let caused_col = batch
320                    .column(col::CAUSED_BY)
321                    .as_any()
322                    .downcast_ref::<StringArray>()
323                    .expect("caused_by column");
324                let derived_col = batch
325                    .column(col::DERIVED_FROM)
326                    .as_any()
327                    .downcast_ref::<StringArray>()
328                    .expect("derived_from column");
329                let del_col = batch
330                    .column(col::DELETED)
331                    .as_any()
332                    .downcast_ref::<BooleanArray>()
333                    .expect("deleted column");
334
335                for i in 0..batch.num_rows() {
336                    if del_col.value(i) {
337                        continue;
338                    }
339                    let id = id_col.value(i).to_string();
340                    let caused = if caused_col.is_null(i) {
341                        None
342                    } else {
343                        Some(caused_col.value(i).to_string())
344                    };
345                    let derived = if derived_col.is_null(i) {
346                        None
347                    } else {
348                        Some(derived_col.value(i).to_string())
349                    };
350                    index.insert(id, (caused, derived));
351                }
352            }
353        }
354
355        while let Some(tid) = queue.pop_front() {
356            if !visited.insert(tid.clone()) {
357                continue;
358            }
359            if let Some((caused, derived)) = index.get(&tid) {
360                result.push(CausalNode {
361                    triple_id: tid.clone(),
362                    caused_by: caused.clone(),
363                    derived_from: derived.clone(),
364                });
365                if let Some(cb) = caused
366                    && !visited.contains(cb)
367                {
368                    queue.push_back(cb.clone());
369                }
370                if let Some(df) = derived
371                    && !visited.contains(df)
372                {
373                    queue.push_back(df.clone());
374                }
375            }
376        }
377
378        result
379    }
380
381    /// Clear all data from the store.
382    pub fn clear(&mut self) {
383        for batches in self.partitions.values_mut() {
384            batches.clear();
385        }
386    }
387
388    /// Filter a RecordBatch by the QuerySpec predicates.
389    fn filter_batch(&self, batch: &RecordBatch, spec: &QuerySpec) -> Result<RecordBatch> {
390        let n = batch.num_rows();
391        let mut mask = BooleanArray::from(vec![true; n]);
392
393        // Filter out deleted unless include_deleted
394        if !spec.include_deleted {
395            let del_col = batch
396                .column(col::DELETED)
397                .as_any()
398                .downcast_ref::<BooleanArray>()
399                .expect("deleted column must be BooleanArray");
400            let not_deleted = compute::not(del_col)?;
401            mask = compute::and(&mask, &not_deleted)?;
402        }
403
404        // Filter by subject
405        if let Some(ref subj) = spec.subject {
406            let c = batch
407                .column(col::SUBJECT)
408                .as_any()
409                .downcast_ref::<StringArray>()
410                .expect("subject column must be StringArray");
411            let eq = string_eq_scalar(c, subj);
412            mask = compute::and(&mask, &eq)?;
413        }
414
415        // Filter by predicate
416        if let Some(ref pred) = spec.predicate {
417            let c = batch
418                .column(col::PREDICATE)
419                .as_any()
420                .downcast_ref::<StringArray>()
421                .expect("predicate column must be StringArray");
422            let eq = string_eq_scalar(c, pred);
423            mask = compute::and(&mask, &eq)?;
424        }
425
426        // Filter by object
427        if let Some(ref obj) = spec.object {
428            let c = batch
429                .column(col::OBJECT)
430                .as_any()
431                .downcast_ref::<StringArray>()
432                .expect("object column must be StringArray");
433            let eq = string_eq_scalar(c, obj);
434            mask = compute::and(&mask, &eq)?;
435        }
436
437        // Filter by Y-layer
438        if let Some(layer) = spec.y_layer {
439            let c = batch
440                .column(col::Y_LAYER)
441                .as_any()
442                .downcast_ref::<UInt8Array>()
443                .expect("y_layer column must be UInt8Array");
444            let eq = u8_eq_scalar(c, layer.as_u8());
445            mask = compute::and(&mask, &eq)?;
446        }
447
448        let filtered = compute::filter_record_batch(batch, &mask)?;
449        Ok(filtered)
450    }
451}
452
453impl Default for ArrowGraphStore {
454    fn default() -> Self {
455        Self::new()
456    }
457}
458
459/// Scalar string equality: returns BooleanArray where each element == value.
460fn string_eq_scalar(array: &StringArray, value: &str) -> BooleanArray {
461    let bools: Vec<bool> = (0..array.len()).map(|i| array.value(i) == value).collect();
462    BooleanArray::from(bools)
463}
464
465/// Scalar u8 equality.
466fn u8_eq_scalar(array: &UInt8Array, value: u8) -> BooleanArray {
467    let bools: Vec<bool> = (0..array.len()).map(|i| array.value(i) == value).collect();
468    BooleanArray::from(bools)
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    fn sample_triple(subj: &str, pred: &str, obj: &str) -> Triple {
476        Triple {
477            subject: subj.to_string(),
478            predicate: pred.to_string(),
479            object: obj.to_string(),
480            graph: None,
481            confidence: Some(0.9),
482            source_document: None,
483            source_chunk_id: None,
484            extracted_by: Some("test".to_string()),
485            caused_by: None,
486            derived_from: None,
487            consolidated_at: None,
488            certifiability_class: None,
489        }
490    }
491
492    #[test]
493    fn test_add_and_query_single() {
494        let mut store = ArrowGraphStore::new();
495        let id = store
496            .add_triple(
497                &sample_triple("s1", "p1", "o1"),
498                Namespace::World,
499                YLayer::Semantic,
500            )
501            .unwrap();
502
503        assert!(!id.is_empty());
504        assert_eq!(store.len(), 1);
505
506        let results = store
507            .query(&QuerySpec {
508                subject: Some("s1".to_string()),
509                ..Default::default()
510            })
511            .unwrap();
512        let total: usize = results.iter().map(|b| b.num_rows()).sum();
513        assert_eq!(total, 1);
514    }
515
516    #[test]
517    fn test_namespace_isolation() {
518        let mut store = ArrowGraphStore::new();
519
520        // Add 100 triples to world
521        let world_triples: Vec<Triple> = (0..100)
522            .map(|i| sample_triple(&format!("w{i}"), "rdf:type", "Thing"))
523            .collect();
524        store
525            .add_batch(&world_triples, Namespace::World, YLayer::Semantic)
526            .unwrap();
527
528        // Add 100 triples to work
529        let work_triples: Vec<Triple> = (0..100)
530            .map(|i| sample_triple(&format!("k{i}"), "rdf:type", "Task"))
531            .collect();
532        store
533            .add_batch(&work_triples, Namespace::Work, YLayer::Semantic)
534            .unwrap();
535
536        // Query world — should return exactly 100
537        let world_results = store
538            .query(&QuerySpec {
539                namespace: Some(Namespace::World),
540                ..Default::default()
541            })
542            .unwrap();
543        let world_count: usize = world_results.iter().map(|b| b.num_rows()).sum();
544        assert_eq!(world_count, 100);
545
546        // Query work — should return exactly 100
547        let work_results = store
548            .query(&QuerySpec {
549                namespace: Some(Namespace::Work),
550                ..Default::default()
551            })
552            .unwrap();
553        let work_count: usize = work_results.iter().map(|b| b.num_rows()).sum();
554        assert_eq!(work_count, 100);
555
556        // Total
557        assert_eq!(store.len(), 200);
558    }
559
560    #[test]
561    fn test_ylayer_query() {
562        let mut store = ArrowGraphStore::new();
563
564        store
565            .add_triple(
566                &sample_triple("s1", "p1", "o1"),
567                Namespace::World,
568                YLayer::Prose,
569            )
570            .unwrap();
571        store
572            .add_triple(
573                &sample_triple("s2", "p2", "o2"),
574                Namespace::World,
575                YLayer::Semantic,
576            )
577            .unwrap();
578
579        let y0_results = store
580            .query(&QuerySpec {
581                y_layer: Some(YLayer::Prose),
582                ..Default::default()
583            })
584            .unwrap();
585        let y0_count: usize = y0_results.iter().map(|b| b.num_rows()).sum();
586        assert_eq!(y0_count, 1);
587    }
588
589    #[test]
590    fn test_logical_delete() {
591        let mut store = ArrowGraphStore::new();
592        let id = store
593            .add_triple(
594                &sample_triple("s1", "p1", "o1"),
595                Namespace::World,
596                YLayer::Semantic,
597            )
598            .unwrap();
599
600        assert_eq!(store.len(), 1);
601        assert!(store.delete(&id).unwrap());
602        assert_eq!(store.len(), 0);
603        assert_eq!(store.len_all(), 1); // Still physically present
604    }
605
606    #[test]
607    fn test_batch_add_performance() {
608        let mut store = ArrowGraphStore::new();
609
610        let triples: Vec<Triple> = (0..10_000)
611            .map(|i| sample_triple(&format!("s{i}"), "rdf:type", "Entity"))
612            .collect();
613
614        let start = std::time::Instant::now();
615        store
616            .add_batch(&triples, Namespace::World, YLayer::Semantic)
617            .unwrap();
618        let elapsed = start.elapsed();
619
620        assert_eq!(store.len(), 10_000);
621        // Should be well under 10ms for batch add
622        assert!(
623            elapsed.as_millis() < 100,
624            "Batch add took too long: {:?}",
625            elapsed
626        );
627    }
628
629    #[test]
630    fn test_causal_chain_linear() {
631        let mut store = ArrowGraphStore::new();
632
633        // Create a chain: t0 → t1 → t2 (each caused_by the previous)
634        let t0 = Triple {
635            subject: "s0".to_string(),
636            predicate: "p".to_string(),
637            object: "o0".to_string(),
638            caused_by: None,
639            derived_from: None,
640            ..sample_triple("s0", "p", "o0")
641        };
642        let id0 = store
643            .add_triple(&t0, Namespace::World, YLayer::Semantic)
644            .unwrap();
645
646        let t1 = Triple {
647            subject: "s1".to_string(),
648            predicate: "p".to_string(),
649            object: "o1".to_string(),
650            caused_by: Some(id0.clone()),
651            derived_from: None,
652            ..sample_triple("s1", "p", "o1")
653        };
654        let id1 = store
655            .add_triple(&t1, Namespace::World, YLayer::Semantic)
656            .unwrap();
657
658        let t2 = Triple {
659            subject: "s2".to_string(),
660            predicate: "p".to_string(),
661            object: "o2".to_string(),
662            caused_by: Some(id1.clone()),
663            derived_from: None,
664            ..sample_triple("s2", "p", "o2")
665        };
666        let id2 = store
667            .add_triple(&t2, Namespace::World, YLayer::Semantic)
668            .unwrap();
669
670        // Causal chain from t2 should traverse t2 → t1 → t0
671        let chain = store.causal_chain(&id2);
672        assert_eq!(chain.len(), 3);
673        assert_eq!(chain[0].triple_id, id2);
674        assert_eq!(chain[0].caused_by, Some(id1.clone()));
675        assert_eq!(chain[1].triple_id, id1);
676        assert_eq!(chain[1].caused_by, Some(id0.clone()));
677        assert_eq!(chain[2].triple_id, id0);
678        assert_eq!(chain[2].caused_by, None);
679    }
680
681    #[test]
682    fn test_causal_chain_with_derived_from() {
683        let mut store = ArrowGraphStore::new();
684
685        let t0 = Triple {
686            subject: "base".to_string(),
687            predicate: "p".to_string(),
688            object: "original".to_string(),
689            caused_by: None,
690            derived_from: None,
691            ..sample_triple("base", "p", "original")
692        };
693        let id0 = store
694            .add_triple(&t0, Namespace::World, YLayer::Reasoning)
695            .unwrap();
696
697        let t1 = Triple {
698            subject: "derived".to_string(),
699            predicate: "p".to_string(),
700            object: "derived_val".to_string(),
701            caused_by: None,
702            derived_from: Some(id0.clone()),
703            ..sample_triple("derived", "p", "derived_val")
704        };
705        let id1 = store
706            .add_triple(&t1, Namespace::World, YLayer::Reasoning)
707            .unwrap();
708
709        let chain = store.causal_chain(&id1);
710        assert_eq!(chain.len(), 2);
711        assert_eq!(chain[0].derived_from, Some(id0.clone()));
712        assert_eq!(chain[1].triple_id, id0);
713    }
714
715    #[test]
716    fn test_causal_chain_nonexistent_triple() {
717        let store = ArrowGraphStore::new();
718        let chain = store.causal_chain("nonexistent");
719        assert!(chain.is_empty());
720    }
721}