Skip to main content

argentor_memory/
rag.rs

1//! Retrieval-Augmented Generation (RAG) pipeline for knowledge base search.
2//!
3//! Provides document ingestion (chunking + embedding + storage) and
4//! context-aware retrieval for LLM injection.
5//!
6//! # Main types
7//!
8//! - [`RagPipeline`] — Orchestrates ingestion and retrieval.
9//! - [`Document`] — A document to ingest into the knowledge base.
10//! - [`DocumentChunk`] — A chunk of a document after splitting.
11//! - [`ChunkingStrategy`] — How to split documents into chunks.
12//! - [`RagConfig`] — Pipeline configuration.
13//! - [`RagResult`] — Query result containing scored chunks and formatted context.
14//! - [`ScoredChunk`] — A chunk paired with its relevance score.
15
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Instant;
19
20use argentor_core::ArgentorResult;
21use chrono::Utc;
22use serde::{Deserialize, Serialize};
23use tracing::instrument;
24use uuid::Uuid;
25
26use crate::embedding::EmbeddingProvider;
27use crate::store::{MemoryEntry, VectorStore};
28
29// ---------------------------------------------------------------------------
30// Data types
31// ---------------------------------------------------------------------------
32
33/// A document to ingest into the RAG knowledge base.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct Document {
36    /// Unique identifier for this document.
37    pub id: String,
38    /// Human-readable title.
39    pub title: String,
40    /// Full text content to be chunked.
41    pub content: String,
42    /// Origin of the document (e.g. "knowledge_base", "faq", "docs").
43    pub source: String,
44    /// Arbitrary key-value metadata.
45    pub metadata: HashMap<String, String>,
46    /// Optional classification category.
47    pub category: Option<String>,
48}
49
50/// A chunk produced by splitting a [`Document`].
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct DocumentChunk {
53    /// Unique identifier for this chunk (typically `{document_id}_chunk_{index}`).
54    pub chunk_id: String,
55    /// The document this chunk belongs to.
56    pub document_id: String,
57    /// The text content of this chunk.
58    pub content: String,
59    /// Zero-based index of this chunk within the parent document.
60    pub chunk_index: usize,
61    /// Rough estimate of the number of tokens in this chunk.
62    pub token_estimate: usize,
63}
64
65/// Strategy for splitting documents into chunks.
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub enum ChunkingStrategy {
68    /// Fixed-size chunks measured in characters, with optional overlap.
69    FixedSize {
70        /// Maximum number of characters per chunk.
71        chunk_size: usize,
72        /// Number of overlapping characters between consecutive chunks.
73        overlap: usize,
74    },
75    /// Split on paragraph boundaries (double newlines).
76    Paragraph,
77    /// Split on sentence boundaries (`.` / `!` / `?` followed by whitespace).
78    Sentence,
79    /// Split on heading boundaries (lines starting with `#`), with a
80    /// maximum token budget per chunk.
81    Semantic {
82        /// Maximum estimated tokens per chunk.
83        max_chunk_tokens: usize,
84    },
85}
86
87/// Configuration for the RAG pipeline.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct RagConfig {
90    /// How to split ingested documents.
91    pub chunking: ChunkingStrategy,
92    /// Default number of top results to return.
93    pub top_k: usize,
94    /// Minimum relevance score (0.0 – 1.0) for a chunk to be included.
95    pub min_relevance_score: f32,
96    /// Whether to include document metadata in the formatted context.
97    pub include_metadata: bool,
98    /// Maximum estimated tokens for the combined context window.
99    pub max_context_tokens: usize,
100}
101
102impl Default for RagConfig {
103    fn default() -> Self {
104        Self {
105            chunking: ChunkingStrategy::FixedSize {
106                chunk_size: 512,
107                overlap: 64,
108            },
109            top_k: 5,
110            min_relevance_score: 0.3,
111            include_metadata: true,
112            max_context_tokens: 4096,
113        }
114    }
115}
116
117/// A scored chunk returned by the RAG query.
118#[derive(Debug, Clone)]
119pub struct ScoredChunk {
120    /// The document chunk.
121    pub chunk: DocumentChunk,
122    /// Relevance score (higher is better).
123    pub score: f32,
124    /// Title of the parent document.
125    pub document_title: String,
126    /// Source of the parent document.
127    pub source: String,
128}
129
130/// Result of a RAG query.
131#[derive(Debug, Clone)]
132pub struct RagResult {
133    /// Scored chunks ordered by relevance.
134    pub chunks: Vec<ScoredChunk>,
135    /// Pre-formatted context text ready for LLM injection.
136    pub context_text: String,
137    /// Total number of chunks that were searched.
138    pub total_chunks_searched: usize,
139    /// Wall-clock time in milliseconds for the query.
140    pub query_time_ms: u64,
141}
142
143// ---------------------------------------------------------------------------
144// Chunking helpers
145// ---------------------------------------------------------------------------
146
147/// Estimate token count from a string (≈ 1 token per 4 characters).
148fn estimate_tokens(text: &str) -> usize {
149    // Simple heuristic: ~4 chars per token for English text.
150    text.len().div_ceil(4)
151}
152
153/// Split a document into chunks according to the given strategy.
154fn chunk_document(doc: &Document, strategy: &ChunkingStrategy) -> Vec<DocumentChunk> {
155    let raw_chunks = match strategy {
156        ChunkingStrategy::FixedSize {
157            chunk_size,
158            overlap,
159        } => chunk_fixed_size(&doc.content, *chunk_size, *overlap),
160        ChunkingStrategy::Paragraph => chunk_paragraph(&doc.content),
161        ChunkingStrategy::Sentence => chunk_sentence(&doc.content),
162        ChunkingStrategy::Semantic { max_chunk_tokens } => {
163            chunk_semantic(&doc.content, *max_chunk_tokens)
164        }
165    };
166
167    raw_chunks
168        .into_iter()
169        .enumerate()
170        .map(|(idx, text)| DocumentChunk {
171            chunk_id: format!("{}_chunk_{}", doc.id, idx),
172            document_id: doc.id.clone(),
173            content: text,
174            chunk_index: idx,
175            token_estimate: 0, // filled below
176        })
177        .map(|mut c| {
178            c.token_estimate = estimate_tokens(&c.content);
179            c
180        })
181        .collect()
182}
183
184/// Fixed-size chunking with character-level overlap.
185fn chunk_fixed_size(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
186    if text.is_empty() || chunk_size == 0 {
187        return vec![];
188    }
189    let effective_overlap = overlap.min(chunk_size.saturating_sub(1));
190    let step = chunk_size.saturating_sub(effective_overlap).max(1);
191    let chars: Vec<char> = text.chars().collect();
192    let mut chunks = Vec::new();
193    let mut start = 0;
194    while start < chars.len() {
195        let end = (start + chunk_size).min(chars.len());
196        let chunk: String = chars[start..end].iter().collect();
197        let trimmed = chunk.trim().to_string();
198        if !trimmed.is_empty() {
199            chunks.push(trimmed);
200        }
201        if end == chars.len() {
202            break;
203        }
204        start += step;
205    }
206    chunks
207}
208
209/// Split on paragraph boundaries (two or more consecutive newlines).
210fn chunk_paragraph(text: &str) -> Vec<String> {
211    let parts: Vec<String> = text
212        .split("\n\n")
213        .map(|p| p.trim().to_string())
214        .filter(|p| !p.is_empty())
215        .collect();
216    if parts.is_empty() {
217        // Fallback: return the whole text as a single chunk if non-empty.
218        let trimmed = text.trim().to_string();
219        if trimmed.is_empty() {
220            vec![]
221        } else {
222            vec![trimmed]
223        }
224    } else {
225        parts
226    }
227}
228
229/// Split on sentence boundaries.
230fn chunk_sentence(text: &str) -> Vec<String> {
231    let mut sentences = Vec::new();
232    let mut current = String::new();
233    let chars: Vec<char> = text.chars().collect();
234    let len = chars.len();
235
236    for i in 0..len {
237        current.push(chars[i]);
238        let is_terminal = matches!(chars[i], '.' | '!' | '?');
239        let followed_by_space = i + 1 < len && chars[i + 1].is_whitespace();
240        if is_terminal && (followed_by_space || i + 1 == len) {
241            let trimmed = current.trim().to_string();
242            if !trimmed.is_empty() {
243                sentences.push(trimmed);
244            }
245            current.clear();
246        }
247    }
248    // Remaining text
249    let trimmed = current.trim().to_string();
250    if !trimmed.is_empty() {
251        sentences.push(trimmed);
252    }
253    sentences
254}
255
256/// Semantic chunking: split on Markdown heading boundaries (`# ...`),
257/// merging small sections up to `max_chunk_tokens`.
258fn chunk_semantic(text: &str, max_chunk_tokens: usize) -> Vec<String> {
259    let max_tokens = if max_chunk_tokens == 0 {
260        256
261    } else {
262        max_chunk_tokens
263    };
264
265    let mut sections: Vec<String> = Vec::new();
266    let mut current_section = String::new();
267
268    for line in text.lines() {
269        let is_heading = line.starts_with('#');
270        if is_heading && !current_section.trim().is_empty() {
271            sections.push(current_section.trim().to_string());
272            current_section.clear();
273        }
274        if !current_section.is_empty() {
275            current_section.push('\n');
276        }
277        current_section.push_str(line);
278    }
279    if !current_section.trim().is_empty() {
280        sections.push(current_section.trim().to_string());
281    }
282
283    // Merge tiny sections that fit within the token budget.
284    let mut merged: Vec<String> = Vec::new();
285    let mut buffer = String::new();
286
287    for section in sections {
288        let combined_tokens = estimate_tokens(&buffer) + estimate_tokens(&section);
289        if buffer.is_empty() {
290            buffer = section;
291        } else if combined_tokens <= max_tokens {
292            buffer.push_str("\n\n");
293            buffer.push_str(&section);
294        } else {
295            merged.push(buffer.trim().to_string());
296            buffer = section;
297        }
298    }
299    if !buffer.trim().is_empty() {
300        merged.push(buffer.trim().to_string());
301    }
302
303    if merged.is_empty() {
304        let trimmed = text.trim().to_string();
305        if trimmed.is_empty() {
306            vec![]
307        } else {
308            vec![trimmed]
309        }
310    } else {
311        merged
312    }
313}
314
315// ---------------------------------------------------------------------------
316// RagPipeline
317// ---------------------------------------------------------------------------
318
319/// Retrieval-Augmented Generation pipeline.
320///
321/// Orchestrates document ingestion (chunking, embedding, storage) and
322/// context-aware retrieval for LLM injection.
323pub struct RagPipeline {
324    vector_store: Arc<dyn VectorStore>,
325    embedder: Arc<dyn EmbeddingProvider>,
326    config: RagConfig,
327    /// In-memory index mapping `MemoryEntry.id` → `(DocumentChunk, doc_title, doc_source)`.
328    chunk_index: tokio::sync::RwLock<HashMap<Uuid, (DocumentChunk, String, String)>>,
329}
330
331impl RagPipeline {
332    /// Create a new RAG pipeline.
333    pub fn new(
334        vector_store: Arc<dyn VectorStore>,
335        embedder: Arc<dyn EmbeddingProvider>,
336        config: RagConfig,
337    ) -> Self {
338        Self {
339            vector_store,
340            embedder,
341            config,
342            chunk_index: tokio::sync::RwLock::new(HashMap::new()),
343        }
344    }
345
346    /// Ingest a single document: chunk it, embed each chunk, and store.
347    pub async fn ingest_document(&self, doc: &Document) -> ArgentorResult<Vec<DocumentChunk>> {
348        let chunks = chunk_document(doc, &self.config.chunking);
349
350        for chunk in &chunks {
351            let embedding = self.embedder.embed(&chunk.content).await?;
352
353            let mut metadata = HashMap::new();
354            metadata.insert(
355                "document_id".to_string(),
356                serde_json::Value::String(doc.id.clone()),
357            );
358            metadata.insert(
359                "document_title".to_string(),
360                serde_json::Value::String(doc.title.clone()),
361            );
362            metadata.insert(
363                "source".to_string(),
364                serde_json::Value::String(doc.source.clone()),
365            );
366            metadata.insert(
367                "chunk_id".to_string(),
368                serde_json::Value::String(chunk.chunk_id.clone()),
369            );
370            metadata.insert(
371                "chunk_index".to_string(),
372                serde_json::json!(chunk.chunk_index),
373            );
374            if let Some(cat) = &doc.category {
375                metadata.insert(
376                    "category".to_string(),
377                    serde_json::Value::String(cat.clone()),
378                );
379            }
380            if self.config.include_metadata {
381                for (k, v) in &doc.metadata {
382                    metadata.insert(k.clone(), serde_json::Value::String(v.clone()));
383                }
384            }
385
386            let entry_id = Uuid::new_v4();
387            let entry = MemoryEntry {
388                id: entry_id,
389                content: chunk.content.clone(),
390                embedding,
391                metadata,
392                session_id: None,
393                created_at: Utc::now(),
394            };
395
396            self.vector_store.insert(entry).await?;
397
398            // Track the chunk in our local index.
399            let mut idx = self.chunk_index.write().await;
400            idx.insert(
401                entry_id,
402                (chunk.clone(), doc.title.clone(), doc.source.clone()),
403            );
404        }
405
406        Ok(chunks)
407    }
408
409    /// Batch-ingest multiple documents.
410    pub async fn ingest_batch(&self, docs: &[Document]) -> ArgentorResult<Vec<Vec<DocumentChunk>>> {
411        let mut all_chunks = Vec::with_capacity(docs.len());
412        for doc in docs {
413            let chunks = self.ingest_document(doc).await?;
414            all_chunks.push(chunks);
415        }
416        Ok(all_chunks)
417    }
418
419    /// Query the knowledge base and return scored, filtered chunks.
420    #[instrument(
421        skip(self, question),
422        fields(
423            query_length = question.len(),
424            chunks_searched = tracing::field::Empty,
425            results_count = tracing::field::Empty,
426        )
427    )]
428    pub async fn query(&self, question: &str, top_k: Option<usize>) -> ArgentorResult<RagResult> {
429        let start = Instant::now();
430        let k = top_k.unwrap_or(self.config.top_k);
431
432        let query_embedding = self.embedder.embed(question).await?;
433
434        let total_chunks_searched = self.vector_store.count().await?;
435
436        // Retrieve more candidates than requested so we can filter by score.
437        let fetch_k = (k * 3).max(k);
438        let results = self
439            .vector_store
440            .search(&query_embedding, fetch_k, None)
441            .await?;
442
443        let idx = self.chunk_index.read().await;
444
445        let mut scored: Vec<ScoredChunk> = results
446            .into_iter()
447            .filter(|r| r.score >= self.config.min_relevance_score)
448            .filter_map(|r| {
449                let (chunk, title, source) = idx.get(&r.entry.id)?;
450                Some(ScoredChunk {
451                    chunk: chunk.clone(),
452                    score: r.score,
453                    document_title: title.clone(),
454                    source: source.clone(),
455                })
456            })
457            .collect();
458
459        scored.sort_by(|a, b| {
460            b.score
461                .partial_cmp(&a.score)
462                .unwrap_or(std::cmp::Ordering::Equal)
463        });
464        scored.truncate(k);
465
466        let context_text = format_context(&scored, self.config.max_context_tokens);
467
468        let elapsed = start.elapsed().as_millis() as u64;
469
470        tracing::Span::current()
471            .record("chunks_searched", total_chunks_searched)
472            .record("results_count", scored.len());
473
474        Ok(RagResult {
475            chunks: scored,
476            context_text,
477            total_chunks_searched,
478            query_time_ms: elapsed,
479        })
480    }
481
482    /// Query and return a formatted context string sized to fit a given
483    /// context window (in estimated tokens).
484    pub async fn query_with_context(
485        &self,
486        question: &str,
487        top_k: Option<usize>,
488        context_window: usize,
489    ) -> ArgentorResult<RagResult> {
490        let mut result = self.query(question, top_k).await?;
491        // Re-format the context with the caller-specified window size.
492        result.context_text = format_context(&result.chunks, context_window);
493        Ok(result)
494    }
495
496    /// Return a reference to the pipeline configuration.
497    pub fn config(&self) -> &RagConfig {
498        &self.config
499    }
500}
501
502/// Format scored chunks into a context string for LLM injection,
503/// respecting a maximum token budget.
504fn format_context(chunks: &[ScoredChunk], max_tokens: usize) -> String {
505    let mut parts: Vec<String> = Vec::new();
506    let mut token_budget = max_tokens;
507
508    for (i, sc) in chunks.iter().enumerate() {
509        let header = format!(
510            "[Source: {} | Document: {} | Score: {:.2}]",
511            sc.source, sc.document_title, sc.score
512        );
513        let section = format!("--- Chunk {} ---\n{}\n{}", i + 1, header, sc.chunk.content);
514        let section_tokens = estimate_tokens(&section);
515        if section_tokens > token_budget {
516            // Try to include a truncated version if there is room.
517            if token_budget > 20 {
518                let available_chars = token_budget * 4;
519                let truncated: String = section.chars().take(available_chars).collect();
520                parts.push(truncated);
521            }
522            break;
523        }
524        token_budget = token_budget.saturating_sub(section_tokens);
525        parts.push(section);
526    }
527
528    parts.join("\n\n")
529}
530
531// ---------------------------------------------------------------------------
532// Tests
533// ---------------------------------------------------------------------------
534
535#[cfg(test)]
536#[allow(clippy::unwrap_used, clippy::expect_used)]
537mod tests {
538    use super::*;
539    use crate::embedding::LocalEmbedding;
540    use crate::store::InMemoryVectorStore;
541
542    // -- Helpers --
543
544    fn sample_doc(id: &str, title: &str, content: &str) -> Document {
545        Document {
546            id: id.to_string(),
547            title: title.to_string(),
548            content: content.to_string(),
549            source: "test".to_string(),
550            metadata: HashMap::new(),
551            category: None,
552        }
553    }
554
555    fn make_pipeline(config: RagConfig) -> RagPipeline {
556        let store = Arc::new(InMemoryVectorStore::new()) as Arc<dyn VectorStore>;
557        let embedder = Arc::new(LocalEmbedding::default()) as Arc<dyn EmbeddingProvider>;
558        RagPipeline::new(store, embedder, config)
559    }
560
561    fn default_pipeline() -> RagPipeline {
562        make_pipeline(RagConfig::default())
563    }
564
565    // -----------------------------------------------------------------------
566    // Chunking unit tests
567    // -----------------------------------------------------------------------
568
569    #[test]
570    fn test_chunk_fixed_size_basic() {
571        let chunks = chunk_fixed_size("abcdefghij", 4, 0);
572        assert_eq!(chunks.len(), 3); // "abcd", "efgh", "ij"
573        assert_eq!(chunks[0], "abcd");
574        assert_eq!(chunks[1], "efgh");
575        assert_eq!(chunks[2], "ij");
576    }
577
578    #[test]
579    fn test_chunk_fixed_size_with_overlap() {
580        let chunks = chunk_fixed_size("abcdefghij", 5, 2);
581        // step = 5 - 2 = 3, windows: [0..5], [3..8], [6..10]
582        assert_eq!(chunks.len(), 3);
583        assert_eq!(chunks[0], "abcde");
584        assert_eq!(chunks[1], "defgh");
585        assert_eq!(chunks[2], "ghij");
586    }
587
588    #[test]
589    fn test_chunk_fixed_size_empty_text() {
590        let chunks = chunk_fixed_size("", 10, 0);
591        assert!(chunks.is_empty());
592    }
593
594    #[test]
595    fn test_chunk_fixed_size_zero_size() {
596        let chunks = chunk_fixed_size("hello", 0, 0);
597        assert!(chunks.is_empty());
598    }
599
600    #[test]
601    fn test_chunk_fixed_size_text_shorter_than_chunk() {
602        let chunks = chunk_fixed_size("hi", 100, 0);
603        assert_eq!(chunks.len(), 1);
604        assert_eq!(chunks[0], "hi");
605    }
606
607    #[test]
608    fn test_chunk_paragraph_basic() {
609        let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
610        let chunks = chunk_paragraph(text);
611        assert_eq!(chunks.len(), 3);
612        assert_eq!(chunks[0], "First paragraph.");
613        assert_eq!(chunks[1], "Second paragraph.");
614        assert_eq!(chunks[2], "Third paragraph.");
615    }
616
617    #[test]
618    fn test_chunk_paragraph_empty() {
619        let chunks = chunk_paragraph("");
620        assert!(chunks.is_empty());
621    }
622
623    #[test]
624    fn test_chunk_paragraph_single() {
625        let text = "Just one paragraph with no double newlines.";
626        let chunks = chunk_paragraph(text);
627        assert_eq!(chunks.len(), 1);
628        assert_eq!(chunks[0], text);
629    }
630
631    #[test]
632    fn test_chunk_sentence_basic() {
633        let text = "First sentence. Second sentence! Third sentence?";
634        let chunks = chunk_sentence(text);
635        assert_eq!(chunks.len(), 3);
636        assert_eq!(chunks[0], "First sentence.");
637        assert_eq!(chunks[1], "Second sentence!");
638        assert_eq!(chunks[2], "Third sentence?");
639    }
640
641    #[test]
642    fn test_chunk_sentence_no_terminal() {
643        let text = "No terminal punctuation here";
644        let chunks = chunk_sentence(text);
645        assert_eq!(chunks.len(), 1);
646        assert_eq!(chunks[0], text);
647    }
648
649    #[test]
650    fn test_chunk_sentence_empty() {
651        let chunks = chunk_sentence("");
652        assert!(chunks.is_empty());
653    }
654
655    #[test]
656    fn test_chunk_semantic_headings() {
657        let text = "# Heading 1\nParagraph one.\n# Heading 2\nParagraph two.";
658        let chunks = chunk_semantic(text, 1000);
659        // Two sections: "# Heading 1\nParagraph one." and "# Heading 2\nParagraph two."
660        // They fit within 1000 tokens, so they may be merged.
661        assert!(!chunks.is_empty());
662        // With a huge token budget they get merged into one.
663        assert_eq!(chunks.len(), 1);
664    }
665
666    #[test]
667    fn test_chunk_semantic_splits_on_budget() {
668        let text = "# A\nLorem ipsum dolor sit amet.\n# B\nConsectetur adipiscing elit.";
669        // Very small budget forces a split.
670        let chunks = chunk_semantic(text, 8);
671        assert!(chunks.len() >= 2, "small budget should force split");
672    }
673
674    #[test]
675    fn test_chunk_semantic_empty() {
676        let chunks = chunk_semantic("", 100);
677        assert!(chunks.is_empty());
678    }
679
680    #[test]
681    fn test_estimate_tokens() {
682        assert_eq!(estimate_tokens(""), 0);
683        assert_eq!(estimate_tokens("abcd"), 1);
684        // 12 chars -> 3 tokens
685        assert_eq!(estimate_tokens("abcdefghijkl"), 3);
686    }
687
688    // -----------------------------------------------------------------------
689    // Document model tests
690    // -----------------------------------------------------------------------
691
692    #[test]
693    fn test_document_chunk_fields() {
694        let chunk = DocumentChunk {
695            chunk_id: "doc1_chunk_0".into(),
696            document_id: "doc1".into(),
697            content: "hello world".into(),
698            chunk_index: 0,
699            token_estimate: 3,
700        };
701        assert_eq!(chunk.chunk_id, "doc1_chunk_0");
702        assert_eq!(chunk.document_id, "doc1");
703        assert_eq!(chunk.chunk_index, 0);
704        assert_eq!(chunk.token_estimate, 3);
705    }
706
707    #[test]
708    fn test_rag_config_default() {
709        let cfg = RagConfig::default();
710        assert_eq!(cfg.top_k, 5);
711        assert!((cfg.min_relevance_score - 0.3).abs() < f32::EPSILON);
712        assert!(cfg.include_metadata);
713        assert_eq!(cfg.max_context_tokens, 4096);
714    }
715
716    // -----------------------------------------------------------------------
717    // Integration / pipeline tests
718    // -----------------------------------------------------------------------
719
720    #[tokio::test]
721    async fn test_ingest_document_creates_chunks() {
722        let pipeline = default_pipeline();
723        let doc = sample_doc("d1", "Test Doc", "Hello world. This is a test document.");
724        let chunks = pipeline.ingest_document(&doc).await.unwrap();
725        assert!(!chunks.is_empty(), "should produce at least one chunk");
726        // Each chunk should reference the parent document.
727        for c in &chunks {
728            assert_eq!(c.document_id, "d1");
729            assert!(!c.content.is_empty());
730            assert!(c.token_estimate > 0);
731        }
732    }
733
734    #[tokio::test]
735    async fn test_ingest_batch() {
736        let pipeline = default_pipeline();
737        let docs = vec![
738            sample_doc("d1", "Doc One", "Content of document one."),
739            sample_doc("d2", "Doc Two", "Content of document two."),
740        ];
741        let all = pipeline.ingest_batch(&docs).await.unwrap();
742        assert_eq!(all.len(), 2);
743        assert!(!all[0].is_empty());
744        assert!(!all[1].is_empty());
745    }
746
747    #[tokio::test]
748    async fn test_query_returns_results() {
749        let pipeline = default_pipeline();
750        let doc = sample_doc(
751            "d1",
752            "Rust Book",
753            "Rust is a systems programming language focused on safety and performance.",
754        );
755        pipeline.ingest_document(&doc).await.unwrap();
756
757        let result = pipeline.query("rust programming", None).await.unwrap();
758        assert!(!result.chunks.is_empty(), "query should return results");
759        assert!(result.query_time_ms < 10_000, "query should be fast");
760        assert!(
761            result.total_chunks_searched > 0,
762            "should report chunks searched"
763        );
764    }
765
766    #[tokio::test]
767    async fn test_query_context_text_not_empty() {
768        let pipeline = default_pipeline();
769        let doc = sample_doc("d1", "FAQ", "How do I install Rust? Use rustup.");
770        pipeline.ingest_document(&doc).await.unwrap();
771
772        let result = pipeline.query("install rust", None).await.unwrap();
773        assert!(
774            !result.context_text.is_empty(),
775            "context_text should be populated"
776        );
777        assert!(
778            result.context_text.contains("Chunk 1"),
779            "context should include chunk header"
780        );
781    }
782
783    #[tokio::test]
784    async fn test_query_with_context_window() {
785        let cfg = RagConfig {
786            min_relevance_score: 0.0, // accept any score so the test is deterministic
787            ..RagConfig::default()
788        };
789        let pipeline = make_pipeline(cfg);
790        let doc = sample_doc(
791            "d1",
792            "Long Doc",
793            "Rust programming language. Memory safety without garbage collection. Zero-cost abstractions.",
794        );
795        pipeline.ingest_document(&doc).await.unwrap();
796
797        let result = pipeline
798            .query_with_context("rust programming language", None, 8192)
799            .await
800            .unwrap();
801        assert!(!result.context_text.is_empty());
802    }
803
804    #[tokio::test]
805    async fn test_query_min_relevance_filter() {
806        let cfg = RagConfig {
807            min_relevance_score: 0.99, // very high threshold
808            ..RagConfig::default()
809        };
810        let pipeline = make_pipeline(cfg);
811        let doc = sample_doc("d1", "Doc", "some random content about various topics");
812        pipeline.ingest_document(&doc).await.unwrap();
813
814        let result = pipeline
815            .query("completely unrelated xyz", None)
816            .await
817            .unwrap();
818        // With a 0.99 threshold most results should be filtered out.
819        // We don't assert exact count because LocalEmbedding is approximate.
820        assert!(
821            result.chunks.len() <= 1,
822            "high threshold should filter most results"
823        );
824    }
825
826    #[tokio::test]
827    async fn test_scored_chunk_has_metadata() {
828        let pipeline = default_pipeline();
829        let doc = sample_doc("d1", "My Title", "Content about Rust programming.");
830        pipeline.ingest_document(&doc).await.unwrap();
831
832        let result = pipeline.query("rust", None).await.unwrap();
833        if let Some(sc) = result.chunks.first() {
834            assert_eq!(sc.document_title, "My Title");
835            assert_eq!(sc.source, "test");
836            assert!(sc.score > 0.0);
837        }
838    }
839
840    #[tokio::test]
841    async fn test_config_accessor() {
842        let pipeline = default_pipeline();
843        assert_eq!(pipeline.config().top_k, 5);
844    }
845
846    #[test]
847    fn test_format_context_empty() {
848        let ctx = format_context(&[], 4096);
849        assert!(ctx.is_empty());
850    }
851
852    #[test]
853    fn test_format_context_includes_source() {
854        let chunks = vec![ScoredChunk {
855            chunk: DocumentChunk {
856                chunk_id: "c1".into(),
857                document_id: "d1".into(),
858                content: "Hello".into(),
859                chunk_index: 0,
860                token_estimate: 2,
861            },
862            score: 0.95,
863            document_title: "Title".into(),
864            source: "kb".into(),
865        }];
866        let ctx = format_context(&chunks, 4096);
867        assert!(ctx.contains("kb"));
868        assert!(ctx.contains("Title"));
869        assert!(ctx.contains("0.95"));
870        assert!(ctx.contains("Hello"));
871    }
872}