1use std::collections::{HashMap, HashSet, VecDeque};
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct ContextTriple {
17 pub subject: String,
18 pub predicate: String,
19 pub object: String,
20}
21
22impl ContextTriple {
23 pub fn new(
24 subject: impl Into<String>,
25 predicate: impl Into<String>,
26 object: impl Into<String>,
27 ) -> Self {
28 Self {
29 subject: subject.into(),
30 predicate: predicate.into(),
31 object: object.into(),
32 }
33 }
34
35 pub fn estimated_tokens(&self) -> usize {
39 let text = format!("{} {} {}", self.subject, self.predicate, self.object);
40 text.split(|c: char| c.is_whitespace() || c == '/' || c == '#')
42 .filter(|s| !s.is_empty())
43 .count()
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct ScoredTriple {
50 pub triple: ContextTriple,
51 pub score: f64,
52}
53
54#[derive(Debug, Clone)]
56pub struct BuiltContext {
57 pub triples: Vec<ContextTriple>,
59 pub text: String,
61 pub estimated_tokens: usize,
63 pub total_candidates: usize,
65}
66
67#[derive(Debug, Clone)]
69pub struct ContextBuilderConfig {
70 pub max_hops: usize,
72 pub token_budget: usize,
74 pub triple_template: String,
77 pub separator: String,
79}
80
81impl Default for ContextBuilderConfig {
82 fn default() -> Self {
83 Self {
84 max_hops: 2,
85 token_budget: 2048,
86 triple_template: "{s} -- {p} --> {o}".to_string(),
87 separator: "\n".to_string(),
88 }
89 }
90}
91
92pub struct KnowledgeGraph {
100 triples: Vec<ContextTriple>,
102 subject_index: HashMap<String, Vec<usize>>,
104 object_index: HashMap<String, Vec<usize>>,
106 predicate_index: HashMap<String, Vec<usize>>,
108}
109
110impl KnowledgeGraph {
111 pub fn new() -> Self {
113 Self {
114 triples: Vec::new(),
115 subject_index: HashMap::new(),
116 object_index: HashMap::new(),
117 predicate_index: HashMap::new(),
118 }
119 }
120
121 pub fn from_triples(triples: &[ContextTriple]) -> Self {
123 let mut kg = Self::new();
124 for t in triples {
125 kg.add_triple(t.clone());
126 }
127 kg
128 }
129
130 pub fn add_triple(&mut self, triple: ContextTriple) {
132 let idx = self.triples.len();
133 self.subject_index
134 .entry(triple.subject.clone())
135 .or_default()
136 .push(idx);
137 self.object_index
138 .entry(triple.object.clone())
139 .or_default()
140 .push(idx);
141 self.predicate_index
142 .entry(triple.predicate.clone())
143 .or_default()
144 .push(idx);
145 self.triples.push(triple);
146 }
147
148 pub fn len(&self) -> usize {
150 self.triples.len()
151 }
152
153 pub fn is_empty(&self) -> bool {
155 self.triples.is_empty()
156 }
157
158 pub fn neighbors(&self, entity: &str) -> Vec<&ContextTriple> {
160 let mut result: Vec<&ContextTriple> = Vec::new();
161 if let Some(indices) = self.subject_index.get(entity) {
162 for &idx in indices {
163 result.push(&self.triples[idx]);
164 }
165 }
166 if let Some(indices) = self.object_index.get(entity) {
167 for &idx in indices {
168 result.push(&self.triples[idx]);
169 }
170 }
171 result
172 }
173
174 pub fn triples_by_predicate(&self, predicate: &str) -> Vec<&ContextTriple> {
176 self.predicate_index
177 .get(predicate)
178 .map(|indices| indices.iter().map(|&idx| &self.triples[idx]).collect())
179 .unwrap_or_default()
180 }
181
182 pub fn all_triples(&self) -> &[ContextTriple] {
184 &self.triples
185 }
186}
187
188impl Default for KnowledgeGraph {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194pub struct ContextBuilder {
200 config: ContextBuilderConfig,
201}
202
203impl ContextBuilder {
204 pub fn new() -> Self {
206 Self {
207 config: ContextBuilderConfig::default(),
208 }
209 }
210
211 pub fn with_config(config: ContextBuilderConfig) -> Self {
213 Self { config }
214 }
215
216 pub fn extract_neighborhood(
220 &self,
221 kg: &KnowledgeGraph,
222 entity: &str,
223 max_hops: Option<usize>,
224 ) -> Vec<ContextTriple> {
225 let hops = max_hops.unwrap_or(self.config.max_hops);
226 let mut visited: HashSet<String> = HashSet::new();
227 let mut queue: VecDeque<(String, usize)> = VecDeque::new();
228 let mut result_set: HashSet<ContextTriple> = HashSet::new();
229
230 queue.push_back((entity.to_string(), 0));
231 visited.insert(entity.to_string());
232
233 while let Some((current, depth)) = queue.pop_front() {
234 let neighbors = kg.neighbors(¤t);
235 for triple in neighbors {
236 result_set.insert(triple.clone());
237
238 if depth < hops {
239 let other = if triple.subject == current {
241 &triple.object
242 } else {
243 &triple.subject
244 };
245 if !visited.contains(other.as_str()) {
246 visited.insert(other.clone());
247 queue.push_back((other.clone(), depth + 1));
248 }
249 }
250 }
251 }
252
253 result_set.into_iter().collect()
254 }
255
256 pub fn entity_neighborhood(&self, kg: &KnowledgeGraph, entity: &str) -> Vec<ContextTriple> {
258 self.extract_neighborhood(kg, entity, Some(1))
259 }
260
261 pub fn relation_context(&self, kg: &KnowledgeGraph, predicate: &str) -> Vec<ContextTriple> {
263 kg.triples_by_predicate(predicate)
264 .into_iter()
265 .cloned()
266 .collect()
267 }
268
269 pub fn rank_triples(
275 &self,
276 triples: &[ContextTriple],
277 seed_entities: &[&str],
278 ) -> Vec<ScoredTriple> {
279 let seeds: HashSet<&str> = seed_entities.iter().copied().collect();
280 let mut scored: Vec<ScoredTriple> = triples
281 .iter()
282 .map(|t| {
283 let mut score = 0.0;
284 if seeds.contains(t.subject.as_str()) {
285 score += 1.0;
286 }
287 if seeds.contains(t.object.as_str()) {
288 score += 1.0;
289 }
290 if seeds.contains(t.subject.as_str()) && seeds.contains(t.object.as_str()) {
292 score += 0.5;
293 }
294 ScoredTriple {
295 triple: t.clone(),
296 score,
297 }
298 })
299 .collect();
300
301 scored.sort_by(|a, b| {
302 b.score
303 .partial_cmp(&a.score)
304 .unwrap_or(std::cmp::Ordering::Equal)
305 });
306 scored
307 }
308
309 pub fn truncate_to_budget(&self, triples: &[ContextTriple]) -> Vec<ContextTriple> {
313 let mut result = Vec::new();
314 let mut tokens_used = 0;
315
316 for t in triples {
317 let t_tokens = t.estimated_tokens();
318 if tokens_used + t_tokens > self.config.token_budget {
319 break;
320 }
321 tokens_used += t_tokens;
322 result.push(t.clone());
323 }
324
325 result
326 }
327
328 pub fn format_triples(&self, triples: &[ContextTriple]) -> String {
332 triples
333 .iter()
334 .map(|t| {
335 self.config
336 .triple_template
337 .replace("{s}", &t.subject)
338 .replace("{p}", &t.predicate)
339 .replace("{o}", &t.object)
340 })
341 .collect::<Vec<_>>()
342 .join(&self.config.separator)
343 }
344
345 pub fn merge_contexts(&self, contexts: &[Vec<ContextTriple>]) -> Vec<ContextTriple> {
349 let mut seen: HashSet<ContextTriple> = HashSet::new();
350 let mut merged: Vec<ContextTriple> = Vec::new();
351
352 for ctx in contexts {
353 for triple in ctx {
354 if seen.insert(triple.clone()) {
355 merged.push(triple.clone());
356 }
357 }
358 }
359
360 merged
361 }
362
363 pub fn deduplicate(&self, triples: &[ContextTriple]) -> Vec<ContextTriple> {
365 let mut seen: HashSet<&ContextTriple> = HashSet::new();
366 let mut result: Vec<ContextTriple> = Vec::new();
367 for t in triples {
368 if seen.insert(t) {
369 result.push(t.clone());
370 }
371 }
372 result
373 }
374
375 pub fn build(&self, kg: &KnowledgeGraph, entity: &str) -> BuiltContext {
379 let candidates = self.extract_neighborhood(kg, entity, None);
380 let total_candidates = candidates.len();
381 let ranked = self.rank_triples(&candidates, &[entity]);
382 let ranked_triples: Vec<ContextTriple> = ranked.into_iter().map(|st| st.triple).collect();
383 let truncated = self.truncate_to_budget(&ranked_triples);
384 let text = self.format_triples(&truncated);
385 let estimated_tokens: usize = truncated.iter().map(|t| t.estimated_tokens()).sum();
386
387 BuiltContext {
388 triples: truncated,
389 text,
390 estimated_tokens,
391 total_candidates,
392 }
393 }
394
395 pub fn build_multi(&self, kg: &KnowledgeGraph, entities: &[&str]) -> BuiltContext {
397 let mut all_contexts: Vec<Vec<ContextTriple>> = Vec::new();
398 for &entity in entities {
399 all_contexts.push(self.extract_neighborhood(kg, entity, None));
400 }
401 let merged = self.merge_contexts(&all_contexts);
402 let total_candidates = merged.len();
403 let ranked = self.rank_triples(&merged, entities);
404 let ranked_triples: Vec<ContextTriple> = ranked.into_iter().map(|st| st.triple).collect();
405 let truncated = self.truncate_to_budget(&ranked_triples);
406 let text = self.format_triples(&truncated);
407 let estimated_tokens: usize = truncated.iter().map(|t| t.estimated_tokens()).sum();
408
409 BuiltContext {
410 triples: truncated,
411 text,
412 estimated_tokens,
413 total_candidates,
414 }
415 }
416
417 pub fn config(&self) -> &ContextBuilderConfig {
419 &self.config
420 }
421}
422
423impl Default for ContextBuilder {
424 fn default() -> Self {
425 Self::new()
426 }
427}
428
429#[cfg(test)]
434mod tests {
435 use super::*;
436
437 fn sample_kg() -> KnowledgeGraph {
438 KnowledgeGraph::from_triples(&[
439 ContextTriple::new("Alice", "knows", "Bob"),
440 ContextTriple::new("Bob", "knows", "Charlie"),
441 ContextTriple::new("Charlie", "knows", "Dave"),
442 ContextTriple::new("Alice", "likes", "Music"),
443 ContextTriple::new("Bob", "likes", "Sports"),
444 ContextTriple::new("Dave", "likes", "Art"),
445 ])
446 }
447
448 fn builder() -> ContextBuilder {
449 ContextBuilder::new()
450 }
451
452 #[test]
455 fn test_context_triple_new() {
456 let t = ContextTriple::new("s", "p", "o");
457 assert_eq!(t.subject, "s");
458 assert_eq!(t.predicate, "p");
459 assert_eq!(t.object, "o");
460 }
461
462 #[test]
463 fn test_context_triple_estimated_tokens() {
464 let t = ContextTriple::new("Alice", "knows", "Bob");
465 assert!(t.estimated_tokens() >= 3);
466 }
467
468 #[test]
469 fn test_context_triple_equality() {
470 let a = ContextTriple::new("s", "p", "o");
471 let b = ContextTriple::new("s", "p", "o");
472 assert_eq!(a, b);
473 }
474
475 #[test]
476 fn test_context_triple_inequality() {
477 let a = ContextTriple::new("s", "p", "o1");
478 let b = ContextTriple::new("s", "p", "o2");
479 assert_ne!(a, b);
480 }
481
482 #[test]
485 fn test_kg_new_empty() {
486 let kg = KnowledgeGraph::new();
487 assert!(kg.is_empty());
488 assert_eq!(kg.len(), 0);
489 }
490
491 #[test]
492 fn test_kg_from_triples() {
493 let kg = sample_kg();
494 assert_eq!(kg.len(), 6);
495 assert!(!kg.is_empty());
496 }
497
498 #[test]
499 fn test_kg_add_triple() {
500 let mut kg = KnowledgeGraph::new();
501 kg.add_triple(ContextTriple::new("A", "r", "B"));
502 assert_eq!(kg.len(), 1);
503 }
504
505 #[test]
506 fn test_kg_neighbors() {
507 let kg = sample_kg();
508 let n = kg.neighbors("Alice");
509 assert_eq!(n.len(), 2);
511 }
512
513 #[test]
514 fn test_kg_neighbors_as_object() {
515 let kg = sample_kg();
516 let n = kg.neighbors("Bob");
517 assert_eq!(n.len(), 3);
519 }
520
521 #[test]
522 fn test_kg_neighbors_unknown_entity() {
523 let kg = sample_kg();
524 let n = kg.neighbors("Unknown");
525 assert!(n.is_empty());
526 }
527
528 #[test]
529 fn test_kg_triples_by_predicate() {
530 let kg = sample_kg();
531 let knows = kg.triples_by_predicate("knows");
532 assert_eq!(knows.len(), 3);
533 }
534
535 #[test]
536 fn test_kg_triples_by_predicate_unknown() {
537 let kg = sample_kg();
538 assert!(kg.triples_by_predicate("unknown").is_empty());
539 }
540
541 #[test]
542 fn test_kg_all_triples() {
543 let kg = sample_kg();
544 assert_eq!(kg.all_triples().len(), 6);
545 }
546
547 #[test]
548 fn test_kg_default() {
549 let kg = KnowledgeGraph::default();
550 assert!(kg.is_empty());
551 }
552
553 #[test]
556 fn test_extract_neighborhood_1_hop() {
557 let kg = sample_kg();
558 let b = builder();
559 let ctx = b.extract_neighborhood(&kg, "Alice", Some(1));
560 assert!(ctx.len() >= 2);
562 }
563
564 #[test]
565 fn test_extract_neighborhood_2_hops() {
566 let kg = sample_kg();
567 let b = builder();
568 let ctx = b.extract_neighborhood(&kg, "Alice", Some(2));
569 assert!(ctx.len() > 2);
571 }
572
573 #[test]
574 fn test_extract_neighborhood_0_hops() {
575 let kg = sample_kg();
576 let b = builder();
577 let ctx = b.extract_neighborhood(&kg, "Alice", Some(0));
578 assert!(!ctx.is_empty());
580 }
581
582 #[test]
583 fn test_entity_neighborhood() {
584 let kg = sample_kg();
585 let b = builder();
586 let ctx = b.entity_neighborhood(&kg, "Bob");
587 assert!(ctx.len() >= 3); }
589
590 #[test]
593 fn test_relation_context() {
594 let kg = sample_kg();
595 let b = builder();
596 let ctx = b.relation_context(&kg, "likes");
597 assert_eq!(ctx.len(), 3);
598 }
599
600 #[test]
601 fn test_relation_context_unknown() {
602 let kg = sample_kg();
603 let b = builder();
604 let ctx = b.relation_context(&kg, "hates");
605 assert!(ctx.is_empty());
606 }
607
608 #[test]
611 fn test_rank_triples_seed_first() {
612 let b = builder();
613 let triples = vec![
614 ContextTriple::new("X", "r", "Y"),
615 ContextTriple::new("Alice", "r", "Bob"),
616 ];
617 let ranked = b.rank_triples(&triples, &["Alice"]);
618 assert_eq!(ranked[0].triple.subject, "Alice");
619 }
620
621 #[test]
622 fn test_rank_triples_both_seeds_highest() {
623 let b = builder();
624 let triples = vec![
625 ContextTriple::new("Alice", "r", "X"),
626 ContextTriple::new("Alice", "r", "Bob"),
627 ];
628 let ranked = b.rank_triples(&triples, &["Alice", "Bob"]);
629 assert!(ranked[0].score > ranked[1].score);
631 }
632
633 #[test]
634 fn test_rank_triples_empty() {
635 let b = builder();
636 let ranked = b.rank_triples(&[], &["Alice"]);
637 assert!(ranked.is_empty());
638 }
639
640 #[test]
643 fn test_truncate_to_budget() {
644 let b = ContextBuilder::with_config(ContextBuilderConfig {
645 token_budget: 10,
646 ..ContextBuilderConfig::default()
647 });
648 let triples: Vec<ContextTriple> = (0..100)
649 .map(|i| ContextTriple::new(format!("s{i}"), "p", format!("o{i}")))
650 .collect();
651 let truncated = b.truncate_to_budget(&triples);
652 let total_tokens: usize = truncated.iter().map(|t| t.estimated_tokens()).sum();
653 assert!(total_tokens <= 10);
654 }
655
656 #[test]
657 fn test_truncate_to_budget_all_fit() {
658 let b = ContextBuilder::with_config(ContextBuilderConfig {
659 token_budget: 100_000,
660 ..ContextBuilderConfig::default()
661 });
662 let triples = vec![
663 ContextTriple::new("A", "r", "B"),
664 ContextTriple::new("C", "r", "D"),
665 ];
666 let truncated = b.truncate_to_budget(&triples);
667 assert_eq!(truncated.len(), 2);
668 }
669
670 #[test]
673 fn test_format_triples_default_template() {
674 let b = builder();
675 let triples = vec![ContextTriple::new("Alice", "knows", "Bob")];
676 let text = b.format_triples(&triples);
677 assert!(text.contains("Alice"));
678 assert!(text.contains("knows"));
679 assert!(text.contains("Bob"));
680 }
681
682 #[test]
683 fn test_format_triples_custom_template() {
684 let b = ContextBuilder::with_config(ContextBuilderConfig {
685 triple_template: "({s}, {p}, {o})".to_string(),
686 separator: "; ".to_string(),
687 ..ContextBuilderConfig::default()
688 });
689 let triples = vec![
690 ContextTriple::new("A", "r", "B"),
691 ContextTriple::new("C", "r", "D"),
692 ];
693 let text = b.format_triples(&triples);
694 assert!(text.contains("(A, r, B); (C, r, D)"));
695 }
696
697 #[test]
698 fn test_format_triples_empty() {
699 let b = builder();
700 let text = b.format_triples(&[]);
701 assert!(text.is_empty());
702 }
703
704 #[test]
707 fn test_merge_contexts_deduplicates() {
708 let b = builder();
709 let t = ContextTriple::new("A", "r", "B");
710 let ctx1 = vec![t.clone()];
711 let ctx2 = vec![t.clone()];
712 let merged = b.merge_contexts(&[ctx1, ctx2]);
713 assert_eq!(merged.len(), 1);
714 }
715
716 #[test]
717 fn test_merge_contexts_combines() {
718 let b = builder();
719 let ctx1 = vec![ContextTriple::new("A", "r", "B")];
720 let ctx2 = vec![ContextTriple::new("C", "r", "D")];
721 let merged = b.merge_contexts(&[ctx1, ctx2]);
722 assert_eq!(merged.len(), 2);
723 }
724
725 #[test]
726 fn test_merge_contexts_empty() {
727 let b = builder();
728 let merged = b.merge_contexts(&[]);
729 assert!(merged.is_empty());
730 }
731
732 #[test]
733 fn test_deduplicate() {
734 let b = builder();
735 let t = ContextTriple::new("A", "r", "B");
736 let triples = vec![t.clone(), t.clone(), t];
737 let deduped = b.deduplicate(&triples);
738 assert_eq!(deduped.len(), 1);
739 }
740
741 #[test]
742 fn test_deduplicate_preserves_order() {
743 let b = builder();
744 let triples = vec![
745 ContextTriple::new("C", "r", "D"),
746 ContextTriple::new("A", "r", "B"),
747 ContextTriple::new("C", "r", "D"),
748 ];
749 let deduped = b.deduplicate(&triples);
750 assert_eq!(deduped.len(), 2);
751 assert_eq!(deduped[0].subject, "C");
752 assert_eq!(deduped[1].subject, "A");
753 }
754
755 #[test]
758 fn test_build_single_entity() {
759 let kg = sample_kg();
760 let b = builder();
761 let ctx = b.build(&kg, "Alice");
762 assert!(!ctx.triples.is_empty());
763 assert!(!ctx.text.is_empty());
764 assert!(ctx.estimated_tokens > 0);
765 }
766
767 #[test]
768 fn test_build_unknown_entity() {
769 let kg = sample_kg();
770 let b = builder();
771 let ctx = b.build(&kg, "Unknown");
772 assert!(ctx.triples.is_empty());
773 }
774
775 #[test]
776 fn test_build_multi_entity() {
777 let kg = sample_kg();
778 let b = builder();
779 let ctx = b.build_multi(&kg, &["Alice", "Dave"]);
780 assert!(!ctx.triples.is_empty());
781 assert!(ctx.total_candidates >= 2);
782 }
783
784 #[test]
785 fn test_build_multi_empty_entities() {
786 let kg = sample_kg();
787 let b = builder();
788 let ctx = b.build_multi(&kg, &[]);
789 assert!(ctx.triples.is_empty());
790 }
791
792 #[test]
795 fn test_config_default() {
796 let cfg = ContextBuilderConfig::default();
797 assert_eq!(cfg.max_hops, 2);
798 assert_eq!(cfg.token_budget, 2048);
799 }
800
801 #[test]
802 fn test_config_access() {
803 let b = builder();
804 assert_eq!(b.config().max_hops, 2);
805 }
806
807 #[test]
808 fn test_builder_default() {
809 let b = ContextBuilder::default();
810 assert_eq!(b.config().max_hops, 2);
811 }
812
813 #[test]
816 fn test_zero_token_budget() {
817 let b = ContextBuilder::with_config(ContextBuilderConfig {
818 token_budget: 0,
819 ..ContextBuilderConfig::default()
820 });
821 let triples = vec![ContextTriple::new("A", "r", "B")];
822 let truncated = b.truncate_to_budget(&triples);
823 assert!(truncated.is_empty());
824 }
825
826 #[test]
829 fn test_scored_triple_fields() {
830 let st = ScoredTriple {
831 triple: ContextTriple::new("A", "r", "B"),
832 score: 0.9,
833 };
834 assert_eq!(st.triple.subject, "A");
835 assert!((st.score - 0.9).abs() < 1e-10);
836 }
837
838 #[test]
841 fn test_built_context_total_candidates() {
842 let kg = sample_kg();
843 let b = builder();
844 let ctx = b.build(&kg, "Bob");
845 assert!(ctx.total_candidates > 0);
846 assert!(ctx.triples.len() <= ctx.total_candidates);
847 }
848}