1use crate::error::{Result, TdbError};
35use crate::index::btree_index::{EncodedTriple, TripleIndexSet};
36use std::collections::HashMap;
37use std::sync::{Arc, Mutex};
38use std::time::{Duration, Instant};
39
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub enum RdfNode {
47 Iri(String),
49 BlankNode(String),
51 Literal {
53 value: String,
55 datatype: Option<String>,
57 lang: Option<String>,
59 },
60}
61
62impl RdfNode {
63 pub fn to_canonical_string(&self) -> String {
67 match self {
68 Self::Iri(iri) => format!("<{}>", iri),
69 Self::BlankNode(label) => format!("_:{}", label),
70 Self::Literal {
71 value,
72 datatype,
73 lang,
74 } => {
75 if let Some(lang) = lang {
76 format!("\"{}\"@{}", value.replace('"', "\\\""), lang)
77 } else if let Some(dt) = datatype {
78 format!("\"{}\"^^<{}>", value.replace('"', "\\\""), dt)
79 } else {
80 format!("\"{}\"", value.replace('"', "\\\""))
81 }
82 }
83 }
84 }
85
86 pub fn from_canonical_string(s: &str) -> Result<Self> {
88 if s.starts_with('<') && s.ends_with('>') {
89 Ok(Self::Iri(s[1..s.len() - 1].to_string()))
90 } else if let Some(label) = s.strip_prefix("_:") {
91 Ok(Self::BlankNode(label.to_string()))
92 } else if let Some(s_inner) = s.strip_prefix('"') {
93 let (value, rest) = parse_quoted_string(s_inner)?;
95 if rest.is_empty() {
96 Ok(Self::Literal {
97 value,
98 datatype: None,
99 lang: None,
100 })
101 } else if let Some(lang) = rest.strip_prefix('@') {
102 Ok(Self::Literal {
103 value,
104 datatype: None,
105 lang: Some(lang.to_string()),
106 })
107 } else if let Some(dt_part) = rest.strip_prefix("^^<") {
108 let dt = dt_part.strip_suffix('>').ok_or_else(|| {
109 TdbError::InvalidInput(format!(
110 "malformed datatype in canonical literal: {}",
111 s
112 ))
113 })?;
114 Ok(Self::Literal {
115 value,
116 datatype: Some(dt.to_string()),
117 lang: None,
118 })
119 } else {
120 Err(TdbError::InvalidInput(format!(
121 "unrecognised canonical RDF node suffix: {}",
122 rest
123 )))
124 }
125 } else {
126 Err(TdbError::InvalidInput(format!(
127 "cannot parse canonical RDF node: {}",
128 s
129 )))
130 }
131 }
132}
133
134fn parse_quoted_string(s: &str) -> Result<(String, &str)> {
137 let mut value = String::new();
138 let mut chars = s.char_indices();
139 loop {
140 match chars.next() {
141 None => {
142 return Err(TdbError::InvalidInput(
143 "unterminated string literal".to_string(),
144 ))
145 }
146 Some((_, '\\')) => match chars.next() {
147 Some((_, '"')) => value.push('"'),
148 Some((_, '\\')) => value.push('\\'),
149 Some((_, 'n')) => value.push('\n'),
150 Some((_, 't')) => value.push('\t'),
151 Some((_, other)) => {
152 value.push('\\');
153 value.push(other);
154 }
155 None => {
156 return Err(TdbError::InvalidInput(
157 "trailing backslash in string literal".to_string(),
158 ))
159 }
160 },
161 Some((pos, '"')) => {
162 let rest = &s[pos + 1..];
163 return Ok((value, rest));
164 }
165 Some((_, ch)) => value.push(ch),
166 }
167 }
168}
169
170pub struct NodeDictionary {
179 node_to_id: HashMap<String, u64>,
180 id_to_node: Vec<RdfNode>,
181 next_id: u64,
182}
183
184impl Default for NodeDictionary {
185 fn default() -> Self {
186 Self::new()
187 }
188}
189
190impl NodeDictionary {
191 pub fn new() -> Self {
193 Self {
194 node_to_id: HashMap::new(),
195 id_to_node: Vec::new(), next_id: 1,
197 }
198 }
199
200 pub fn encode(&mut self, node: &RdfNode) -> u64 {
202 let key = node.to_canonical_string();
203 if let Some(&id) = self.node_to_id.get(&key) {
204 return id;
205 }
206 let id = self.next_id;
207 self.next_id += 1;
208 self.node_to_id.insert(key, id);
209 self.id_to_node.push(node.clone());
210 id
211 }
212
213 pub fn decode(&self, id: u64) -> Option<&RdfNode> {
215 if id == 0 || id as usize > self.id_to_node.len() {
216 return None;
217 }
218 self.id_to_node.get((id - 1) as usize)
219 }
220
221 pub fn get_id(&self, node: &RdfNode) -> Option<u64> {
223 let key = node.to_canonical_string();
224 self.node_to_id.get(&key).copied()
225 }
226
227 pub fn size(&self) -> usize {
229 self.node_to_id.len()
230 }
231
232 pub fn memory_bytes(&self) -> usize {
234 let key_bytes: usize = self
236 .node_to_id
237 .keys()
238 .map(|k| k.len() + std::mem::size_of::<String>())
239 .sum();
240 let node_bytes = self.id_to_node.len() * std::mem::size_of::<RdfNode>();
241 key_bytes + node_bytes + std::mem::size_of::<Self>()
242 }
243}
244
245pub trait TripleSource: Send {
254 fn next_batch(&mut self, batch_size: usize) -> Vec<RawTriple>;
256
257 fn is_exhausted(&self) -> bool;
259
260 fn estimated_total(&self) -> Option<usize>;
262}
263
264#[derive(Debug, Clone)]
266pub struct RawTriple {
267 pub subject: String,
269 pub predicate: String,
271 pub object: String,
273 pub graph: Option<String>,
275}
276
277impl RawTriple {
278 pub fn new(subject: &str, predicate: &str, object: &str) -> Self {
280 Self {
281 subject: subject.to_string(),
282 predicate: predicate.to_string(),
283 object: object.to_string(),
284 graph: None,
285 }
286 }
287}
288
289pub struct VecTripleSource {
297 triples: Vec<RawTriple>,
298 pos: usize,
299}
300
301impl VecTripleSource {
302 pub fn new(triples: Vec<RawTriple>) -> Self {
304 Self { triples, pos: 0 }
305 }
306
307 pub fn total(&self) -> usize {
309 self.triples.len()
310 }
311}
312
313impl TripleSource for VecTripleSource {
314 fn next_batch(&mut self, batch_size: usize) -> Vec<RawTriple> {
315 let end = (self.pos + batch_size).min(self.triples.len());
316 let batch = self.triples[self.pos..end].to_vec();
317 self.pos = end;
318 batch
319 }
320
321 fn is_exhausted(&self) -> bool {
322 self.pos >= self.triples.len()
323 }
324
325 fn estimated_total(&self) -> Option<usize> {
326 Some(self.triples.len())
327 }
328}
329
330#[derive(Debug, Clone)]
336pub struct ParallelBulkLoadConfig {
337 pub batch_size: usize,
339 pub sort_before_insert: bool,
341 pub progress_interval: usize,
343}
344
345impl Default for ParallelBulkLoadConfig {
346 fn default() -> Self {
347 Self {
348 batch_size: 100_000,
349 sort_before_insert: true,
350 progress_interval: 1_000_000,
351 }
352 }
353}
354
355#[derive(Debug, Clone)]
357pub struct ParallelBulkLoadStats {
358 pub triples_loaded: usize,
360 pub parse_errors: usize,
362 pub elapsed: Duration,
364 pub triples_per_second: f64,
366}
367
368impl ParallelBulkLoadStats {
369 fn new(triples_loaded: usize, parse_errors: usize, elapsed: Duration) -> Self {
370 let triples_per_second = if elapsed.as_secs_f64() > 0.0 {
371 triples_loaded as f64 / elapsed.as_secs_f64()
372 } else {
373 0.0
374 };
375 Self {
376 triples_loaded,
377 parse_errors,
378 elapsed,
379 triples_per_second,
380 }
381 }
382}
383
384pub struct ParallelBulkLoader {
394 config: ParallelBulkLoadConfig,
395}
396
397impl Default for ParallelBulkLoader {
398 fn default() -> Self {
399 Self::new(ParallelBulkLoadConfig::default())
400 }
401}
402
403impl ParallelBulkLoader {
404 pub fn new(config: ParallelBulkLoadConfig) -> Self {
406 Self { config }
407 }
408
409 pub fn load(
419 &self,
420 source: &mut dyn TripleSource,
421 dict: &Arc<Mutex<NodeDictionary>>,
422 index: &mut TripleIndexSet,
423 progress_cb: Option<&dyn Fn(usize, usize)>,
424 ) -> Result<ParallelBulkLoadStats> {
425 let start = Instant::now();
426 let estimated_total = source.estimated_total().unwrap_or(0);
427 let mut triples_loaded = 0usize;
428 let mut parse_errors = 0usize;
429
430 while !source.is_exhausted() {
431 let raw_batch = source.next_batch(self.config.batch_size);
432 if raw_batch.is_empty() {
433 break;
434 }
435
436 let (encoded, errors) = Self::encode_batch(&raw_batch, dict)?;
438 parse_errors += errors;
439
440 let mut sorted = encoded;
442 if self.config.sort_before_insert {
443 sorted.sort_by_key(|t| (t.s, t.p, t.o));
444 }
445
446 let inserted = Self::insert_batch(sorted, index);
448 triples_loaded += inserted;
449
450 if let Some(cb) = progress_cb {
452 if self.config.progress_interval > 0
453 && triples_loaded % self.config.progress_interval < self.config.batch_size
454 {
455 cb(triples_loaded, estimated_total);
456 }
457 }
458 }
459
460 let elapsed = start.elapsed();
461 Ok(ParallelBulkLoadStats::new(
462 triples_loaded,
463 parse_errors,
464 elapsed,
465 ))
466 }
467
468 fn encode_batch(
475 batch: &[RawTriple],
476 dict: &Arc<Mutex<NodeDictionary>>,
477 ) -> Result<(Vec<EncodedTriple>, usize)> {
478 let mut encoded = Vec::with_capacity(batch.len());
479 let mut errors = 0usize;
480
481 let mut dict_guard = dict
482 .lock()
483 .map_err(|_| TdbError::Other("bulk loader: dictionary mutex poisoned".to_string()))?;
484
485 for raw in batch {
486 let s_node = match RdfNode::from_canonical_string(&raw.subject) {
487 Ok(n) => n,
488 Err(_) => {
489 errors += 1;
490 continue;
491 }
492 };
493 let p_node = match RdfNode::from_canonical_string(&raw.predicate) {
494 Ok(n) => n,
495 Err(_) => {
496 errors += 1;
497 continue;
498 }
499 };
500 let o_node = match RdfNode::from_canonical_string(&raw.object) {
501 Ok(n) => n,
502 Err(_) => {
503 errors += 1;
504 continue;
505 }
506 };
507
508 let s = dict_guard.encode(&s_node);
509 let p = dict_guard.encode(&p_node);
510 let o = dict_guard.encode(&o_node);
511 encoded.push(EncodedTriple::new(s, p, o));
512 }
513
514 Ok((encoded, errors))
515 }
516
517 fn insert_batch(batch: Vec<EncodedTriple>, index: &mut TripleIndexSet) -> usize {
520 let before = index.len();
521 for triple in batch {
522 index.insert(triple);
523 }
524 index.len() - before
525 }
526}
527
528#[cfg(test)]
533mod tests {
534 use super::*;
535
536 fn make_raw(s: &str, p: &str, o: &str) -> RawTriple {
537 RawTriple::new(
538 &format!("<{}>", s),
539 &format!("<{}>", p),
540 &format!("<{}>", o),
541 )
542 }
543
544 fn default_dict() -> Arc<Mutex<NodeDictionary>> {
545 Arc::new(Mutex::new(NodeDictionary::new()))
546 }
547
548 #[test]
551 fn test_rdf_node_iri_roundtrip() {
552 let node = RdfNode::Iri("http://example.org/foo".to_string());
553 let canonical = node.to_canonical_string();
554 let parsed = RdfNode::from_canonical_string(&canonical).unwrap();
555 assert_eq!(node, parsed);
556 }
557
558 #[test]
559 fn test_rdf_node_blank_roundtrip() {
560 let node = RdfNode::BlankNode("b42".to_string());
561 let canonical = node.to_canonical_string();
562 let parsed = RdfNode::from_canonical_string(&canonical).unwrap();
563 assert_eq!(node, parsed);
564 }
565
566 #[test]
567 fn test_rdf_node_plain_literal_roundtrip() {
568 let node = RdfNode::Literal {
569 value: "hello world".to_string(),
570 datatype: None,
571 lang: None,
572 };
573 let canonical = node.to_canonical_string();
574 let parsed = RdfNode::from_canonical_string(&canonical).unwrap();
575 assert_eq!(node, parsed);
576 }
577
578 #[test]
579 fn test_rdf_node_typed_literal_roundtrip() {
580 let node = RdfNode::Literal {
581 value: "42".to_string(),
582 datatype: Some("http://www.w3.org/2001/XMLSchema#integer".to_string()),
583 lang: None,
584 };
585 let canonical = node.to_canonical_string();
586 let parsed = RdfNode::from_canonical_string(&canonical).unwrap();
587 assert_eq!(node, parsed);
588 }
589
590 #[test]
591 fn test_rdf_node_lang_literal_roundtrip() {
592 let node = RdfNode::Literal {
593 value: "bonjour".to_string(),
594 datatype: None,
595 lang: Some("fr".to_string()),
596 };
597 let canonical = node.to_canonical_string();
598 let parsed = RdfNode::from_canonical_string(&canonical).unwrap();
599 assert_eq!(node, parsed);
600 }
601
602 #[test]
605 fn test_dictionary_encode_decode_roundtrip() {
606 let mut dict = NodeDictionary::new();
607 let node = RdfNode::Iri("http://example.org/subject".to_string());
608 let id = dict.encode(&node);
609 assert_ne!(id, 0);
610 let decoded = dict.decode(id).unwrap();
611 assert_eq!(decoded, &node);
612 }
613
614 #[test]
615 fn test_dictionary_same_node_same_id() {
616 let mut dict = NodeDictionary::new();
617 let node = RdfNode::Iri("http://example.org/x".to_string());
618 let id1 = dict.encode(&node);
619 let id2 = dict.encode(&node);
620 assert_eq!(id1, id2);
621 assert_eq!(dict.size(), 1);
622 }
623
624 #[test]
625 fn test_dictionary_different_nodes_different_ids() {
626 let mut dict = NodeDictionary::new();
627 let a = RdfNode::Iri("http://a.org/".to_string());
628 let b = RdfNode::Iri("http://b.org/".to_string());
629 let id_a = dict.encode(&a);
630 let id_b = dict.encode(&b);
631 assert_ne!(id_a, id_b);
632 assert_eq!(dict.size(), 2);
633 }
634
635 #[test]
636 fn test_dictionary_get_id_not_present() {
637 let dict = NodeDictionary::new();
638 let node = RdfNode::Iri("http://missing.org/".to_string());
639 assert_eq!(dict.get_id(&node), None);
640 }
641
642 #[test]
643 fn test_dictionary_decode_unknown_id_returns_none() {
644 let dict = NodeDictionary::new();
645 assert!(dict.decode(999).is_none());
646 assert!(dict.decode(0).is_none());
647 }
648
649 #[test]
652 fn test_vec_triple_source_batching() {
653 let triples: Vec<RawTriple> = (0..10)
654 .map(|i| make_raw(&format!("http://s{i}"), "http://p", "http://o"))
655 .collect();
656 let mut source = VecTripleSource::new(triples);
657
658 assert_eq!(source.estimated_total(), Some(10));
659 assert!(!source.is_exhausted());
660
661 let batch1 = source.next_batch(6);
662 assert_eq!(batch1.len(), 6);
663 assert!(!source.is_exhausted());
664
665 let batch2 = source.next_batch(6);
666 assert_eq!(batch2.len(), 4);
667 assert!(source.is_exhausted());
668 }
669
670 #[test]
671 fn test_vec_triple_source_empty() {
672 let mut source = VecTripleSource::new(vec![]);
673 assert!(source.is_exhausted());
674 assert!(source.next_batch(10).is_empty());
675 }
676
677 #[test]
680 fn test_bulk_loader_basic() {
681 let raw_triples = vec![
682 make_raw("http://s1", "http://p1", "http://o1"),
683 make_raw("http://s2", "http://p2", "http://o2"),
684 ];
685 let mut source = VecTripleSource::new(raw_triples);
686 let dict = default_dict();
687 let mut index = TripleIndexSet::new();
688 let loader = ParallelBulkLoader::default();
689
690 let stats = loader.load(&mut source, &dict, &mut index, None).unwrap();
691
692 assert_eq!(stats.triples_loaded, 2);
693 assert_eq!(stats.parse_errors, 0);
694 assert_eq!(index.len(), 2);
695 assert!(stats.triples_per_second >= 0.0);
696 }
697
698 #[test]
699 fn test_bulk_loader_large_dataset() {
700 let raw_triples: Vec<RawTriple> = (0..1000)
701 .map(|i| make_raw(&format!("http://s{i}"), "http://p", &format!("http://o{i}")))
702 .collect();
703 let mut source = VecTripleSource::new(raw_triples);
704 let dict = default_dict();
705 let mut index = TripleIndexSet::new();
706 let loader = ParallelBulkLoader::new(ParallelBulkLoadConfig {
707 batch_size: 200,
708 ..Default::default()
709 });
710
711 let stats = loader.load(&mut source, &dict, &mut index, None).unwrap();
712
713 assert_eq!(stats.triples_loaded, 1000);
714 assert_eq!(index.len(), 1000);
715 }
716
717 #[test]
718 fn test_bulk_loader_duplicate_triples() {
719 let raw_triples = vec![
721 make_raw("http://s", "http://p", "http://o"),
722 make_raw("http://s", "http://p", "http://o"),
723 make_raw("http://s", "http://p", "http://o"),
724 ];
725 let mut source = VecTripleSource::new(raw_triples);
726 let dict = default_dict();
727 let mut index = TripleIndexSet::new();
728 let loader = ParallelBulkLoader::default();
729
730 loader.load(&mut source, &dict, &mut index, None).unwrap();
731
732 assert_eq!(index.len(), 1);
733 }
734
735 #[test]
736 fn test_bulk_loader_parse_errors() {
737 let mut source = VecTripleSource::new(vec![
739 make_raw("http://s1", "http://p1", "http://o1"),
740 RawTriple::new("INVALID_NOT_CANONICAL", "<http://p>", "<http://o>"),
741 ]);
742 let dict = default_dict();
743 let mut index = TripleIndexSet::new();
744 let loader = ParallelBulkLoader::default();
745
746 let stats = loader.load(&mut source, &dict, &mut index, None).unwrap();
747
748 assert_eq!(stats.parse_errors, 1);
749 assert_eq!(stats.triples_loaded, 1);
750 }
751
752 #[test]
753 fn test_bulk_loader_progress_callback() {
754 let raw_triples: Vec<RawTriple> = (0..50)
755 .map(|i| make_raw(&format!("http://s{i}"), "http://p", &format!("http://o{i}")))
756 .collect();
757 let mut source = VecTripleSource::new(raw_triples);
758 let dict = default_dict();
759 let mut index = TripleIndexSet::new();
760
761 let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
762 let call_count_clone = call_count.clone();
763
764 let loader = ParallelBulkLoader::new(ParallelBulkLoadConfig {
765 batch_size: 10,
766 progress_interval: 10,
767 ..Default::default()
768 });
769
770 loader
771 .load(
772 &mut source,
773 &dict,
774 &mut index,
775 Some(&|_loaded, _total| {
776 call_count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
777 }),
778 )
779 .unwrap();
780
781 assert_eq!(index.len(), 50);
782 }
784
785 #[test]
786 fn test_bulk_loader_dictionary_consistency() {
787 let raw_triples = vec![
788 make_raw("http://subject", "http://predicate", "http://object1"),
789 make_raw("http://subject", "http://predicate", "http://object2"),
790 ];
791 let mut source = VecTripleSource::new(raw_triples);
792 let dict = default_dict();
793 let mut index = TripleIndexSet::new();
794 let loader = ParallelBulkLoader::default();
795
796 loader.load(&mut source, &dict, &mut index, None).unwrap();
797
798 let dict_guard = dict.lock().unwrap_or_else(|e| e.into_inner());
799 assert_eq!(dict_guard.size(), 4); }
802
803 #[test]
804 fn test_bulk_loader_blank_nodes() {
805 let raw_triples = vec![
806 RawTriple::new("_:b0", "<http://p>", "<http://o>"),
807 RawTriple::new("_:b1", "<http://p>", "<http://o>"),
808 ];
809 let mut source = VecTripleSource::new(raw_triples);
810 let dict = default_dict();
811 let mut index = TripleIndexSet::new();
812 let loader = ParallelBulkLoader::default();
813
814 let stats = loader.load(&mut source, &dict, &mut index, None).unwrap();
815 assert_eq!(stats.triples_loaded, 2);
816 assert_eq!(stats.parse_errors, 0);
817 }
818
819 #[test]
820 fn test_bulk_load_stats_throughput() {
821 let stats = ParallelBulkLoadStats::new(1000, 0, Duration::from_secs(1));
822 assert!((stats.triples_per_second - 1000.0).abs() < 0.01);
823 }
824
825 #[test]
826 fn test_bulk_load_stats_zero_duration() {
827 let stats = ParallelBulkLoadStats::new(100, 0, Duration::from_nanos(0));
829 assert_eq!(stats.triples_per_second, 0.0);
830 }
831
832 #[test]
833 fn test_rdf_node_literal_with_quotes() {
834 let node = RdfNode::Literal {
835 value: "say \"hello\"".to_string(),
836 datatype: None,
837 lang: None,
838 };
839 let canonical = node.to_canonical_string();
840 let parsed = RdfNode::from_canonical_string(&canonical).unwrap();
841 assert_eq!(node, parsed);
842 }
843}