1use crate::namespace::Namespace;
7use crate::schema::{col, triples_schema};
8use crate::y_layer::YLayer;
9
10use arrow::array::{
11 Array, BooleanArray, Float64Array, RecordBatch, StringArray, TimestampMillisecondArray,
12 UInt8Array,
13};
14use arrow::compute;
15use arrow::datatypes::SchemaRef;
16use std::collections::HashMap;
17use std::sync::Arc;
18
19#[derive(Debug, thiserror::Error)]
21pub enum StoreError {
22 #[error("Arrow error: {0}")]
23 Arrow(#[from] arrow::error::ArrowError),
24
25 #[error("Unknown namespace: {0}")]
26 UnknownNamespace(String),
27
28 #[error("Invalid Y-layer: {0}")]
29 InvalidYLayer(u8),
30
31 #[error("Triple not found: {0}")]
32 TripleNotFound(String),
33}
34
35pub type Result<T> = std::result::Result<T, StoreError>;
36
37#[derive(Debug, Clone)]
39pub struct Triple {
40 pub subject: String,
41 pub predicate: String,
42 pub object: String,
43 pub graph: Option<String>,
44 pub confidence: Option<f64>,
45 pub source_document: Option<String>,
46 pub source_chunk_id: Option<String>,
48 pub extracted_by: Option<String>,
49 pub caused_by: Option<String>,
51 pub derived_from: Option<String>,
53 pub consolidated_at: Option<i64>,
55 pub certifiability_class: Option<String>,
59}
60
61#[derive(Debug, Default, Clone)]
63pub struct QuerySpec {
64 pub subject: Option<String>,
65 pub predicate: Option<String>,
66 pub object: Option<String>,
67 pub namespace: Option<Namespace>,
68 pub y_layer: Option<YLayer>,
69 pub include_deleted: bool,
70}
71
72#[derive(Debug, Clone, PartialEq)]
74pub struct CausalNode {
75 pub triple_id: String,
76 pub caused_by: Option<String>,
77 pub derived_from: Option<String>,
78}
79
80pub struct ArrowGraphStore {
85 schema: SchemaRef,
86 partitions: HashMap<Namespace, Vec<RecordBatch>>,
88}
89
90impl ArrowGraphStore {
91 pub fn new() -> Self {
93 let schema = Arc::new(triples_schema());
94 let mut partitions = HashMap::new();
95 for ns in Namespace::ALL {
96 partitions.insert(ns, Vec::new());
97 }
98 ArrowGraphStore { schema, partitions }
99 }
100
101 pub fn schema(&self) -> &SchemaRef {
103 &self.schema
104 }
105
106 pub fn add_triple(
108 &mut self,
109 triple: &Triple,
110 namespace: Namespace,
111 y_layer: YLayer,
112 ) -> Result<String> {
113 self.add_batch(std::slice::from_ref(triple), namespace, y_layer)
114 .map(|ids| ids.into_iter().next().unwrap())
115 }
116
117 pub fn add_batch(
120 &mut self,
121 triples: &[Triple],
122 namespace: Namespace,
123 y_layer: YLayer,
124 ) -> Result<Vec<String>> {
125 let n = triples.len();
126 if n == 0 {
127 return Ok(vec![]);
128 }
129
130 let now_ms = chrono::Utc::now().timestamp_millis();
131 let ns_str = namespace.as_str();
132 let layer_val = y_layer.as_u8();
133
134 let ids: Vec<String> = (0..n).map(|_| uuid::Uuid::new_v4().to_string()).collect();
135
136 let subjects: Vec<&str> = triples.iter().map(|t| t.subject.as_str()).collect();
137 let predicates: Vec<&str> = triples.iter().map(|t| t.predicate.as_str()).collect();
138 let objects: Vec<&str> = triples.iter().map(|t| t.object.as_str()).collect();
139 let graphs: Vec<Option<&str>> = triples.iter().map(|t| t.graph.as_deref()).collect();
140 let ns_vals: Vec<&str> = vec![ns_str; n];
141 let layer_vals: Vec<u8> = vec![layer_val; n];
142 let confidences: Vec<Option<f64>> = triples.iter().map(|t| t.confidence).collect();
143 let source_docs: Vec<Option<&str>> = triples
144 .iter()
145 .map(|t| t.source_document.as_deref())
146 .collect();
147 let source_chunks: Vec<Option<&str>> = triples
148 .iter()
149 .map(|t| t.source_chunk_id.as_deref())
150 .collect();
151 let extracted: Vec<Option<&str>> =
152 triples.iter().map(|t| t.extracted_by.as_deref()).collect();
153 let caused_by: Vec<Option<&str>> = triples.iter().map(|t| t.caused_by.as_deref()).collect();
154 let derived_from: Vec<Option<&str>> =
155 triples.iter().map(|t| t.derived_from.as_deref()).collect();
156 let consolidated_at: Vec<Option<i64>> = triples.iter().map(|t| t.consolidated_at).collect();
157 let timestamps: Vec<i64> = vec![now_ms; n];
158 let deleted: Vec<bool> = vec![false; n];
159 let certifiability_class: Vec<Option<&str>> = triples
160 .iter()
161 .map(|t| t.certifiability_class.as_deref())
162 .collect();
163 let id_strs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
164
165 let batch = RecordBatch::try_new(
166 self.schema.clone(),
167 vec![
168 Arc::new(StringArray::from(id_strs)),
169 Arc::new(StringArray::from(subjects)),
170 Arc::new(StringArray::from(predicates)),
171 Arc::new(StringArray::from(objects)),
172 Arc::new(StringArray::from(graphs)),
173 Arc::new(StringArray::from(ns_vals)),
174 Arc::new(UInt8Array::from(layer_vals)),
175 Arc::new(Float64Array::from(confidences)),
176 Arc::new(StringArray::from(source_docs)),
177 Arc::new(StringArray::from(source_chunks)),
178 Arc::new(StringArray::from(extracted)),
179 Arc::new(TimestampMillisecondArray::from(timestamps).with_timezone("UTC")),
180 Arc::new(StringArray::from(caused_by)),
181 Arc::new(StringArray::from(derived_from)),
182 Arc::new(TimestampMillisecondArray::from(consolidated_at).with_timezone("UTC")),
183 Arc::new(BooleanArray::from(deleted)),
184 Arc::new(StringArray::from(certifiability_class)),
185 ],
186 )?;
187
188 self.partitions.get_mut(&namespace).unwrap().push(batch);
189
190 Ok(ids)
191 }
192
193 pub fn query(&self, spec: &QuerySpec) -> Result<Vec<RecordBatch>> {
195 let namespaces: Vec<Namespace> = match spec.namespace {
196 Some(ns) => vec![ns],
197 None => Namespace::ALL.to_vec(),
198 };
199
200 let mut results = Vec::new();
201
202 for ns in namespaces {
203 let batches = self.partitions.get(&ns).unwrap();
204 for batch in batches {
205 let filtered = self.filter_batch(batch, spec)?;
206 if filtered.num_rows() > 0 {
207 results.push(filtered);
208 }
209 }
210 }
211
212 Ok(results)
213 }
214
215 pub fn len(&self) -> usize {
217 let spec = QuerySpec::default();
218 self.query(&spec)
219 .unwrap_or_default()
220 .iter()
221 .map(|b| b.num_rows())
222 .sum()
223 }
224
225 pub fn is_empty(&self) -> bool {
227 self.len() == 0
228 }
229
230 pub fn len_all(&self) -> usize {
232 self.partitions
233 .values()
234 .flat_map(|batches| batches.iter())
235 .map(|b| b.num_rows())
236 .sum()
237 }
238
239 pub fn delete(&mut self, triple_id: &str) -> Result<bool> {
241 for batches in self.partitions.values_mut() {
242 for batch in batches.iter_mut() {
243 let id_col = batch
244 .column(col::TRIPLE_ID)
245 .as_any()
246 .downcast_ref::<StringArray>()
247 .expect("triple_id column must be StringArray");
248
249 let mut found_idx = None;
250 for i in 0..id_col.len() {
251 if id_col.value(i) == triple_id {
252 found_idx = Some(i);
253 break;
254 }
255 }
256
257 if let Some(idx) = found_idx {
258 let del_col = batch
260 .column(col::DELETED)
261 .as_any()
262 .downcast_ref::<BooleanArray>()
263 .expect("deleted column must be BooleanArray");
264
265 let mut new_del: Vec<bool> =
266 (0..del_col.len()).map(|i| del_col.value(i)).collect();
267 new_del[idx] = true;
268
269 let mut columns: Vec<Arc<dyn Array>> = Vec::new();
270 for c in 0..batch.num_columns() {
271 if c == col::DELETED {
272 columns.push(Arc::new(BooleanArray::from(new_del.clone())));
273 } else {
274 columns.push(batch.column(c).clone());
275 }
276 }
277
278 *batch = RecordBatch::try_new(self.schema.clone(), columns)?;
279 return Ok(true);
280 }
281 }
282 }
283 Ok(false)
284 }
285
286 pub fn get_namespace_batches(&self, namespace: Namespace) -> &[RecordBatch] {
288 self.partitions
289 .get(&namespace)
290 .map_or(&[], |v| v.as_slice())
291 }
292
293 pub fn set_namespace_batches(&mut self, namespace: Namespace, batches: Vec<RecordBatch>) {
295 self.partitions.insert(namespace, batches);
296 }
297
298 pub fn causal_chain(&self, triple_id: &str) -> Vec<CausalNode> {
305 let mut result = Vec::new();
306 let mut visited = std::collections::HashSet::new();
307 let mut queue = std::collections::VecDeque::new();
308 queue.push_back(triple_id.to_string());
309
310 let mut index: HashMap<String, (Option<String>, Option<String>)> = HashMap::new();
312 for batches in self.partitions.values() {
313 for batch in batches {
314 let id_col = batch
315 .column(col::TRIPLE_ID)
316 .as_any()
317 .downcast_ref::<StringArray>()
318 .expect("triple_id column");
319 let caused_col = batch
320 .column(col::CAUSED_BY)
321 .as_any()
322 .downcast_ref::<StringArray>()
323 .expect("caused_by column");
324 let derived_col = batch
325 .column(col::DERIVED_FROM)
326 .as_any()
327 .downcast_ref::<StringArray>()
328 .expect("derived_from column");
329 let del_col = batch
330 .column(col::DELETED)
331 .as_any()
332 .downcast_ref::<BooleanArray>()
333 .expect("deleted column");
334
335 for i in 0..batch.num_rows() {
336 if del_col.value(i) {
337 continue;
338 }
339 let id = id_col.value(i).to_string();
340 let caused = if caused_col.is_null(i) {
341 None
342 } else {
343 Some(caused_col.value(i).to_string())
344 };
345 let derived = if derived_col.is_null(i) {
346 None
347 } else {
348 Some(derived_col.value(i).to_string())
349 };
350 index.insert(id, (caused, derived));
351 }
352 }
353 }
354
355 while let Some(tid) = queue.pop_front() {
356 if !visited.insert(tid.clone()) {
357 continue;
358 }
359 if let Some((caused, derived)) = index.get(&tid) {
360 result.push(CausalNode {
361 triple_id: tid.clone(),
362 caused_by: caused.clone(),
363 derived_from: derived.clone(),
364 });
365 if let Some(cb) = caused
366 && !visited.contains(cb)
367 {
368 queue.push_back(cb.clone());
369 }
370 if let Some(df) = derived
371 && !visited.contains(df)
372 {
373 queue.push_back(df.clone());
374 }
375 }
376 }
377
378 result
379 }
380
381 pub fn clear(&mut self) {
383 for batches in self.partitions.values_mut() {
384 batches.clear();
385 }
386 }
387
388 fn filter_batch(&self, batch: &RecordBatch, spec: &QuerySpec) -> Result<RecordBatch> {
390 let n = batch.num_rows();
391 let mut mask = BooleanArray::from(vec![true; n]);
392
393 if !spec.include_deleted {
395 let del_col = batch
396 .column(col::DELETED)
397 .as_any()
398 .downcast_ref::<BooleanArray>()
399 .expect("deleted column must be BooleanArray");
400 let not_deleted = compute::not(del_col)?;
401 mask = compute::and(&mask, ¬_deleted)?;
402 }
403
404 if let Some(ref subj) = spec.subject {
406 let c = batch
407 .column(col::SUBJECT)
408 .as_any()
409 .downcast_ref::<StringArray>()
410 .expect("subject column must be StringArray");
411 let eq = string_eq_scalar(c, subj);
412 mask = compute::and(&mask, &eq)?;
413 }
414
415 if let Some(ref pred) = spec.predicate {
417 let c = batch
418 .column(col::PREDICATE)
419 .as_any()
420 .downcast_ref::<StringArray>()
421 .expect("predicate column must be StringArray");
422 let eq = string_eq_scalar(c, pred);
423 mask = compute::and(&mask, &eq)?;
424 }
425
426 if let Some(ref obj) = spec.object {
428 let c = batch
429 .column(col::OBJECT)
430 .as_any()
431 .downcast_ref::<StringArray>()
432 .expect("object column must be StringArray");
433 let eq = string_eq_scalar(c, obj);
434 mask = compute::and(&mask, &eq)?;
435 }
436
437 if let Some(layer) = spec.y_layer {
439 let c = batch
440 .column(col::Y_LAYER)
441 .as_any()
442 .downcast_ref::<UInt8Array>()
443 .expect("y_layer column must be UInt8Array");
444 let eq = u8_eq_scalar(c, layer.as_u8());
445 mask = compute::and(&mask, &eq)?;
446 }
447
448 let filtered = compute::filter_record_batch(batch, &mask)?;
449 Ok(filtered)
450 }
451}
452
453impl Default for ArrowGraphStore {
454 fn default() -> Self {
455 Self::new()
456 }
457}
458
459fn string_eq_scalar(array: &StringArray, value: &str) -> BooleanArray {
461 let bools: Vec<bool> = (0..array.len()).map(|i| array.value(i) == value).collect();
462 BooleanArray::from(bools)
463}
464
465fn u8_eq_scalar(array: &UInt8Array, value: u8) -> BooleanArray {
467 let bools: Vec<bool> = (0..array.len()).map(|i| array.value(i) == value).collect();
468 BooleanArray::from(bools)
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 fn sample_triple(subj: &str, pred: &str, obj: &str) -> Triple {
476 Triple {
477 subject: subj.to_string(),
478 predicate: pred.to_string(),
479 object: obj.to_string(),
480 graph: None,
481 confidence: Some(0.9),
482 source_document: None,
483 source_chunk_id: None,
484 extracted_by: Some("test".to_string()),
485 caused_by: None,
486 derived_from: None,
487 consolidated_at: None,
488 certifiability_class: None,
489 }
490 }
491
492 #[test]
493 fn test_add_and_query_single() {
494 let mut store = ArrowGraphStore::new();
495 let id = store
496 .add_triple(
497 &sample_triple("s1", "p1", "o1"),
498 Namespace::World,
499 YLayer::Semantic,
500 )
501 .unwrap();
502
503 assert!(!id.is_empty());
504 assert_eq!(store.len(), 1);
505
506 let results = store
507 .query(&QuerySpec {
508 subject: Some("s1".to_string()),
509 ..Default::default()
510 })
511 .unwrap();
512 let total: usize = results.iter().map(|b| b.num_rows()).sum();
513 assert_eq!(total, 1);
514 }
515
516 #[test]
517 fn test_namespace_isolation() {
518 let mut store = ArrowGraphStore::new();
519
520 let world_triples: Vec<Triple> = (0..100)
522 .map(|i| sample_triple(&format!("w{i}"), "rdf:type", "Thing"))
523 .collect();
524 store
525 .add_batch(&world_triples, Namespace::World, YLayer::Semantic)
526 .unwrap();
527
528 let work_triples: Vec<Triple> = (0..100)
530 .map(|i| sample_triple(&format!("k{i}"), "rdf:type", "Task"))
531 .collect();
532 store
533 .add_batch(&work_triples, Namespace::Work, YLayer::Semantic)
534 .unwrap();
535
536 let world_results = store
538 .query(&QuerySpec {
539 namespace: Some(Namespace::World),
540 ..Default::default()
541 })
542 .unwrap();
543 let world_count: usize = world_results.iter().map(|b| b.num_rows()).sum();
544 assert_eq!(world_count, 100);
545
546 let work_results = store
548 .query(&QuerySpec {
549 namespace: Some(Namespace::Work),
550 ..Default::default()
551 })
552 .unwrap();
553 let work_count: usize = work_results.iter().map(|b| b.num_rows()).sum();
554 assert_eq!(work_count, 100);
555
556 assert_eq!(store.len(), 200);
558 }
559
560 #[test]
561 fn test_ylayer_query() {
562 let mut store = ArrowGraphStore::new();
563
564 store
565 .add_triple(
566 &sample_triple("s1", "p1", "o1"),
567 Namespace::World,
568 YLayer::Prose,
569 )
570 .unwrap();
571 store
572 .add_triple(
573 &sample_triple("s2", "p2", "o2"),
574 Namespace::World,
575 YLayer::Semantic,
576 )
577 .unwrap();
578
579 let y0_results = store
580 .query(&QuerySpec {
581 y_layer: Some(YLayer::Prose),
582 ..Default::default()
583 })
584 .unwrap();
585 let y0_count: usize = y0_results.iter().map(|b| b.num_rows()).sum();
586 assert_eq!(y0_count, 1);
587 }
588
589 #[test]
590 fn test_logical_delete() {
591 let mut store = ArrowGraphStore::new();
592 let id = store
593 .add_triple(
594 &sample_triple("s1", "p1", "o1"),
595 Namespace::World,
596 YLayer::Semantic,
597 )
598 .unwrap();
599
600 assert_eq!(store.len(), 1);
601 assert!(store.delete(&id).unwrap());
602 assert_eq!(store.len(), 0);
603 assert_eq!(store.len_all(), 1); }
605
606 #[test]
607 fn test_batch_add_performance() {
608 let mut store = ArrowGraphStore::new();
609
610 let triples: Vec<Triple> = (0..10_000)
611 .map(|i| sample_triple(&format!("s{i}"), "rdf:type", "Entity"))
612 .collect();
613
614 let start = std::time::Instant::now();
615 store
616 .add_batch(&triples, Namespace::World, YLayer::Semantic)
617 .unwrap();
618 let elapsed = start.elapsed();
619
620 assert_eq!(store.len(), 10_000);
621 assert!(
623 elapsed.as_millis() < 100,
624 "Batch add took too long: {:?}",
625 elapsed
626 );
627 }
628
629 #[test]
630 fn test_causal_chain_linear() {
631 let mut store = ArrowGraphStore::new();
632
633 let t0 = Triple {
635 subject: "s0".to_string(),
636 predicate: "p".to_string(),
637 object: "o0".to_string(),
638 caused_by: None,
639 derived_from: None,
640 ..sample_triple("s0", "p", "o0")
641 };
642 let id0 = store
643 .add_triple(&t0, Namespace::World, YLayer::Semantic)
644 .unwrap();
645
646 let t1 = Triple {
647 subject: "s1".to_string(),
648 predicate: "p".to_string(),
649 object: "o1".to_string(),
650 caused_by: Some(id0.clone()),
651 derived_from: None,
652 ..sample_triple("s1", "p", "o1")
653 };
654 let id1 = store
655 .add_triple(&t1, Namespace::World, YLayer::Semantic)
656 .unwrap();
657
658 let t2 = Triple {
659 subject: "s2".to_string(),
660 predicate: "p".to_string(),
661 object: "o2".to_string(),
662 caused_by: Some(id1.clone()),
663 derived_from: None,
664 ..sample_triple("s2", "p", "o2")
665 };
666 let id2 = store
667 .add_triple(&t2, Namespace::World, YLayer::Semantic)
668 .unwrap();
669
670 let chain = store.causal_chain(&id2);
672 assert_eq!(chain.len(), 3);
673 assert_eq!(chain[0].triple_id, id2);
674 assert_eq!(chain[0].caused_by, Some(id1.clone()));
675 assert_eq!(chain[1].triple_id, id1);
676 assert_eq!(chain[1].caused_by, Some(id0.clone()));
677 assert_eq!(chain[2].triple_id, id0);
678 assert_eq!(chain[2].caused_by, None);
679 }
680
681 #[test]
682 fn test_causal_chain_with_derived_from() {
683 let mut store = ArrowGraphStore::new();
684
685 let t0 = Triple {
686 subject: "base".to_string(),
687 predicate: "p".to_string(),
688 object: "original".to_string(),
689 caused_by: None,
690 derived_from: None,
691 ..sample_triple("base", "p", "original")
692 };
693 let id0 = store
694 .add_triple(&t0, Namespace::World, YLayer::Reasoning)
695 .unwrap();
696
697 let t1 = Triple {
698 subject: "derived".to_string(),
699 predicate: "p".to_string(),
700 object: "derived_val".to_string(),
701 caused_by: None,
702 derived_from: Some(id0.clone()),
703 ..sample_triple("derived", "p", "derived_val")
704 };
705 let id1 = store
706 .add_triple(&t1, Namespace::World, YLayer::Reasoning)
707 .unwrap();
708
709 let chain = store.causal_chain(&id1);
710 assert_eq!(chain.len(), 2);
711 assert_eq!(chain[0].derived_from, Some(id0.clone()));
712 assert_eq!(chain[1].triple_id, id0);
713 }
714
715 #[test]
716 fn test_causal_chain_nonexistent_triple() {
717 let store = ArrowGraphStore::new();
718 let chain = store.causal_chain("nonexistent");
719 assert!(chain.is_empty());
720 }
721}