1use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Instant;
19
20use argentor_core::ArgentorResult;
21use chrono::Utc;
22use serde::{Deserialize, Serialize};
23use uuid::Uuid;
24
25use crate::embedding::EmbeddingProvider;
26use crate::store::{MemoryEntry, VectorStore};
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Document {
35 pub id: String,
37 pub title: String,
39 pub content: String,
41 pub source: String,
43 pub metadata: HashMap<String, String>,
45 pub category: Option<String>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct DocumentChunk {
52 pub chunk_id: String,
54 pub document_id: String,
56 pub content: String,
58 pub chunk_index: usize,
60 pub token_estimate: usize,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum ChunkingStrategy {
67 FixedSize {
69 chunk_size: usize,
71 overlap: usize,
73 },
74 Paragraph,
76 Sentence,
78 Semantic {
81 max_chunk_tokens: usize,
83 },
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct RagConfig {
89 pub chunking: ChunkingStrategy,
91 pub top_k: usize,
93 pub min_relevance_score: f32,
95 pub include_metadata: bool,
97 pub max_context_tokens: usize,
99}
100
101impl Default for RagConfig {
102 fn default() -> Self {
103 Self {
104 chunking: ChunkingStrategy::FixedSize {
105 chunk_size: 512,
106 overlap: 64,
107 },
108 top_k: 5,
109 min_relevance_score: 0.3,
110 include_metadata: true,
111 max_context_tokens: 4096,
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct ScoredChunk {
119 pub chunk: DocumentChunk,
121 pub score: f32,
123 pub document_title: String,
125 pub source: String,
127}
128
129#[derive(Debug, Clone)]
131pub struct RagResult {
132 pub chunks: Vec<ScoredChunk>,
134 pub context_text: String,
136 pub total_chunks_searched: usize,
138 pub query_time_ms: u64,
140}
141
142fn estimate_tokens(text: &str) -> usize {
148 text.len().div_ceil(4)
150}
151
152fn chunk_document(doc: &Document, strategy: &ChunkingStrategy) -> Vec<DocumentChunk> {
154 let raw_chunks = match strategy {
155 ChunkingStrategy::FixedSize {
156 chunk_size,
157 overlap,
158 } => chunk_fixed_size(&doc.content, *chunk_size, *overlap),
159 ChunkingStrategy::Paragraph => chunk_paragraph(&doc.content),
160 ChunkingStrategy::Sentence => chunk_sentence(&doc.content),
161 ChunkingStrategy::Semantic { max_chunk_tokens } => {
162 chunk_semantic(&doc.content, *max_chunk_tokens)
163 }
164 };
165
166 raw_chunks
167 .into_iter()
168 .enumerate()
169 .map(|(idx, text)| DocumentChunk {
170 chunk_id: format!("{}_chunk_{}", doc.id, idx),
171 document_id: doc.id.clone(),
172 content: text,
173 chunk_index: idx,
174 token_estimate: 0, })
176 .map(|mut c| {
177 c.token_estimate = estimate_tokens(&c.content);
178 c
179 })
180 .collect()
181}
182
183fn chunk_fixed_size(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
185 if text.is_empty() || chunk_size == 0 {
186 return vec![];
187 }
188 let effective_overlap = overlap.min(chunk_size.saturating_sub(1));
189 let step = chunk_size.saturating_sub(effective_overlap).max(1);
190 let chars: Vec<char> = text.chars().collect();
191 let mut chunks = Vec::new();
192 let mut start = 0;
193 while start < chars.len() {
194 let end = (start + chunk_size).min(chars.len());
195 let chunk: String = chars[start..end].iter().collect();
196 let trimmed = chunk.trim().to_string();
197 if !trimmed.is_empty() {
198 chunks.push(trimmed);
199 }
200 if end == chars.len() {
201 break;
202 }
203 start += step;
204 }
205 chunks
206}
207
208fn chunk_paragraph(text: &str) -> Vec<String> {
210 let parts: Vec<String> = text
211 .split("\n\n")
212 .map(|p| p.trim().to_string())
213 .filter(|p| !p.is_empty())
214 .collect();
215 if parts.is_empty() {
216 let trimmed = text.trim().to_string();
218 if trimmed.is_empty() {
219 vec![]
220 } else {
221 vec![trimmed]
222 }
223 } else {
224 parts
225 }
226}
227
228fn chunk_sentence(text: &str) -> Vec<String> {
230 let mut sentences = Vec::new();
231 let mut current = String::new();
232 let chars: Vec<char> = text.chars().collect();
233 let len = chars.len();
234
235 for i in 0..len {
236 current.push(chars[i]);
237 let is_terminal = matches!(chars[i], '.' | '!' | '?');
238 let followed_by_space = i + 1 < len && chars[i + 1].is_whitespace();
239 if is_terminal && (followed_by_space || i + 1 == len) {
240 let trimmed = current.trim().to_string();
241 if !trimmed.is_empty() {
242 sentences.push(trimmed);
243 }
244 current.clear();
245 }
246 }
247 let trimmed = current.trim().to_string();
249 if !trimmed.is_empty() {
250 sentences.push(trimmed);
251 }
252 sentences
253}
254
255fn chunk_semantic(text: &str, max_chunk_tokens: usize) -> Vec<String> {
258 let max_tokens = if max_chunk_tokens == 0 {
259 256
260 } else {
261 max_chunk_tokens
262 };
263
264 let mut sections: Vec<String> = Vec::new();
265 let mut current_section = String::new();
266
267 for line in text.lines() {
268 let is_heading = line.starts_with('#');
269 if is_heading && !current_section.trim().is_empty() {
270 sections.push(current_section.trim().to_string());
271 current_section.clear();
272 }
273 if !current_section.is_empty() {
274 current_section.push('\n');
275 }
276 current_section.push_str(line);
277 }
278 if !current_section.trim().is_empty() {
279 sections.push(current_section.trim().to_string());
280 }
281
282 let mut merged: Vec<String> = Vec::new();
284 let mut buffer = String::new();
285
286 for section in sections {
287 let combined_tokens = estimate_tokens(&buffer) + estimate_tokens(§ion);
288 if buffer.is_empty() {
289 buffer = section;
290 } else if combined_tokens <= max_tokens {
291 buffer.push_str("\n\n");
292 buffer.push_str(§ion);
293 } else {
294 merged.push(buffer.trim().to_string());
295 buffer = section;
296 }
297 }
298 if !buffer.trim().is_empty() {
299 merged.push(buffer.trim().to_string());
300 }
301
302 if merged.is_empty() {
303 let trimmed = text.trim().to_string();
304 if trimmed.is_empty() {
305 vec![]
306 } else {
307 vec![trimmed]
308 }
309 } else {
310 merged
311 }
312}
313
314pub struct RagPipeline {
323 vector_store: Arc<dyn VectorStore>,
324 embedder: Arc<dyn EmbeddingProvider>,
325 config: RagConfig,
326 chunk_index: tokio::sync::RwLock<HashMap<Uuid, (DocumentChunk, String, String)>>,
328}
329
330impl RagPipeline {
331 pub fn new(
333 vector_store: Arc<dyn VectorStore>,
334 embedder: Arc<dyn EmbeddingProvider>,
335 config: RagConfig,
336 ) -> Self {
337 Self {
338 vector_store,
339 embedder,
340 config,
341 chunk_index: tokio::sync::RwLock::new(HashMap::new()),
342 }
343 }
344
345 pub async fn ingest_document(&self, doc: &Document) -> ArgentorResult<Vec<DocumentChunk>> {
347 let chunks = chunk_document(doc, &self.config.chunking);
348
349 for chunk in &chunks {
350 let embedding = self.embedder.embed(&chunk.content).await?;
351
352 let mut metadata = HashMap::new();
353 metadata.insert(
354 "document_id".to_string(),
355 serde_json::Value::String(doc.id.clone()),
356 );
357 metadata.insert(
358 "document_title".to_string(),
359 serde_json::Value::String(doc.title.clone()),
360 );
361 metadata.insert(
362 "source".to_string(),
363 serde_json::Value::String(doc.source.clone()),
364 );
365 metadata.insert(
366 "chunk_id".to_string(),
367 serde_json::Value::String(chunk.chunk_id.clone()),
368 );
369 metadata.insert(
370 "chunk_index".to_string(),
371 serde_json::json!(chunk.chunk_index),
372 );
373 if let Some(cat) = &doc.category {
374 metadata.insert(
375 "category".to_string(),
376 serde_json::Value::String(cat.clone()),
377 );
378 }
379 if self.config.include_metadata {
380 for (k, v) in &doc.metadata {
381 metadata.insert(k.clone(), serde_json::Value::String(v.clone()));
382 }
383 }
384
385 let entry_id = Uuid::new_v4();
386 let entry = MemoryEntry {
387 id: entry_id,
388 content: chunk.content.clone(),
389 embedding,
390 metadata,
391 session_id: None,
392 created_at: Utc::now(),
393 };
394
395 self.vector_store.insert(entry).await?;
396
397 let mut idx = self.chunk_index.write().await;
399 idx.insert(
400 entry_id,
401 (chunk.clone(), doc.title.clone(), doc.source.clone()),
402 );
403 }
404
405 Ok(chunks)
406 }
407
408 pub async fn ingest_batch(&self, docs: &[Document]) -> ArgentorResult<Vec<Vec<DocumentChunk>>> {
410 let mut all_chunks = Vec::with_capacity(docs.len());
411 for doc in docs {
412 let chunks = self.ingest_document(doc).await?;
413 all_chunks.push(chunks);
414 }
415 Ok(all_chunks)
416 }
417
418 pub async fn query(&self, question: &str, top_k: Option<usize>) -> ArgentorResult<RagResult> {
420 let start = Instant::now();
421 let k = top_k.unwrap_or(self.config.top_k);
422
423 let query_embedding = self.embedder.embed(question).await?;
424
425 let total_chunks_searched = self.vector_store.count().await?;
426
427 let fetch_k = (k * 3).max(k);
429 let results = self
430 .vector_store
431 .search(&query_embedding, fetch_k, None)
432 .await?;
433
434 let idx = self.chunk_index.read().await;
435
436 let mut scored: Vec<ScoredChunk> = results
437 .into_iter()
438 .filter(|r| r.score >= self.config.min_relevance_score)
439 .filter_map(|r| {
440 let (chunk, title, source) = idx.get(&r.entry.id)?;
441 Some(ScoredChunk {
442 chunk: chunk.clone(),
443 score: r.score,
444 document_title: title.clone(),
445 source: source.clone(),
446 })
447 })
448 .collect();
449
450 scored.sort_by(|a, b| {
451 b.score
452 .partial_cmp(&a.score)
453 .unwrap_or(std::cmp::Ordering::Equal)
454 });
455 scored.truncate(k);
456
457 let context_text = format_context(&scored, self.config.max_context_tokens);
458
459 let elapsed = start.elapsed().as_millis() as u64;
460
461 Ok(RagResult {
462 chunks: scored,
463 context_text,
464 total_chunks_searched,
465 query_time_ms: elapsed,
466 })
467 }
468
469 pub async fn query_with_context(
472 &self,
473 question: &str,
474 top_k: Option<usize>,
475 context_window: usize,
476 ) -> ArgentorResult<RagResult> {
477 let mut result = self.query(question, top_k).await?;
478 result.context_text = format_context(&result.chunks, context_window);
480 Ok(result)
481 }
482
483 pub fn config(&self) -> &RagConfig {
485 &self.config
486 }
487}
488
489fn format_context(chunks: &[ScoredChunk], max_tokens: usize) -> String {
492 let mut parts: Vec<String> = Vec::new();
493 let mut token_budget = max_tokens;
494
495 for (i, sc) in chunks.iter().enumerate() {
496 let header = format!(
497 "[Source: {} | Document: {} | Score: {:.2}]",
498 sc.source, sc.document_title, sc.score
499 );
500 let section = format!("--- Chunk {} ---\n{}\n{}", i + 1, header, sc.chunk.content);
501 let section_tokens = estimate_tokens(§ion);
502 if section_tokens > token_budget {
503 if token_budget > 20 {
505 let available_chars = token_budget * 4;
506 let truncated: String = section.chars().take(available_chars).collect();
507 parts.push(truncated);
508 }
509 break;
510 }
511 token_budget = token_budget.saturating_sub(section_tokens);
512 parts.push(section);
513 }
514
515 parts.join("\n\n")
516}
517
518#[cfg(test)]
523#[allow(clippy::unwrap_used, clippy::expect_used)]
524mod tests {
525 use super::*;
526 use crate::embedding::LocalEmbedding;
527 use crate::store::InMemoryVectorStore;
528
529 fn sample_doc(id: &str, title: &str, content: &str) -> Document {
532 Document {
533 id: id.to_string(),
534 title: title.to_string(),
535 content: content.to_string(),
536 source: "test".to_string(),
537 metadata: HashMap::new(),
538 category: None,
539 }
540 }
541
542 fn make_pipeline(config: RagConfig) -> RagPipeline {
543 let store = Arc::new(InMemoryVectorStore::new()) as Arc<dyn VectorStore>;
544 let embedder = Arc::new(LocalEmbedding::default()) as Arc<dyn EmbeddingProvider>;
545 RagPipeline::new(store, embedder, config)
546 }
547
548 fn default_pipeline() -> RagPipeline {
549 make_pipeline(RagConfig::default())
550 }
551
552 #[test]
557 fn test_chunk_fixed_size_basic() {
558 let chunks = chunk_fixed_size("abcdefghij", 4, 0);
559 assert_eq!(chunks.len(), 3); assert_eq!(chunks[0], "abcd");
561 assert_eq!(chunks[1], "efgh");
562 assert_eq!(chunks[2], "ij");
563 }
564
565 #[test]
566 fn test_chunk_fixed_size_with_overlap() {
567 let chunks = chunk_fixed_size("abcdefghij", 5, 2);
568 assert_eq!(chunks.len(), 3);
570 assert_eq!(chunks[0], "abcde");
571 assert_eq!(chunks[1], "defgh");
572 assert_eq!(chunks[2], "ghij");
573 }
574
575 #[test]
576 fn test_chunk_fixed_size_empty_text() {
577 let chunks = chunk_fixed_size("", 10, 0);
578 assert!(chunks.is_empty());
579 }
580
581 #[test]
582 fn test_chunk_fixed_size_zero_size() {
583 let chunks = chunk_fixed_size("hello", 0, 0);
584 assert!(chunks.is_empty());
585 }
586
587 #[test]
588 fn test_chunk_fixed_size_text_shorter_than_chunk() {
589 let chunks = chunk_fixed_size("hi", 100, 0);
590 assert_eq!(chunks.len(), 1);
591 assert_eq!(chunks[0], "hi");
592 }
593
594 #[test]
595 fn test_chunk_paragraph_basic() {
596 let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
597 let chunks = chunk_paragraph(text);
598 assert_eq!(chunks.len(), 3);
599 assert_eq!(chunks[0], "First paragraph.");
600 assert_eq!(chunks[1], "Second paragraph.");
601 assert_eq!(chunks[2], "Third paragraph.");
602 }
603
604 #[test]
605 fn test_chunk_paragraph_empty() {
606 let chunks = chunk_paragraph("");
607 assert!(chunks.is_empty());
608 }
609
610 #[test]
611 fn test_chunk_paragraph_single() {
612 let text = "Just one paragraph with no double newlines.";
613 let chunks = chunk_paragraph(text);
614 assert_eq!(chunks.len(), 1);
615 assert_eq!(chunks[0], text);
616 }
617
618 #[test]
619 fn test_chunk_sentence_basic() {
620 let text = "First sentence. Second sentence! Third sentence?";
621 let chunks = chunk_sentence(text);
622 assert_eq!(chunks.len(), 3);
623 assert_eq!(chunks[0], "First sentence.");
624 assert_eq!(chunks[1], "Second sentence!");
625 assert_eq!(chunks[2], "Third sentence?");
626 }
627
628 #[test]
629 fn test_chunk_sentence_no_terminal() {
630 let text = "No terminal punctuation here";
631 let chunks = chunk_sentence(text);
632 assert_eq!(chunks.len(), 1);
633 assert_eq!(chunks[0], text);
634 }
635
636 #[test]
637 fn test_chunk_sentence_empty() {
638 let chunks = chunk_sentence("");
639 assert!(chunks.is_empty());
640 }
641
642 #[test]
643 fn test_chunk_semantic_headings() {
644 let text = "# Heading 1\nParagraph one.\n# Heading 2\nParagraph two.";
645 let chunks = chunk_semantic(text, 1000);
646 assert!(!chunks.is_empty());
649 assert_eq!(chunks.len(), 1);
651 }
652
653 #[test]
654 fn test_chunk_semantic_splits_on_budget() {
655 let text = "# A\nLorem ipsum dolor sit amet.\n# B\nConsectetur adipiscing elit.";
656 let chunks = chunk_semantic(text, 8);
658 assert!(chunks.len() >= 2, "small budget should force split");
659 }
660
661 #[test]
662 fn test_chunk_semantic_empty() {
663 let chunks = chunk_semantic("", 100);
664 assert!(chunks.is_empty());
665 }
666
667 #[test]
668 fn test_estimate_tokens() {
669 assert_eq!(estimate_tokens(""), 0);
670 assert_eq!(estimate_tokens("abcd"), 1);
671 assert_eq!(estimate_tokens("abcdefghijkl"), 3);
673 }
674
675 #[test]
680 fn test_document_chunk_fields() {
681 let chunk = DocumentChunk {
682 chunk_id: "doc1_chunk_0".into(),
683 document_id: "doc1".into(),
684 content: "hello world".into(),
685 chunk_index: 0,
686 token_estimate: 3,
687 };
688 assert_eq!(chunk.chunk_id, "doc1_chunk_0");
689 assert_eq!(chunk.document_id, "doc1");
690 assert_eq!(chunk.chunk_index, 0);
691 assert_eq!(chunk.token_estimate, 3);
692 }
693
694 #[test]
695 fn test_rag_config_default() {
696 let cfg = RagConfig::default();
697 assert_eq!(cfg.top_k, 5);
698 assert!((cfg.min_relevance_score - 0.3).abs() < f32::EPSILON);
699 assert!(cfg.include_metadata);
700 assert_eq!(cfg.max_context_tokens, 4096);
701 }
702
703 #[tokio::test]
708 async fn test_ingest_document_creates_chunks() {
709 let pipeline = default_pipeline();
710 let doc = sample_doc("d1", "Test Doc", "Hello world. This is a test document.");
711 let chunks = pipeline.ingest_document(&doc).await.unwrap();
712 assert!(!chunks.is_empty(), "should produce at least one chunk");
713 for c in &chunks {
715 assert_eq!(c.document_id, "d1");
716 assert!(!c.content.is_empty());
717 assert!(c.token_estimate > 0);
718 }
719 }
720
721 #[tokio::test]
722 async fn test_ingest_batch() {
723 let pipeline = default_pipeline();
724 let docs = vec![
725 sample_doc("d1", "Doc One", "Content of document one."),
726 sample_doc("d2", "Doc Two", "Content of document two."),
727 ];
728 let all = pipeline.ingest_batch(&docs).await.unwrap();
729 assert_eq!(all.len(), 2);
730 assert!(!all[0].is_empty());
731 assert!(!all[1].is_empty());
732 }
733
734 #[tokio::test]
735 async fn test_query_returns_results() {
736 let pipeline = default_pipeline();
737 let doc = sample_doc(
738 "d1",
739 "Rust Book",
740 "Rust is a systems programming language focused on safety and performance.",
741 );
742 pipeline.ingest_document(&doc).await.unwrap();
743
744 let result = pipeline.query("rust programming", None).await.unwrap();
745 assert!(!result.chunks.is_empty(), "query should return results");
746 assert!(result.query_time_ms < 10_000, "query should be fast");
747 assert!(
748 result.total_chunks_searched > 0,
749 "should report chunks searched"
750 );
751 }
752
753 #[tokio::test]
754 async fn test_query_context_text_not_empty() {
755 let pipeline = default_pipeline();
756 let doc = sample_doc("d1", "FAQ", "How do I install Rust? Use rustup.");
757 pipeline.ingest_document(&doc).await.unwrap();
758
759 let result = pipeline.query("install rust", None).await.unwrap();
760 assert!(
761 !result.context_text.is_empty(),
762 "context_text should be populated"
763 );
764 assert!(
765 result.context_text.contains("Chunk 1"),
766 "context should include chunk header"
767 );
768 }
769
770 #[tokio::test]
771 async fn test_query_with_context_window() {
772 let cfg = RagConfig {
773 min_relevance_score: 0.0, ..RagConfig::default()
775 };
776 let pipeline = make_pipeline(cfg);
777 let doc = sample_doc(
778 "d1",
779 "Long Doc",
780 "Rust programming language. Memory safety without garbage collection. Zero-cost abstractions.",
781 );
782 pipeline.ingest_document(&doc).await.unwrap();
783
784 let result = pipeline
785 .query_with_context("rust programming language", None, 8192)
786 .await
787 .unwrap();
788 assert!(!result.context_text.is_empty());
789 }
790
791 #[tokio::test]
792 async fn test_query_min_relevance_filter() {
793 let cfg = RagConfig {
794 min_relevance_score: 0.99, ..RagConfig::default()
796 };
797 let pipeline = make_pipeline(cfg);
798 let doc = sample_doc("d1", "Doc", "some random content about various topics");
799 pipeline.ingest_document(&doc).await.unwrap();
800
801 let result = pipeline
802 .query("completely unrelated xyz", None)
803 .await
804 .unwrap();
805 assert!(
808 result.chunks.len() <= 1,
809 "high threshold should filter most results"
810 );
811 }
812
813 #[tokio::test]
814 async fn test_scored_chunk_has_metadata() {
815 let pipeline = default_pipeline();
816 let doc = sample_doc("d1", "My Title", "Content about Rust programming.");
817 pipeline.ingest_document(&doc).await.unwrap();
818
819 let result = pipeline.query("rust", None).await.unwrap();
820 if let Some(sc) = result.chunks.first() {
821 assert_eq!(sc.document_title, "My Title");
822 assert_eq!(sc.source, "test");
823 assert!(sc.score > 0.0);
824 }
825 }
826
827 #[tokio::test]
828 async fn test_config_accessor() {
829 let pipeline = default_pipeline();
830 assert_eq!(pipeline.config().top_k, 5);
831 }
832
833 #[test]
834 fn test_format_context_empty() {
835 let ctx = format_context(&[], 4096);
836 assert!(ctx.is_empty());
837 }
838
839 #[test]
840 fn test_format_context_includes_source() {
841 let chunks = vec![ScoredChunk {
842 chunk: DocumentChunk {
843 chunk_id: "c1".into(),
844 document_id: "d1".into(),
845 content: "Hello".into(),
846 chunk_index: 0,
847 token_estimate: 2,
848 },
849 score: 0.95,
850 document_title: "Title".into(),
851 source: "kb".into(),
852 }];
853 let ctx = format_context(&chunks, 4096);
854 assert!(ctx.contains("kb"));
855 assert!(ctx.contains("Title"));
856 assert!(ctx.contains("0.95"));
857 assert!(ctx.contains("Hello"));
858 }
859}