1use crate::schema::{col, triples_schema};
20
21use arrow::array::{
22 Array, BooleanArray, Float64Array, RecordBatch, StringArray, TimestampMillisecondArray,
23 UInt8Array,
24};
25use arrow::compute;
26use arrow::datatypes::SchemaRef;
27use std::collections::HashMap;
28use std::sync::Arc;
29
30#[derive(Debug, thiserror::Error)]
32pub enum StoreError {
33 #[error("Arrow error: {0}")]
34 Arrow(#[from] arrow::error::ArrowError),
35
36 #[error("Unknown namespace: {0}")]
37 UnknownNamespace(String),
38
39 #[error("Invalid layer: {0}")]
40 InvalidLayer(u8),
41
42 #[error("Triple not found: {0}")]
43 TripleNotFound(String),
44}
45
46pub type Result<T> = std::result::Result<T, StoreError>;
47
48#[derive(Debug, Clone)]
50pub struct Triple {
51 pub subject: String,
52 pub predicate: String,
53 pub object: String,
54 pub graph: Option<String>,
55 pub confidence: Option<f64>,
56 pub source_document: Option<String>,
57 pub source_chunk_id: Option<String>,
59 pub extracted_by: Option<String>,
60 pub caused_by: Option<String>,
62 pub derived_from: Option<String>,
64 pub consolidated_at: Option<i64>,
66}
67
68#[derive(Debug, Default, Clone)]
70pub struct QuerySpec {
71 pub subject: Option<String>,
72 pub predicate: Option<String>,
73 pub object: Option<String>,
74 pub namespace: Option<String>,
75 pub layer: Option<u8>,
76 pub include_deleted: bool,
77}
78
79#[derive(Debug, Clone, PartialEq)]
81pub struct CausalNode {
82 pub triple_id: String,
83 pub caused_by: Option<String>,
84 pub derived_from: Option<String>,
85}
86
87pub struct ArrowGraphStore {
94 schema: SchemaRef,
95 partitions: HashMap<String, Vec<RecordBatch>>,
97 namespace_names: Vec<String>,
99}
100
101impl ArrowGraphStore {
102 pub fn new(namespaces: &[&str]) -> Self {
116 let schema = Arc::new(triples_schema());
117 let mut partitions = HashMap::new();
118 let mut namespace_names = Vec::new();
119 for ns in namespaces {
120 partitions.insert(ns.to_string(), Vec::new());
121 namespace_names.push(ns.to_string());
122 }
123 ArrowGraphStore {
124 schema,
125 partitions,
126 namespace_names,
127 }
128 }
129
130 pub fn schema(&self) -> &SchemaRef {
132 &self.schema
133 }
134
135 pub fn namespaces(&self) -> &[String] {
137 &self.namespace_names
138 }
139
140 pub fn add_triple(
144 &mut self,
145 triple: &Triple,
146 namespace: &str,
147 layer: Option<u8>,
148 ) -> Result<String> {
149 self.add_batch(std::slice::from_ref(triple), namespace, layer)
150 .map(|ids| ids.into_iter().next().unwrap())
151 }
152
153 pub fn add_batch(
156 &mut self,
157 triples: &[Triple],
158 namespace: &str,
159 layer: Option<u8>,
160 ) -> Result<Vec<String>> {
161 if !self.partitions.contains_key(namespace) {
162 return Err(StoreError::UnknownNamespace(namespace.to_string()));
163 }
164
165 let n = triples.len();
166 if n == 0 {
167 return Ok(vec![]);
168 }
169
170 let now_ms = chrono::Utc::now().timestamp_millis();
171 let layer_val = layer.unwrap_or(0);
172
173 let ids: Vec<String> = (0..n).map(|_| uuid::Uuid::new_v4().to_string()).collect();
174
175 let subjects: Vec<&str> = triples.iter().map(|t| t.subject.as_str()).collect();
176 let predicates: Vec<&str> = triples.iter().map(|t| t.predicate.as_str()).collect();
177 let objects: Vec<&str> = triples.iter().map(|t| t.object.as_str()).collect();
178 let graphs: Vec<Option<&str>> = triples.iter().map(|t| t.graph.as_deref()).collect();
179 let ns_vals: Vec<&str> = vec![namespace; n];
180 let layer_vals: Vec<u8> = vec![layer_val; n];
181 let confidences: Vec<Option<f64>> = triples.iter().map(|t| t.confidence).collect();
182 let source_docs: Vec<Option<&str>> = triples
183 .iter()
184 .map(|t| t.source_document.as_deref())
185 .collect();
186 let source_chunks: Vec<Option<&str>> = triples
187 .iter()
188 .map(|t| t.source_chunk_id.as_deref())
189 .collect();
190 let extracted: Vec<Option<&str>> =
191 triples.iter().map(|t| t.extracted_by.as_deref()).collect();
192 let caused_by: Vec<Option<&str>> = triples.iter().map(|t| t.caused_by.as_deref()).collect();
193 let derived_from: Vec<Option<&str>> =
194 triples.iter().map(|t| t.derived_from.as_deref()).collect();
195 let consolidated_at: Vec<Option<i64>> = triples.iter().map(|t| t.consolidated_at).collect();
196 let timestamps: Vec<i64> = vec![now_ms; n];
197 let deleted: Vec<bool> = vec![false; n];
198 let id_strs: Vec<&str> = ids.iter().map(|s| s.as_str()).collect();
199
200 let batch = RecordBatch::try_new(
201 self.schema.clone(),
202 vec![
203 Arc::new(StringArray::from(id_strs)),
204 Arc::new(StringArray::from(subjects)),
205 Arc::new(StringArray::from(predicates)),
206 Arc::new(StringArray::from(objects)),
207 Arc::new(StringArray::from(graphs)),
208 Arc::new(StringArray::from(ns_vals)),
209 Arc::new(UInt8Array::from(layer_vals)),
210 Arc::new(Float64Array::from(confidences)),
211 Arc::new(StringArray::from(source_docs)),
212 Arc::new(StringArray::from(source_chunks)),
213 Arc::new(StringArray::from(extracted)),
214 Arc::new(TimestampMillisecondArray::from(timestamps).with_timezone("UTC")),
215 Arc::new(StringArray::from(caused_by)),
216 Arc::new(StringArray::from(derived_from)),
217 Arc::new(TimestampMillisecondArray::from(consolidated_at).with_timezone("UTC")),
218 Arc::new(BooleanArray::from(deleted)),
219 ],
220 )?;
221
222 self.partitions.get_mut(namespace).unwrap().push(batch);
223
224 Ok(ids)
225 }
226
227 pub fn query(&self, spec: &QuerySpec) -> Result<Vec<RecordBatch>> {
229 let namespaces: Vec<&String> = match &spec.namespace {
230 Some(ns) => {
231 if !self.partitions.contains_key(ns) {
232 return Err(StoreError::UnknownNamespace(ns.clone()));
233 }
234 vec![ns]
235 }
236 None => self.namespace_names.iter().collect(),
237 };
238
239 let mut results = Vec::new();
240
241 for ns in namespaces {
242 let batches = self.partitions.get(ns).unwrap();
243 for batch in batches {
244 let filtered = self.filter_batch(batch, spec)?;
245 if filtered.num_rows() > 0 {
246 results.push(filtered);
247 }
248 }
249 }
250
251 Ok(results)
252 }
253
254 pub fn len(&self) -> usize {
256 let spec = QuerySpec::default();
257 self.query(&spec)
258 .unwrap_or_default()
259 .iter()
260 .map(|b| b.num_rows())
261 .sum()
262 }
263
264 pub fn is_empty(&self) -> bool {
266 self.len() == 0
267 }
268
269 pub fn len_all(&self) -> usize {
271 self.partitions
272 .values()
273 .flat_map(|batches| batches.iter())
274 .map(|b| b.num_rows())
275 .sum()
276 }
277
278 pub fn delete(&mut self, triple_id: &str) -> Result<bool> {
280 for batches in self.partitions.values_mut() {
281 for batch in batches.iter_mut() {
282 let id_col = batch
283 .column(col::TRIPLE_ID)
284 .as_any()
285 .downcast_ref::<StringArray>()
286 .expect("triple_id column must be StringArray");
287
288 let mut found_idx = None;
289 for i in 0..id_col.len() {
290 if id_col.value(i) == triple_id {
291 found_idx = Some(i);
292 break;
293 }
294 }
295
296 if let Some(idx) = found_idx {
297 let del_col = batch
298 .column(col::DELETED)
299 .as_any()
300 .downcast_ref::<BooleanArray>()
301 .expect("deleted column must be BooleanArray");
302
303 let mut new_del: Vec<bool> =
304 (0..del_col.len()).map(|i| del_col.value(i)).collect();
305 new_del[idx] = true;
306
307 let mut columns: Vec<Arc<dyn Array>> = Vec::new();
308 for c in 0..batch.num_columns() {
309 if c == col::DELETED {
310 columns.push(Arc::new(BooleanArray::from(new_del.clone())));
311 } else {
312 columns.push(batch.column(c).clone());
313 }
314 }
315
316 *batch = RecordBatch::try_new(self.schema.clone(), columns)?;
317 return Ok(true);
318 }
319 }
320 }
321 Ok(false)
322 }
323
324 pub fn get_namespace_batches(&self, namespace: &str) -> &[RecordBatch] {
326 self.partitions.get(namespace).map_or(&[], |v| v.as_slice())
327 }
328
329 pub fn set_namespace_batches(&mut self, namespace: &str, batches: Vec<RecordBatch>) {
331 if let Some(existing) = self.partitions.get_mut(namespace) {
332 *existing = batches;
333 }
334 }
335
336 pub fn causal_chain(&self, triple_id: &str) -> Vec<CausalNode> {
342 let mut result = Vec::new();
343 let mut visited = std::collections::HashSet::new();
344 let mut queue = std::collections::VecDeque::new();
345 queue.push_back(triple_id.to_string());
346
347 let mut index: HashMap<String, (Option<String>, Option<String>)> = HashMap::new();
349 for batches in self.partitions.values() {
350 for batch in batches {
351 let id_col = batch
352 .column(col::TRIPLE_ID)
353 .as_any()
354 .downcast_ref::<StringArray>()
355 .expect("triple_id column");
356 let caused_col = batch
357 .column(col::CAUSED_BY)
358 .as_any()
359 .downcast_ref::<StringArray>()
360 .expect("caused_by column");
361 let derived_col = batch
362 .column(col::DERIVED_FROM)
363 .as_any()
364 .downcast_ref::<StringArray>()
365 .expect("derived_from column");
366 let del_col = batch
367 .column(col::DELETED)
368 .as_any()
369 .downcast_ref::<BooleanArray>()
370 .expect("deleted column");
371
372 for i in 0..batch.num_rows() {
373 if del_col.value(i) {
374 continue;
375 }
376 let id = id_col.value(i).to_string();
377 let caused = if caused_col.is_null(i) {
378 None
379 } else {
380 Some(caused_col.value(i).to_string())
381 };
382 let derived = if derived_col.is_null(i) {
383 None
384 } else {
385 Some(derived_col.value(i).to_string())
386 };
387 index.insert(id, (caused, derived));
388 }
389 }
390 }
391
392 while let Some(tid) = queue.pop_front() {
393 if !visited.insert(tid.clone()) {
394 continue;
395 }
396 if let Some((caused, derived)) = index.get(&tid) {
397 result.push(CausalNode {
398 triple_id: tid.clone(),
399 caused_by: caused.clone(),
400 derived_from: derived.clone(),
401 });
402 if let Some(cb) = caused
403 && !visited.contains(cb)
404 {
405 queue.push_back(cb.clone());
406 }
407 if let Some(df) = derived
408 && !visited.contains(df)
409 {
410 queue.push_back(df.clone());
411 }
412 }
413 }
414
415 result
416 }
417
418 pub fn clear(&mut self) {
420 for batches in self.partitions.values_mut() {
421 batches.clear();
422 }
423 }
424
425 fn filter_batch(&self, batch: &RecordBatch, spec: &QuerySpec) -> Result<RecordBatch> {
427 let n = batch.num_rows();
428 let mut mask = BooleanArray::from(vec![true; n]);
429
430 if !spec.include_deleted {
432 let del_col = batch
433 .column(col::DELETED)
434 .as_any()
435 .downcast_ref::<BooleanArray>()
436 .expect("deleted column must be BooleanArray");
437 let not_deleted = compute::not(del_col)?;
438 mask = compute::and(&mask, ¬_deleted)?;
439 }
440
441 if let Some(ref subj) = spec.subject {
442 let c = batch
443 .column(col::SUBJECT)
444 .as_any()
445 .downcast_ref::<StringArray>()
446 .expect("subject column must be StringArray");
447 let eq = string_eq_scalar(c, subj);
448 mask = compute::and(&mask, &eq)?;
449 }
450
451 if let Some(ref pred) = spec.predicate {
452 let c = batch
453 .column(col::PREDICATE)
454 .as_any()
455 .downcast_ref::<StringArray>()
456 .expect("predicate column must be StringArray");
457 let eq = string_eq_scalar(c, pred);
458 mask = compute::and(&mask, &eq)?;
459 }
460
461 if let Some(ref obj) = spec.object {
462 let c = batch
463 .column(col::OBJECT)
464 .as_any()
465 .downcast_ref::<StringArray>()
466 .expect("object column must be StringArray");
467 let eq = string_eq_scalar(c, obj);
468 mask = compute::and(&mask, &eq)?;
469 }
470
471 if let Some(layer) = spec.layer {
472 let c = batch
473 .column(col::LAYER)
474 .as_any()
475 .downcast_ref::<UInt8Array>()
476 .expect("layer column must be UInt8Array");
477 let eq = u8_eq_scalar(c, layer);
478 mask = compute::and(&mask, &eq)?;
479 }
480
481 let filtered = compute::filter_record_batch(batch, &mask)?;
482 Ok(filtered)
483 }
484}
485
486fn string_eq_scalar(array: &StringArray, value: &str) -> BooleanArray {
488 let bools: Vec<bool> = (0..array.len()).map(|i| array.value(i) == value).collect();
489 BooleanArray::from(bools)
490}
491
492fn u8_eq_scalar(array: &UInt8Array, value: u8) -> BooleanArray {
494 let bools: Vec<bool> = (0..array.len()).map(|i| array.value(i) == value).collect();
495 BooleanArray::from(bools)
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 fn sample_triple(subj: &str, pred: &str, obj: &str) -> Triple {
503 Triple {
504 subject: subj.to_string(),
505 predicate: pred.to_string(),
506 object: obj.to_string(),
507 graph: None,
508 confidence: Some(0.9),
509 source_document: None,
510 source_chunk_id: None,
511 extracted_by: Some("test".to_string()),
512 caused_by: None,
513 derived_from: None,
514 consolidated_at: None,
515 }
516 }
517
518 #[test]
519 fn test_add_and_query_single() {
520 let mut store = ArrowGraphStore::new(&["world", "work"]);
521 let id = store
522 .add_triple(&sample_triple("s1", "p1", "o1"), "world", Some(1))
523 .unwrap();
524
525 assert!(!id.is_empty());
526 assert_eq!(store.len(), 1);
527
528 let results = store
529 .query(&QuerySpec {
530 subject: Some("s1".to_string()),
531 ..Default::default()
532 })
533 .unwrap();
534 let total: usize = results.iter().map(|b| b.num_rows()).sum();
535 assert_eq!(total, 1);
536 }
537
538 #[test]
539 fn test_namespace_isolation() {
540 let mut store = ArrowGraphStore::new(&["world", "work"]);
541
542 let world_triples: Vec<Triple> = (0..100)
543 .map(|i| sample_triple(&format!("w{i}"), "rdf:type", "Thing"))
544 .collect();
545 store.add_batch(&world_triples, "world", Some(1)).unwrap();
546
547 let work_triples: Vec<Triple> = (0..100)
548 .map(|i| sample_triple(&format!("k{i}"), "rdf:type", "Task"))
549 .collect();
550 store.add_batch(&work_triples, "work", Some(1)).unwrap();
551
552 let world_results = store
553 .query(&QuerySpec {
554 namespace: Some("world".to_string()),
555 ..Default::default()
556 })
557 .unwrap();
558 let world_count: usize = world_results.iter().map(|b| b.num_rows()).sum();
559 assert_eq!(world_count, 100);
560
561 let work_results = store
562 .query(&QuerySpec {
563 namespace: Some("work".to_string()),
564 ..Default::default()
565 })
566 .unwrap();
567 let work_count: usize = work_results.iter().map(|b| b.num_rows()).sum();
568 assert_eq!(work_count, 100);
569
570 assert_eq!(store.len(), 200);
571 }
572
573 #[test]
574 fn test_layer_query() {
575 let mut store = ArrowGraphStore::new(&["default"]);
576
577 store
578 .add_triple(&sample_triple("s1", "p1", "o1"), "default", Some(0))
579 .unwrap();
580 store
581 .add_triple(&sample_triple("s2", "p2", "o2"), "default", Some(1))
582 .unwrap();
583
584 let l0_results = store
585 .query(&QuerySpec {
586 layer: Some(0),
587 ..Default::default()
588 })
589 .unwrap();
590 let l0_count: usize = l0_results.iter().map(|b| b.num_rows()).sum();
591 assert_eq!(l0_count, 1);
592 }
593
594 #[test]
595 fn test_default_layer() {
596 let mut store = ArrowGraphStore::new(&["default"]);
597
598 store
600 .add_triple(&sample_triple("s1", "p1", "o1"), "default", None)
601 .unwrap();
602
603 let results = store
604 .query(&QuerySpec {
605 layer: Some(0),
606 ..Default::default()
607 })
608 .unwrap();
609 let count: usize = results.iter().map(|b| b.num_rows()).sum();
610 assert_eq!(count, 1);
611 }
612
613 #[test]
614 fn test_unknown_namespace_error() {
615 let mut store = ArrowGraphStore::new(&["world"]);
616 let result = store.add_triple(&sample_triple("s", "p", "o"), "unknown", None);
617 assert!(matches!(result, Err(StoreError::UnknownNamespace(_))));
618 }
619
620 #[test]
621 fn test_logical_delete() {
622 let mut store = ArrowGraphStore::new(&["world"]);
623 let id = store
624 .add_triple(&sample_triple("s1", "p1", "o1"), "world", Some(1))
625 .unwrap();
626
627 assert_eq!(store.len(), 1);
628 assert!(store.delete(&id).unwrap());
629 assert_eq!(store.len(), 0);
630 assert_eq!(store.len_all(), 1); }
632
633 #[test]
634 fn test_batch_add_performance() {
635 let mut store = ArrowGraphStore::new(&["world"]);
636
637 let triples: Vec<Triple> = (0..10_000)
638 .map(|i| sample_triple(&format!("s{i}"), "rdf:type", "Entity"))
639 .collect();
640
641 let start = std::time::Instant::now();
642 store.add_batch(&triples, "world", Some(1)).unwrap();
643 let elapsed = start.elapsed();
644
645 assert_eq!(store.len(), 10_000);
646 assert!(
647 elapsed.as_millis() < 100,
648 "Batch add took too long: {:?}",
649 elapsed
650 );
651 }
652
653 #[test]
654 fn test_custom_namespaces() {
655 let mut store = ArrowGraphStore::new(&["drafts", "published", "archived"]);
656 assert_eq!(store.namespaces().len(), 3);
657
658 store
659 .add_triple(&sample_triple("doc1", "status", "draft"), "drafts", None)
660 .unwrap();
661 store
662 .add_triple(&sample_triple("doc2", "status", "live"), "published", None)
663 .unwrap();
664
665 assert_eq!(store.len(), 2);
666
667 let drafts = store
668 .query(&QuerySpec {
669 namespace: Some("drafts".to_string()),
670 ..Default::default()
671 })
672 .unwrap();
673 assert_eq!(drafts.iter().map(|b| b.num_rows()).sum::<usize>(), 1);
674 }
675
676 #[test]
677 fn test_causal_chain_linear() {
678 let mut store = ArrowGraphStore::new(&["world"]);
679
680 let t0 = Triple {
681 caused_by: None,
682 derived_from: None,
683 ..sample_triple("s0", "p", "o0")
684 };
685 let id0 = store.add_triple(&t0, "world", Some(1)).unwrap();
686
687 let t1 = Triple {
688 caused_by: Some(id0.clone()),
689 derived_from: None,
690 ..sample_triple("s1", "p", "o1")
691 };
692 let id1 = store.add_triple(&t1, "world", Some(1)).unwrap();
693
694 let t2 = Triple {
695 caused_by: Some(id1.clone()),
696 derived_from: None,
697 ..sample_triple("s2", "p", "o2")
698 };
699 let id2 = store.add_triple(&t2, "world", Some(1)).unwrap();
700
701 let chain = store.causal_chain(&id2);
702 assert_eq!(chain.len(), 3);
703 assert_eq!(chain[0].triple_id, id2);
704 assert_eq!(chain[0].caused_by, Some(id1.clone()));
705 assert_eq!(chain[1].triple_id, id1);
706 assert_eq!(chain[1].caused_by, Some(id0.clone()));
707 assert_eq!(chain[2].triple_id, id0);
708 assert_eq!(chain[2].caused_by, None);
709 }
710
711 #[test]
712 fn test_causal_chain_nonexistent_triple() {
713 let store = ArrowGraphStore::new(&["world"]);
714 let chain = store.causal_chain("nonexistent");
715 assert!(chain.is_empty());
716 }
717}