contrag_core/
context_builder.rs1use crate::entity::RagEntity;
2use crate::types::TextChunk;
3use crate::config::ChunkingConfig;
4
5pub struct ContextBuilder {
7 config: ChunkingConfig,
8}
9
10impl ContextBuilder {
11 pub fn new(config: ChunkingConfig) -> Self {
13 Self { config }
14 }
15
16 pub fn build_entity_context<T: RagEntity>(&self, entity: &T) -> String {
18 let context_map = entity.to_context_map();
19
20 let mut parts = vec![
21 format!("Entity: {}", T::entity_type()),
22 format!("ID: {}", entity.entity_id()),
23 String::from("---"),
24 ];
25
26 for (key, value) in context_map {
27 if self.config.include_field_names {
28 parts.push(format!("{}: {}", key, value));
29 } else {
30 parts.push(value);
31 }
32 }
33
34 parts.join("\n")
35 }
36
37 pub fn build_graph_context<T: RagEntity>(
39 &self,
40 root_entity: &T,
41 related_contexts: Vec<String>,
42 ) -> String {
43 let mut contexts = vec![self.build_entity_context(root_entity)];
44
45 let relationships = root_entity.relationships();
46
47 for (idx, related_ctx) in related_contexts.iter().enumerate() {
48 if let Some(rel) = relationships.get(idx) {
49 let annotated = format!(
50 "\n=== Relationship: {} ===\n{}\n",
51 rel.field_name,
52 related_ctx
53 );
54 contexts.push(annotated);
55 } else {
56 contexts.push(format!("\n{}\n", related_ctx));
57 }
58 }
59
60 contexts.join("\n")
61 }
62
63 pub fn chunk_text(&self, text: &str) -> Vec<TextChunk> {
65 if text.len() <= self.config.chunk_size {
66 return vec![TextChunk {
67 text: text.to_string(),
68 start_idx: 0,
69 end_idx: text.len(),
70 chunk_index: 0,
71 }];
72 }
73
74 let mut chunks = vec![];
75 let mut start = 0;
76 let mut chunk_index = 0;
77
78 while start < text.len() {
79 let end = (start + self.config.chunk_size).min(text.len());
80
81 let actual_end = if end < text.len() {
83 self.find_word_boundary(text, end)
84 } else {
85 end
86 };
87
88 chunks.push(TextChunk {
89 text: text[start..actual_end].to_string(),
90 start_idx: start,
91 end_idx: actual_end,
92 chunk_index,
93 });
94
95 if actual_end >= text.len() {
97 break;
98 }
99
100 start = actual_end.saturating_sub(self.config.overlap);
101 chunk_index += 1;
102 }
103
104 chunks
105 }
106
107 fn find_word_boundary(&self, text: &str, pos: usize) -> usize {
109 let chars: Vec<char> = text.chars().collect();
110
111 for i in (pos.saturating_sub(50)..pos).rev() {
113 if i >= chars.len() {
114 continue;
115 }
116 let c = chars[i];
117 if c.is_whitespace() || c == '.' || c == '!' || c == '?' || c == '\n' {
118 return i + 1;
119 }
120 }
121
122 pos
124 }
125
126 pub fn build_and_chunk<T: RagEntity>(&self, entity: &T) -> Vec<TextChunk> {
128 let context = self.build_entity_context(entity);
129 self.chunk_text(&context)
130 }
131
132 pub fn build_and_chunk_graph<T: RagEntity>(
134 &self,
135 root_entity: &T,
136 related_contexts: Vec<String>,
137 ) -> Vec<TextChunk> {
138 let context = self.build_graph_context(root_entity, related_contexts);
139 self.chunk_text(&context)
140 }
141
142 pub fn build_multi_entity_context<T: RagEntity>(&self, entities: &[T]) -> String {
144 entities
145 .iter()
146 .map(|entity| self.build_entity_context(entity))
147 .collect::<Vec<_>>()
148 .join("\n\n=== Next Entity ===\n\n")
149 }
150
151 pub fn get_chunk_stats(&self, text: &str) -> ChunkStats {
153 let chunks = self.chunk_text(text);
154 let total_chunks = chunks.len();
155 let avg_chunk_size = if total_chunks > 0 {
156 chunks.iter().map(|c| c.text.len()).sum::<usize>() / total_chunks
157 } else {
158 0
159 };
160
161 ChunkStats {
162 total_text_length: text.len(),
163 total_chunks,
164 avg_chunk_size,
165 chunk_size_config: self.config.chunk_size,
166 overlap_config: self.config.overlap,
167 }
168 }
169}
170
171#[derive(Debug, Clone)]
173pub struct ChunkStats {
174 pub total_text_length: usize,
175 pub total_chunks: usize,
176 pub avg_chunk_size: usize,
177 pub chunk_size_config: usize,
178 pub overlap_config: usize,
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_chunk_text_small() {
187 let config = ChunkingConfig {
188 chunk_size: 100,
189 overlap: 20,
190 include_field_names: true,
191 };
192 let builder = ContextBuilder::new(config);
193
194 let text = "Hello world";
195 let chunks = builder.chunk_text(text);
196
197 assert_eq!(chunks.len(), 1);
198 assert_eq!(chunks[0].text, "Hello world");
199 }
200
201 #[test]
202 fn test_chunk_text_large() {
203 let config = ChunkingConfig {
204 chunk_size: 50,
205 overlap: 10,
206 include_field_names: true,
207 };
208 let builder = ContextBuilder::new(config);
209
210 let text = "a".repeat(150);
211 let chunks = builder.chunk_text(&text);
212
213 assert!(chunks.len() > 1);
214 }
215}