1use crate::schema::col;
24use crate::store::{ArrowGraphStore, QuerySpec, StoreError, Triple};
25
26use arrow::array::{Array, Float64Array, RecordBatch, StringArray};
27use std::collections::HashMap;
28
29const DEFAULT_NAMESPACE: &str = "default";
31
32#[derive(Debug, Clone)]
34pub struct StoredTriple {
35 pub id: String,
36 pub subject: String,
37 pub predicate: String,
38 pub object: String,
39 pub graph: Option<String>,
40 pub confidence: f64,
41 pub source: Option<String>,
42}
43
44#[derive(Debug, Clone)]
46pub struct StoreStats {
47 pub total_triples: usize,
48 pub unique_subjects: usize,
49 pub unique_predicates: usize,
50 pub unique_objects: usize,
51 pub by_source: HashMap<String, usize>,
52}
53
54pub struct SimpleTripleStore {
59 inner: ArrowGraphStore,
60 namespace: String,
61 layer: Option<u8>,
62}
63
64impl SimpleTripleStore {
65 pub fn new() -> Self {
67 Self {
68 inner: ArrowGraphStore::new(&[DEFAULT_NAMESPACE]),
69 namespace: DEFAULT_NAMESPACE.to_string(),
70 layer: Some(0),
71 }
72 }
73
74 pub fn with_defaults(namespace: &str, layer: Option<u8>) -> Self {
76 Self {
77 inner: ArrowGraphStore::new(&[namespace]),
78 namespace: namespace.to_string(),
79 layer,
80 }
81 }
82
83 pub fn add(
85 &mut self,
86 subject: &str,
87 predicate: &str,
88 object: &str,
89 confidence: f64,
90 source: &str,
91 ) -> Result<String, StoreError> {
92 let triple = Triple {
93 subject: subject.to_string(),
94 predicate: predicate.to_string(),
95 object: object.to_string(),
96 graph: None,
97 confidence: Some(confidence),
98 source_document: Some(source.to_string()),
99 source_chunk_id: None,
100 extracted_by: Some(source.to_string()),
101 caused_by: None,
102 derived_from: None,
103 consolidated_at: None,
104 };
105 self.inner.add_triple(&triple, &self.namespace, self.layer)
106 }
107
108 pub fn add_batch(
110 &mut self,
111 triples: &[(&str, &str, &str, f64, &str)],
112 ) -> Result<Vec<String>, StoreError> {
113 let ts: Vec<Triple> = triples
114 .iter()
115 .map(|(s, p, o, conf, src)| Triple {
116 subject: s.to_string(),
117 predicate: p.to_string(),
118 object: o.to_string(),
119 graph: None,
120 confidence: Some(*conf),
121 source_document: Some(src.to_string()),
122 source_chunk_id: None,
123 extracted_by: Some(src.to_string()),
124 caused_by: None,
125 derived_from: None,
126 consolidated_at: None,
127 })
128 .collect();
129 self.inner.add_batch(&ts, &self.namespace, self.layer)
130 }
131
132 pub fn remove(&mut self, triple_id: &str) -> Result<bool, StoreError> {
134 self.inner.delete(triple_id)
135 }
136
137 pub fn query(
139 &self,
140 subject: Option<&str>,
141 predicate: Option<&str>,
142 object: Option<&str>,
143 ) -> Result<Vec<StoredTriple>, StoreError> {
144 let spec = QuerySpec {
145 subject: subject.map(|s| s.to_string()),
146 predicate: predicate.map(|s| s.to_string()),
147 object: object.map(|s| s.to_string()),
148 namespace: Some(self.namespace.clone()),
149 ..Default::default()
150 };
151 let batches = self.inner.query(&spec)?;
152 Ok(batches_to_stored_triples(&batches))
153 }
154
155 pub fn count(
157 &self,
158 subject: Option<&str>,
159 predicate: Option<&str>,
160 object: Option<&str>,
161 ) -> usize {
162 self.query(subject, predicate, object)
163 .map(|v| v.len())
164 .unwrap_or(0)
165 }
166
167 pub fn get(&self, triple_id: &str) -> Option<StoredTriple> {
169 let spec = QuerySpec {
170 include_deleted: false,
171 ..Default::default()
172 };
173 let batches = self.inner.query(&spec).ok()?;
174 for batch in &batches {
175 let ids = batch
176 .column(col::TRIPLE_ID)
177 .as_any()
178 .downcast_ref::<StringArray>()?;
179 for i in 0..ids.len() {
180 if ids.value(i) == triple_id {
181 return Some(extract_stored_triple(batch, i));
182 }
183 }
184 }
185 None
186 }
187
188 pub fn update_confidence(
190 &mut self,
191 triple_id: &str,
192 confidence: f64,
193 ) -> Result<bool, StoreError> {
194 let existing = match self.get(triple_id) {
195 Some(t) => t,
196 None => return Ok(false),
197 };
198 self.inner.delete(triple_id)?;
199 let triple = Triple {
200 subject: existing.subject,
201 predicate: existing.predicate,
202 object: existing.object,
203 graph: existing.graph,
204 confidence: Some(confidence),
205 source_document: existing.source.clone(),
206 source_chunk_id: None,
207 extracted_by: existing.source,
208 caused_by: None,
209 derived_from: None,
210 consolidated_at: None,
211 };
212 self.inner
213 .add_triple(&triple, &self.namespace, self.layer)?;
214 Ok(true)
215 }
216
217 pub fn group_by(&self, field: &str) -> Result<HashMap<String, usize>, StoreError> {
219 let col_idx = match field {
220 "subject" => col::SUBJECT,
221 "predicate" => col::PREDICATE,
222 "object" => col::OBJECT,
223 _ => {
224 return Err(StoreError::Arrow(
225 arrow::error::ArrowError::InvalidArgumentError(format!(
226 "invalid group_by field: {field}"
227 )),
228 ));
229 }
230 };
231 let spec = QuerySpec {
232 namespace: Some(self.namespace.clone()),
233 ..Default::default()
234 };
235 let batches = self.inner.query(&spec)?;
236 let mut counts: HashMap<String, usize> = HashMap::new();
237 for batch in &batches {
238 let col_array = batch
239 .column(col_idx)
240 .as_any()
241 .downcast_ref::<StringArray>()
242 .expect("column must be StringArray");
243 for i in 0..col_array.len() {
244 *counts.entry(col_array.value(i).to_string()).or_insert(0) += 1;
245 }
246 }
247 Ok(counts)
248 }
249
250 pub fn stats(&self) -> StoreStats {
252 let spec = QuerySpec {
253 namespace: Some(self.namespace.clone()),
254 ..Default::default()
255 };
256 let batches = self.inner.query(&spec).unwrap_or_default();
257 let triples = batches_to_stored_triples(&batches);
258
259 let mut subjects = std::collections::HashSet::new();
260 let mut predicates = std::collections::HashSet::new();
261 let mut objects = std::collections::HashSet::new();
262 let mut by_source: HashMap<String, usize> = HashMap::new();
263
264 for t in &triples {
265 subjects.insert(t.subject.clone());
266 predicates.insert(t.predicate.clone());
267 objects.insert(t.object.clone());
268 if let Some(ref src) = t.source {
269 *by_source.entry(src.clone()).or_insert(0) += 1;
270 }
271 }
272
273 StoreStats {
274 total_triples: triples.len(),
275 unique_subjects: subjects.len(),
276 unique_predicates: predicates.len(),
277 unique_objects: objects.len(),
278 by_source,
279 }
280 }
281
282 pub fn len(&self) -> usize {
284 self.count(None, None, None)
285 }
286
287 pub fn is_empty(&self) -> bool {
289 self.len() == 0
290 }
291
292 pub fn inner(&self) -> &ArrowGraphStore {
294 &self.inner
295 }
296
297 pub fn inner_mut(&mut self) -> &mut ArrowGraphStore {
299 &mut self.inner
300 }
301}
302
303impl Default for SimpleTripleStore {
304 fn default() -> Self {
305 Self::new()
306 }
307}
308
309pub fn extract_stored_triple(batch: &RecordBatch, idx: usize) -> StoredTriple {
313 let ids = batch
314 .column(col::TRIPLE_ID)
315 .as_any()
316 .downcast_ref::<StringArray>()
317 .expect("triple_id column");
318 let subjects = batch
319 .column(col::SUBJECT)
320 .as_any()
321 .downcast_ref::<StringArray>()
322 .expect("subject column");
323 let predicates = batch
324 .column(col::PREDICATE)
325 .as_any()
326 .downcast_ref::<StringArray>()
327 .expect("predicate column");
328 let objects = batch
329 .column(col::OBJECT)
330 .as_any()
331 .downcast_ref::<StringArray>()
332 .expect("object column");
333 let graphs = batch
334 .column(col::GRAPH)
335 .as_any()
336 .downcast_ref::<StringArray>()
337 .expect("graph column");
338 let confidences = batch
339 .column(col::CONFIDENCE)
340 .as_any()
341 .downcast_ref::<Float64Array>()
342 .expect("confidence column");
343 let sources = batch
344 .column(col::EXTRACTED_BY)
345 .as_any()
346 .downcast_ref::<StringArray>()
347 .expect("extracted_by column");
348
349 StoredTriple {
350 id: ids.value(idx).to_string(),
351 subject: subjects.value(idx).to_string(),
352 predicate: predicates.value(idx).to_string(),
353 object: objects.value(idx).to_string(),
354 graph: if graphs.is_null(idx) {
355 None
356 } else {
357 Some(graphs.value(idx).to_string())
358 },
359 confidence: if confidences.is_null(idx) {
360 1.0
361 } else {
362 confidences.value(idx)
363 },
364 source: if sources.is_null(idx) {
365 None
366 } else {
367 Some(sources.value(idx).to_string())
368 },
369 }
370}
371
372pub fn batches_to_stored_triples(batches: &[RecordBatch]) -> Vec<StoredTriple> {
374 let mut result = Vec::new();
375 for batch in batches {
376 for i in 0..batch.num_rows() {
377 result.push(extract_stored_triple(batch, i));
378 }
379 }
380 result
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[test]
388 fn test_add_and_query() {
389 let mut store = SimpleTripleStore::new();
390 let id = store.add("Alice", "knows", "Bob", 0.9, "test").unwrap();
391 assert!(!id.is_empty());
392 assert_eq!(store.len(), 1);
393
394 let results = store.query(Some("Alice"), None, None).unwrap();
395 assert_eq!(results.len(), 1);
396 assert_eq!(results[0].subject, "Alice");
397 assert_eq!(results[0].predicate, "knows");
398 assert_eq!(results[0].object, "Bob");
399 assert!((results[0].confidence - 0.9).abs() < 1e-10);
400 }
401
402 #[test]
403 fn test_remove() {
404 let mut store = SimpleTripleStore::new();
405 let id = store.add("s", "p", "o", 1.0, "test").unwrap();
406 assert_eq!(store.len(), 1);
407
408 assert!(store.remove(&id).unwrap());
409 assert_eq!(store.len(), 0);
410 }
411
412 #[test]
413 fn test_query_wildcard() {
414 let mut store = SimpleTripleStore::new();
415 store.add("Alice", "knows", "Bob", 0.9, "test").unwrap();
416 store.add("Alice", "likes", "Carol", 0.8, "test").unwrap();
417 store.add("Bob", "knows", "Carol", 0.7, "test").unwrap();
418
419 assert_eq!(store.query(Some("Alice"), None, None).unwrap().len(), 2);
420 assert_eq!(store.query(None, Some("knows"), None).unwrap().len(), 2);
421 assert_eq!(store.query(None, None, Some("Carol")).unwrap().len(), 2);
422 assert_eq!(
423 store
424 .query(Some("Alice"), Some("knows"), None)
425 .unwrap()
426 .len(),
427 1
428 );
429 assert_eq!(store.query(None, None, None).unwrap().len(), 3);
430 }
431
432 #[test]
433 fn test_group_by() {
434 let mut store = SimpleTripleStore::new();
435 store.add("Alice", "knows", "Bob", 1.0, "test").unwrap();
436 store.add("Alice", "likes", "Carol", 1.0, "test").unwrap();
437 store.add("Bob", "knows", "Carol", 1.0, "test").unwrap();
438
439 let by_subj = store.group_by("subject").unwrap();
440 assert_eq!(by_subj["Alice"], 2);
441 assert_eq!(by_subj["Bob"], 1);
442 }
443
444 #[test]
445 fn test_stats() {
446 let mut store = SimpleTripleStore::new();
447 store.add("s1", "p1", "o1", 1.0, "src_a").unwrap();
448 store.add("s2", "p1", "o2", 1.0, "src_a").unwrap();
449 store.add("s1", "p2", "o1", 1.0, "src_b").unwrap();
450
451 let stats = store.stats();
452 assert_eq!(stats.total_triples, 3);
453 assert_eq!(stats.unique_subjects, 2);
454 assert_eq!(stats.unique_predicates, 2);
455 assert_eq!(stats.by_source["src_a"], 2);
456 assert_eq!(stats.by_source["src_b"], 1);
457 }
458
459 #[test]
460 fn test_batch_add() {
461 let mut store = SimpleTripleStore::new();
462 let ids = store
463 .add_batch(&[
464 ("s1", "p", "o1", 0.9, "batch"),
465 ("s2", "p", "o2", 0.8, "batch"),
466 ("s3", "p", "o3", 0.7, "batch"),
467 ])
468 .unwrap();
469 assert_eq!(ids.len(), 3);
470 assert_eq!(store.len(), 3);
471 }
472
473 #[test]
474 fn test_custom_namespace() {
475 let store = SimpleTripleStore::with_defaults("my_namespace", Some(5));
476 assert!(store.is_empty());
477 }
478}