1use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Instant;
19
20use argentor_core::ArgentorResult;
21use chrono::Utc;
22use serde::{Deserialize, Serialize};
23use tracing::instrument;
24use uuid::Uuid;
25
26use crate::embedding::EmbeddingProvider;
27use crate::store::{MemoryEntry, VectorStore};
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct Document {
36 pub id: String,
38 pub title: String,
40 pub content: String,
42 pub source: String,
44 pub metadata: HashMap<String, String>,
46 pub category: Option<String>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct DocumentChunk {
53 pub chunk_id: String,
55 pub document_id: String,
57 pub content: String,
59 pub chunk_index: usize,
61 pub token_estimate: usize,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub enum ChunkingStrategy {
68 FixedSize {
70 chunk_size: usize,
72 overlap: usize,
74 },
75 Paragraph,
77 Sentence,
79 Semantic {
82 max_chunk_tokens: usize,
84 },
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct RagConfig {
90 pub chunking: ChunkingStrategy,
92 pub top_k: usize,
94 pub min_relevance_score: f32,
96 pub include_metadata: bool,
98 pub max_context_tokens: usize,
100}
101
102impl Default for RagConfig {
103 fn default() -> Self {
104 Self {
105 chunking: ChunkingStrategy::FixedSize {
106 chunk_size: 512,
107 overlap: 64,
108 },
109 top_k: 5,
110 min_relevance_score: 0.3,
111 include_metadata: true,
112 max_context_tokens: 4096,
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
119pub struct ScoredChunk {
120 pub chunk: DocumentChunk,
122 pub score: f32,
124 pub document_title: String,
126 pub source: String,
128}
129
130#[derive(Debug, Clone)]
132pub struct RagResult {
133 pub chunks: Vec<ScoredChunk>,
135 pub context_text: String,
137 pub total_chunks_searched: usize,
139 pub query_time_ms: u64,
141}
142
143fn estimate_tokens(text: &str) -> usize {
149 text.len().div_ceil(4)
151}
152
153fn chunk_document(doc: &Document, strategy: &ChunkingStrategy) -> Vec<DocumentChunk> {
155 let raw_chunks = match strategy {
156 ChunkingStrategy::FixedSize {
157 chunk_size,
158 overlap,
159 } => chunk_fixed_size(&doc.content, *chunk_size, *overlap),
160 ChunkingStrategy::Paragraph => chunk_paragraph(&doc.content),
161 ChunkingStrategy::Sentence => chunk_sentence(&doc.content),
162 ChunkingStrategy::Semantic { max_chunk_tokens } => {
163 chunk_semantic(&doc.content, *max_chunk_tokens)
164 }
165 };
166
167 raw_chunks
168 .into_iter()
169 .enumerate()
170 .map(|(idx, text)| DocumentChunk {
171 chunk_id: format!("{}_chunk_{}", doc.id, idx),
172 document_id: doc.id.clone(),
173 content: text,
174 chunk_index: idx,
175 token_estimate: 0, })
177 .map(|mut c| {
178 c.token_estimate = estimate_tokens(&c.content);
179 c
180 })
181 .collect()
182}
183
184fn chunk_fixed_size(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
186 if text.is_empty() || chunk_size == 0 {
187 return vec![];
188 }
189 let effective_overlap = overlap.min(chunk_size.saturating_sub(1));
190 let step = chunk_size.saturating_sub(effective_overlap).max(1);
191 let chars: Vec<char> = text.chars().collect();
192 let mut chunks = Vec::new();
193 let mut start = 0;
194 while start < chars.len() {
195 let end = (start + chunk_size).min(chars.len());
196 let chunk: String = chars[start..end].iter().collect();
197 let trimmed = chunk.trim().to_string();
198 if !trimmed.is_empty() {
199 chunks.push(trimmed);
200 }
201 if end == chars.len() {
202 break;
203 }
204 start += step;
205 }
206 chunks
207}
208
209fn chunk_paragraph(text: &str) -> Vec<String> {
211 let parts: Vec<String> = text
212 .split("\n\n")
213 .map(|p| p.trim().to_string())
214 .filter(|p| !p.is_empty())
215 .collect();
216 if parts.is_empty() {
217 let trimmed = text.trim().to_string();
219 if trimmed.is_empty() {
220 vec![]
221 } else {
222 vec![trimmed]
223 }
224 } else {
225 parts
226 }
227}
228
229fn chunk_sentence(text: &str) -> Vec<String> {
231 let mut sentences = Vec::new();
232 let mut current = String::new();
233 let chars: Vec<char> = text.chars().collect();
234 let len = chars.len();
235
236 for i in 0..len {
237 current.push(chars[i]);
238 let is_terminal = matches!(chars[i], '.' | '!' | '?');
239 let followed_by_space = i + 1 < len && chars[i + 1].is_whitespace();
240 if is_terminal && (followed_by_space || i + 1 == len) {
241 let trimmed = current.trim().to_string();
242 if !trimmed.is_empty() {
243 sentences.push(trimmed);
244 }
245 current.clear();
246 }
247 }
248 let trimmed = current.trim().to_string();
250 if !trimmed.is_empty() {
251 sentences.push(trimmed);
252 }
253 sentences
254}
255
256fn chunk_semantic(text: &str, max_chunk_tokens: usize) -> Vec<String> {
259 let max_tokens = if max_chunk_tokens == 0 {
260 256
261 } else {
262 max_chunk_tokens
263 };
264
265 let mut sections: Vec<String> = Vec::new();
266 let mut current_section = String::new();
267
268 for line in text.lines() {
269 let is_heading = line.starts_with('#');
270 if is_heading && !current_section.trim().is_empty() {
271 sections.push(current_section.trim().to_string());
272 current_section.clear();
273 }
274 if !current_section.is_empty() {
275 current_section.push('\n');
276 }
277 current_section.push_str(line);
278 }
279 if !current_section.trim().is_empty() {
280 sections.push(current_section.trim().to_string());
281 }
282
283 let mut merged: Vec<String> = Vec::new();
285 let mut buffer = String::new();
286
287 for section in sections {
288 let combined_tokens = estimate_tokens(&buffer) + estimate_tokens(§ion);
289 if buffer.is_empty() {
290 buffer = section;
291 } else if combined_tokens <= max_tokens {
292 buffer.push_str("\n\n");
293 buffer.push_str(§ion);
294 } else {
295 merged.push(buffer.trim().to_string());
296 buffer = section;
297 }
298 }
299 if !buffer.trim().is_empty() {
300 merged.push(buffer.trim().to_string());
301 }
302
303 if merged.is_empty() {
304 let trimmed = text.trim().to_string();
305 if trimmed.is_empty() {
306 vec![]
307 } else {
308 vec![trimmed]
309 }
310 } else {
311 merged
312 }
313}
314
315pub struct RagPipeline {
324 vector_store: Arc<dyn VectorStore>,
325 embedder: Arc<dyn EmbeddingProvider>,
326 config: RagConfig,
327 chunk_index: tokio::sync::RwLock<HashMap<Uuid, (DocumentChunk, String, String)>>,
329}
330
331impl RagPipeline {
332 pub fn new(
334 vector_store: Arc<dyn VectorStore>,
335 embedder: Arc<dyn EmbeddingProvider>,
336 config: RagConfig,
337 ) -> Self {
338 Self {
339 vector_store,
340 embedder,
341 config,
342 chunk_index: tokio::sync::RwLock::new(HashMap::new()),
343 }
344 }
345
346 pub async fn ingest_document(&self, doc: &Document) -> ArgentorResult<Vec<DocumentChunk>> {
348 let chunks = chunk_document(doc, &self.config.chunking);
349
350 for chunk in &chunks {
351 let embedding = self.embedder.embed(&chunk.content).await?;
352
353 let mut metadata = HashMap::new();
354 metadata.insert(
355 "document_id".to_string(),
356 serde_json::Value::String(doc.id.clone()),
357 );
358 metadata.insert(
359 "document_title".to_string(),
360 serde_json::Value::String(doc.title.clone()),
361 );
362 metadata.insert(
363 "source".to_string(),
364 serde_json::Value::String(doc.source.clone()),
365 );
366 metadata.insert(
367 "chunk_id".to_string(),
368 serde_json::Value::String(chunk.chunk_id.clone()),
369 );
370 metadata.insert(
371 "chunk_index".to_string(),
372 serde_json::json!(chunk.chunk_index),
373 );
374 if let Some(cat) = &doc.category {
375 metadata.insert(
376 "category".to_string(),
377 serde_json::Value::String(cat.clone()),
378 );
379 }
380 if self.config.include_metadata {
381 for (k, v) in &doc.metadata {
382 metadata.insert(k.clone(), serde_json::Value::String(v.clone()));
383 }
384 }
385
386 let entry_id = Uuid::new_v4();
387 let entry = MemoryEntry {
388 id: entry_id,
389 content: chunk.content.clone(),
390 embedding,
391 metadata,
392 session_id: None,
393 created_at: Utc::now(),
394 };
395
396 self.vector_store.insert(entry).await?;
397
398 let mut idx = self.chunk_index.write().await;
400 idx.insert(
401 entry_id,
402 (chunk.clone(), doc.title.clone(), doc.source.clone()),
403 );
404 }
405
406 Ok(chunks)
407 }
408
409 pub async fn ingest_batch(&self, docs: &[Document]) -> ArgentorResult<Vec<Vec<DocumentChunk>>> {
411 let mut all_chunks = Vec::with_capacity(docs.len());
412 for doc in docs {
413 let chunks = self.ingest_document(doc).await?;
414 all_chunks.push(chunks);
415 }
416 Ok(all_chunks)
417 }
418
419 #[instrument(
421 skip(self, question),
422 fields(
423 query_length = question.len(),
424 chunks_searched = tracing::field::Empty,
425 results_count = tracing::field::Empty,
426 )
427 )]
428 pub async fn query(&self, question: &str, top_k: Option<usize>) -> ArgentorResult<RagResult> {
429 let start = Instant::now();
430 let k = top_k.unwrap_or(self.config.top_k);
431
432 let query_embedding = self.embedder.embed(question).await?;
433
434 let total_chunks_searched = self.vector_store.count().await?;
435
436 let fetch_k = (k * 3).max(k);
438 let results = self
439 .vector_store
440 .search(&query_embedding, fetch_k, None)
441 .await?;
442
443 let idx = self.chunk_index.read().await;
444
445 let mut scored: Vec<ScoredChunk> = results
446 .into_iter()
447 .filter(|r| r.score >= self.config.min_relevance_score)
448 .filter_map(|r| {
449 let (chunk, title, source) = idx.get(&r.entry.id)?;
450 Some(ScoredChunk {
451 chunk: chunk.clone(),
452 score: r.score,
453 document_title: title.clone(),
454 source: source.clone(),
455 })
456 })
457 .collect();
458
459 scored.sort_by(|a, b| {
460 b.score
461 .partial_cmp(&a.score)
462 .unwrap_or(std::cmp::Ordering::Equal)
463 });
464 scored.truncate(k);
465
466 let context_text = format_context(&scored, self.config.max_context_tokens);
467
468 let elapsed = start.elapsed().as_millis() as u64;
469
470 tracing::Span::current()
471 .record("chunks_searched", total_chunks_searched)
472 .record("results_count", scored.len());
473
474 Ok(RagResult {
475 chunks: scored,
476 context_text,
477 total_chunks_searched,
478 query_time_ms: elapsed,
479 })
480 }
481
482 pub async fn query_with_context(
485 &self,
486 question: &str,
487 top_k: Option<usize>,
488 context_window: usize,
489 ) -> ArgentorResult<RagResult> {
490 let mut result = self.query(question, top_k).await?;
491 result.context_text = format_context(&result.chunks, context_window);
493 Ok(result)
494 }
495
496 pub fn config(&self) -> &RagConfig {
498 &self.config
499 }
500}
501
502fn format_context(chunks: &[ScoredChunk], max_tokens: usize) -> String {
505 let mut parts: Vec<String> = Vec::new();
506 let mut token_budget = max_tokens;
507
508 for (i, sc) in chunks.iter().enumerate() {
509 let header = format!(
510 "[Source: {} | Document: {} | Score: {:.2}]",
511 sc.source, sc.document_title, sc.score
512 );
513 let section = format!("--- Chunk {} ---\n{}\n{}", i + 1, header, sc.chunk.content);
514 let section_tokens = estimate_tokens(§ion);
515 if section_tokens > token_budget {
516 if token_budget > 20 {
518 let available_chars = token_budget * 4;
519 let truncated: String = section.chars().take(available_chars).collect();
520 parts.push(truncated);
521 }
522 break;
523 }
524 token_budget = token_budget.saturating_sub(section_tokens);
525 parts.push(section);
526 }
527
528 parts.join("\n\n")
529}
530
531#[cfg(test)]
536#[allow(clippy::unwrap_used, clippy::expect_used)]
537mod tests {
538 use super::*;
539 use crate::embedding::LocalEmbedding;
540 use crate::store::InMemoryVectorStore;
541
542 fn sample_doc(id: &str, title: &str, content: &str) -> Document {
545 Document {
546 id: id.to_string(),
547 title: title.to_string(),
548 content: content.to_string(),
549 source: "test".to_string(),
550 metadata: HashMap::new(),
551 category: None,
552 }
553 }
554
555 fn make_pipeline(config: RagConfig) -> RagPipeline {
556 let store = Arc::new(InMemoryVectorStore::new()) as Arc<dyn VectorStore>;
557 let embedder = Arc::new(LocalEmbedding::default()) as Arc<dyn EmbeddingProvider>;
558 RagPipeline::new(store, embedder, config)
559 }
560
561 fn default_pipeline() -> RagPipeline {
562 make_pipeline(RagConfig::default())
563 }
564
565 #[test]
570 fn test_chunk_fixed_size_basic() {
571 let chunks = chunk_fixed_size("abcdefghij", 4, 0);
572 assert_eq!(chunks.len(), 3); assert_eq!(chunks[0], "abcd");
574 assert_eq!(chunks[1], "efgh");
575 assert_eq!(chunks[2], "ij");
576 }
577
578 #[test]
579 fn test_chunk_fixed_size_with_overlap() {
580 let chunks = chunk_fixed_size("abcdefghij", 5, 2);
581 assert_eq!(chunks.len(), 3);
583 assert_eq!(chunks[0], "abcde");
584 assert_eq!(chunks[1], "defgh");
585 assert_eq!(chunks[2], "ghij");
586 }
587
588 #[test]
589 fn test_chunk_fixed_size_empty_text() {
590 let chunks = chunk_fixed_size("", 10, 0);
591 assert!(chunks.is_empty());
592 }
593
594 #[test]
595 fn test_chunk_fixed_size_zero_size() {
596 let chunks = chunk_fixed_size("hello", 0, 0);
597 assert!(chunks.is_empty());
598 }
599
600 #[test]
601 fn test_chunk_fixed_size_text_shorter_than_chunk() {
602 let chunks = chunk_fixed_size("hi", 100, 0);
603 assert_eq!(chunks.len(), 1);
604 assert_eq!(chunks[0], "hi");
605 }
606
607 #[test]
608 fn test_chunk_paragraph_basic() {
609 let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
610 let chunks = chunk_paragraph(text);
611 assert_eq!(chunks.len(), 3);
612 assert_eq!(chunks[0], "First paragraph.");
613 assert_eq!(chunks[1], "Second paragraph.");
614 assert_eq!(chunks[2], "Third paragraph.");
615 }
616
617 #[test]
618 fn test_chunk_paragraph_empty() {
619 let chunks = chunk_paragraph("");
620 assert!(chunks.is_empty());
621 }
622
623 #[test]
624 fn test_chunk_paragraph_single() {
625 let text = "Just one paragraph with no double newlines.";
626 let chunks = chunk_paragraph(text);
627 assert_eq!(chunks.len(), 1);
628 assert_eq!(chunks[0], text);
629 }
630
631 #[test]
632 fn test_chunk_sentence_basic() {
633 let text = "First sentence. Second sentence! Third sentence?";
634 let chunks = chunk_sentence(text);
635 assert_eq!(chunks.len(), 3);
636 assert_eq!(chunks[0], "First sentence.");
637 assert_eq!(chunks[1], "Second sentence!");
638 assert_eq!(chunks[2], "Third sentence?");
639 }
640
641 #[test]
642 fn test_chunk_sentence_no_terminal() {
643 let text = "No terminal punctuation here";
644 let chunks = chunk_sentence(text);
645 assert_eq!(chunks.len(), 1);
646 assert_eq!(chunks[0], text);
647 }
648
649 #[test]
650 fn test_chunk_sentence_empty() {
651 let chunks = chunk_sentence("");
652 assert!(chunks.is_empty());
653 }
654
655 #[test]
656 fn test_chunk_semantic_headings() {
657 let text = "# Heading 1\nParagraph one.\n# Heading 2\nParagraph two.";
658 let chunks = chunk_semantic(text, 1000);
659 assert!(!chunks.is_empty());
662 assert_eq!(chunks.len(), 1);
664 }
665
666 #[test]
667 fn test_chunk_semantic_splits_on_budget() {
668 let text = "# A\nLorem ipsum dolor sit amet.\n# B\nConsectetur adipiscing elit.";
669 let chunks = chunk_semantic(text, 8);
671 assert!(chunks.len() >= 2, "small budget should force split");
672 }
673
674 #[test]
675 fn test_chunk_semantic_empty() {
676 let chunks = chunk_semantic("", 100);
677 assert!(chunks.is_empty());
678 }
679
680 #[test]
681 fn test_estimate_tokens() {
682 assert_eq!(estimate_tokens(""), 0);
683 assert_eq!(estimate_tokens("abcd"), 1);
684 assert_eq!(estimate_tokens("abcdefghijkl"), 3);
686 }
687
688 #[test]
693 fn test_document_chunk_fields() {
694 let chunk = DocumentChunk {
695 chunk_id: "doc1_chunk_0".into(),
696 document_id: "doc1".into(),
697 content: "hello world".into(),
698 chunk_index: 0,
699 token_estimate: 3,
700 };
701 assert_eq!(chunk.chunk_id, "doc1_chunk_0");
702 assert_eq!(chunk.document_id, "doc1");
703 assert_eq!(chunk.chunk_index, 0);
704 assert_eq!(chunk.token_estimate, 3);
705 }
706
707 #[test]
708 fn test_rag_config_default() {
709 let cfg = RagConfig::default();
710 assert_eq!(cfg.top_k, 5);
711 assert!((cfg.min_relevance_score - 0.3).abs() < f32::EPSILON);
712 assert!(cfg.include_metadata);
713 assert_eq!(cfg.max_context_tokens, 4096);
714 }
715
716 #[tokio::test]
721 async fn test_ingest_document_creates_chunks() {
722 let pipeline = default_pipeline();
723 let doc = sample_doc("d1", "Test Doc", "Hello world. This is a test document.");
724 let chunks = pipeline.ingest_document(&doc).await.unwrap();
725 assert!(!chunks.is_empty(), "should produce at least one chunk");
726 for c in &chunks {
728 assert_eq!(c.document_id, "d1");
729 assert!(!c.content.is_empty());
730 assert!(c.token_estimate > 0);
731 }
732 }
733
734 #[tokio::test]
735 async fn test_ingest_batch() {
736 let pipeline = default_pipeline();
737 let docs = vec![
738 sample_doc("d1", "Doc One", "Content of document one."),
739 sample_doc("d2", "Doc Two", "Content of document two."),
740 ];
741 let all = pipeline.ingest_batch(&docs).await.unwrap();
742 assert_eq!(all.len(), 2);
743 assert!(!all[0].is_empty());
744 assert!(!all[1].is_empty());
745 }
746
747 #[tokio::test]
748 async fn test_query_returns_results() {
749 let pipeline = default_pipeline();
750 let doc = sample_doc(
751 "d1",
752 "Rust Book",
753 "Rust is a systems programming language focused on safety and performance.",
754 );
755 pipeline.ingest_document(&doc).await.unwrap();
756
757 let result = pipeline.query("rust programming", None).await.unwrap();
758 assert!(!result.chunks.is_empty(), "query should return results");
759 assert!(result.query_time_ms < 10_000, "query should be fast");
760 assert!(
761 result.total_chunks_searched > 0,
762 "should report chunks searched"
763 );
764 }
765
766 #[tokio::test]
767 async fn test_query_context_text_not_empty() {
768 let pipeline = default_pipeline();
769 let doc = sample_doc("d1", "FAQ", "How do I install Rust? Use rustup.");
770 pipeline.ingest_document(&doc).await.unwrap();
771
772 let result = pipeline.query("install rust", None).await.unwrap();
773 assert!(
774 !result.context_text.is_empty(),
775 "context_text should be populated"
776 );
777 assert!(
778 result.context_text.contains("Chunk 1"),
779 "context should include chunk header"
780 );
781 }
782
783 #[tokio::test]
784 async fn test_query_with_context_window() {
785 let cfg = RagConfig {
786 min_relevance_score: 0.0, ..RagConfig::default()
788 };
789 let pipeline = make_pipeline(cfg);
790 let doc = sample_doc(
791 "d1",
792 "Long Doc",
793 "Rust programming language. Memory safety without garbage collection. Zero-cost abstractions.",
794 );
795 pipeline.ingest_document(&doc).await.unwrap();
796
797 let result = pipeline
798 .query_with_context("rust programming language", None, 8192)
799 .await
800 .unwrap();
801 assert!(!result.context_text.is_empty());
802 }
803
804 #[tokio::test]
805 async fn test_query_min_relevance_filter() {
806 let cfg = RagConfig {
807 min_relevance_score: 0.99, ..RagConfig::default()
809 };
810 let pipeline = make_pipeline(cfg);
811 let doc = sample_doc("d1", "Doc", "some random content about various topics");
812 pipeline.ingest_document(&doc).await.unwrap();
813
814 let result = pipeline
815 .query("completely unrelated xyz", None)
816 .await
817 .unwrap();
818 assert!(
821 result.chunks.len() <= 1,
822 "high threshold should filter most results"
823 );
824 }
825
826 #[tokio::test]
827 async fn test_scored_chunk_has_metadata() {
828 let pipeline = default_pipeline();
829 let doc = sample_doc("d1", "My Title", "Content about Rust programming.");
830 pipeline.ingest_document(&doc).await.unwrap();
831
832 let result = pipeline.query("rust", None).await.unwrap();
833 if let Some(sc) = result.chunks.first() {
834 assert_eq!(sc.document_title, "My Title");
835 assert_eq!(sc.source, "test");
836 assert!(sc.score > 0.0);
837 }
838 }
839
840 #[tokio::test]
841 async fn test_config_accessor() {
842 let pipeline = default_pipeline();
843 assert_eq!(pipeline.config().top_k, 5);
844 }
845
846 #[test]
847 fn test_format_context_empty() {
848 let ctx = format_context(&[], 4096);
849 assert!(ctx.is_empty());
850 }
851
852 #[test]
853 fn test_format_context_includes_source() {
854 let chunks = vec![ScoredChunk {
855 chunk: DocumentChunk {
856 chunk_id: "c1".into(),
857 document_id: "d1".into(),
858 content: "Hello".into(),
859 chunk_index: 0,
860 token_estimate: 2,
861 },
862 score: 0.95,
863 document_title: "Title".into(),
864 source: "kb".into(),
865 }];
866 let ctx = format_context(&chunks, 4096);
867 assert!(ctx.contains("kb"));
868 assert!(ctx.contains("Title"));
869 assert!(ctx.contains("0.95"));
870 assert!(ctx.contains("Hello"));
871 }
872}