Skip to main content

argentor_memory/
rag.rs

1//! Retrieval-Augmented Generation (RAG) pipeline for knowledge base search.
2//!
3//! Provides document ingestion (chunking + embedding + storage) and
4//! context-aware retrieval for LLM injection.
5//!
6//! # Main types
7//!
8//! - [`RagPipeline`] — Orchestrates ingestion and retrieval.
9//! - [`Document`] — A document to ingest into the knowledge base.
10//! - [`DocumentChunk`] — A chunk of a document after splitting.
11//! - [`ChunkingStrategy`] — How to split documents into chunks.
12//! - [`RagConfig`] — Pipeline configuration.
13//! - [`RagResult`] — Query result containing scored chunks and formatted context.
14//! - [`ScoredChunk`] — A chunk paired with its relevance score.
15
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Instant;
19
20use argentor_core::ArgentorResult;
21use chrono::Utc;
22use serde::{Deserialize, Serialize};
23use uuid::Uuid;
24
25use crate::embedding::EmbeddingProvider;
26use crate::store::{MemoryEntry, VectorStore};
27
28// ---------------------------------------------------------------------------
29// Data types
30// ---------------------------------------------------------------------------
31
32/// A document to ingest into the RAG knowledge base.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Document {
35    /// Unique identifier for this document.
36    pub id: String,
37    /// Human-readable title.
38    pub title: String,
39    /// Full text content to be chunked.
40    pub content: String,
41    /// Origin of the document (e.g. "knowledge_base", "faq", "docs").
42    pub source: String,
43    /// Arbitrary key-value metadata.
44    pub metadata: HashMap<String, String>,
45    /// Optional classification category.
46    pub category: Option<String>,
47}
48
49/// A chunk produced by splitting a [`Document`].
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct DocumentChunk {
52    /// Unique identifier for this chunk (typically `{document_id}_chunk_{index}`).
53    pub chunk_id: String,
54    /// The document this chunk belongs to.
55    pub document_id: String,
56    /// The text content of this chunk.
57    pub content: String,
58    /// Zero-based index of this chunk within the parent document.
59    pub chunk_index: usize,
60    /// Rough estimate of the number of tokens in this chunk.
61    pub token_estimate: usize,
62}
63
64/// Strategy for splitting documents into chunks.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum ChunkingStrategy {
67    /// Fixed-size chunks measured in characters, with optional overlap.
68    FixedSize {
69        /// Maximum number of characters per chunk.
70        chunk_size: usize,
71        /// Number of overlapping characters between consecutive chunks.
72        overlap: usize,
73    },
74    /// Split on paragraph boundaries (double newlines).
75    Paragraph,
76    /// Split on sentence boundaries (`.` / `!` / `?` followed by whitespace).
77    Sentence,
78    /// Split on heading boundaries (lines starting with `#`), with a
79    /// maximum token budget per chunk.
80    Semantic {
81        /// Maximum estimated tokens per chunk.
82        max_chunk_tokens: usize,
83    },
84}
85
86/// Configuration for the RAG pipeline.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct RagConfig {
89    /// How to split ingested documents.
90    pub chunking: ChunkingStrategy,
91    /// Default number of top results to return.
92    pub top_k: usize,
93    /// Minimum relevance score (0.0 – 1.0) for a chunk to be included.
94    pub min_relevance_score: f32,
95    /// Whether to include document metadata in the formatted context.
96    pub include_metadata: bool,
97    /// Maximum estimated tokens for the combined context window.
98    pub max_context_tokens: usize,
99}
100
101impl Default for RagConfig {
102    fn default() -> Self {
103        Self {
104            chunking: ChunkingStrategy::FixedSize {
105                chunk_size: 512,
106                overlap: 64,
107            },
108            top_k: 5,
109            min_relevance_score: 0.3,
110            include_metadata: true,
111            max_context_tokens: 4096,
112        }
113    }
114}
115
116/// A scored chunk returned by the RAG query.
117#[derive(Debug, Clone)]
118pub struct ScoredChunk {
119    /// The document chunk.
120    pub chunk: DocumentChunk,
121    /// Relevance score (higher is better).
122    pub score: f32,
123    /// Title of the parent document.
124    pub document_title: String,
125    /// Source of the parent document.
126    pub source: String,
127}
128
129/// Result of a RAG query.
130#[derive(Debug, Clone)]
131pub struct RagResult {
132    /// Scored chunks ordered by relevance.
133    pub chunks: Vec<ScoredChunk>,
134    /// Pre-formatted context text ready for LLM injection.
135    pub context_text: String,
136    /// Total number of chunks that were searched.
137    pub total_chunks_searched: usize,
138    /// Wall-clock time in milliseconds for the query.
139    pub query_time_ms: u64,
140}
141
142// ---------------------------------------------------------------------------
143// Chunking helpers
144// ---------------------------------------------------------------------------
145
146/// Estimate token count from a string (≈ 1 token per 4 characters).
147fn estimate_tokens(text: &str) -> usize {
148    // Simple heuristic: ~4 chars per token for English text.
149    text.len().div_ceil(4)
150}
151
152/// Split a document into chunks according to the given strategy.
153fn chunk_document(doc: &Document, strategy: &ChunkingStrategy) -> Vec<DocumentChunk> {
154    let raw_chunks = match strategy {
155        ChunkingStrategy::FixedSize {
156            chunk_size,
157            overlap,
158        } => chunk_fixed_size(&doc.content, *chunk_size, *overlap),
159        ChunkingStrategy::Paragraph => chunk_paragraph(&doc.content),
160        ChunkingStrategy::Sentence => chunk_sentence(&doc.content),
161        ChunkingStrategy::Semantic { max_chunk_tokens } => {
162            chunk_semantic(&doc.content, *max_chunk_tokens)
163        }
164    };
165
166    raw_chunks
167        .into_iter()
168        .enumerate()
169        .map(|(idx, text)| DocumentChunk {
170            chunk_id: format!("{}_chunk_{}", doc.id, idx),
171            document_id: doc.id.clone(),
172            content: text,
173            chunk_index: idx,
174            token_estimate: 0, // filled below
175        })
176        .map(|mut c| {
177            c.token_estimate = estimate_tokens(&c.content);
178            c
179        })
180        .collect()
181}
182
183/// Fixed-size chunking with character-level overlap.
184fn chunk_fixed_size(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
185    if text.is_empty() || chunk_size == 0 {
186        return vec![];
187    }
188    let effective_overlap = overlap.min(chunk_size.saturating_sub(1));
189    let step = chunk_size.saturating_sub(effective_overlap).max(1);
190    let chars: Vec<char> = text.chars().collect();
191    let mut chunks = Vec::new();
192    let mut start = 0;
193    while start < chars.len() {
194        let end = (start + chunk_size).min(chars.len());
195        let chunk: String = chars[start..end].iter().collect();
196        let trimmed = chunk.trim().to_string();
197        if !trimmed.is_empty() {
198            chunks.push(trimmed);
199        }
200        if end == chars.len() {
201            break;
202        }
203        start += step;
204    }
205    chunks
206}
207
208/// Split on paragraph boundaries (two or more consecutive newlines).
209fn chunk_paragraph(text: &str) -> Vec<String> {
210    let parts: Vec<String> = text
211        .split("\n\n")
212        .map(|p| p.trim().to_string())
213        .filter(|p| !p.is_empty())
214        .collect();
215    if parts.is_empty() {
216        // Fallback: return the whole text as a single chunk if non-empty.
217        let trimmed = text.trim().to_string();
218        if trimmed.is_empty() {
219            vec![]
220        } else {
221            vec![trimmed]
222        }
223    } else {
224        parts
225    }
226}
227
228/// Split on sentence boundaries.
229fn chunk_sentence(text: &str) -> Vec<String> {
230    let mut sentences = Vec::new();
231    let mut current = String::new();
232    let chars: Vec<char> = text.chars().collect();
233    let len = chars.len();
234
235    for i in 0..len {
236        current.push(chars[i]);
237        let is_terminal = matches!(chars[i], '.' | '!' | '?');
238        let followed_by_space = i + 1 < len && chars[i + 1].is_whitespace();
239        if is_terminal && (followed_by_space || i + 1 == len) {
240            let trimmed = current.trim().to_string();
241            if !trimmed.is_empty() {
242                sentences.push(trimmed);
243            }
244            current.clear();
245        }
246    }
247    // Remaining text
248    let trimmed = current.trim().to_string();
249    if !trimmed.is_empty() {
250        sentences.push(trimmed);
251    }
252    sentences
253}
254
255/// Semantic chunking: split on Markdown heading boundaries (`# ...`),
256/// merging small sections up to `max_chunk_tokens`.
257fn chunk_semantic(text: &str, max_chunk_tokens: usize) -> Vec<String> {
258    let max_tokens = if max_chunk_tokens == 0 {
259        256
260    } else {
261        max_chunk_tokens
262    };
263
264    let mut sections: Vec<String> = Vec::new();
265    let mut current_section = String::new();
266
267    for line in text.lines() {
268        let is_heading = line.starts_with('#');
269        if is_heading && !current_section.trim().is_empty() {
270            sections.push(current_section.trim().to_string());
271            current_section.clear();
272        }
273        if !current_section.is_empty() {
274            current_section.push('\n');
275        }
276        current_section.push_str(line);
277    }
278    if !current_section.trim().is_empty() {
279        sections.push(current_section.trim().to_string());
280    }
281
282    // Merge tiny sections that fit within the token budget.
283    let mut merged: Vec<String> = Vec::new();
284    let mut buffer = String::new();
285
286    for section in sections {
287        let combined_tokens = estimate_tokens(&buffer) + estimate_tokens(&section);
288        if buffer.is_empty() {
289            buffer = section;
290        } else if combined_tokens <= max_tokens {
291            buffer.push_str("\n\n");
292            buffer.push_str(&section);
293        } else {
294            merged.push(buffer.trim().to_string());
295            buffer = section;
296        }
297    }
298    if !buffer.trim().is_empty() {
299        merged.push(buffer.trim().to_string());
300    }
301
302    if merged.is_empty() {
303        let trimmed = text.trim().to_string();
304        if trimmed.is_empty() {
305            vec![]
306        } else {
307            vec![trimmed]
308        }
309    } else {
310        merged
311    }
312}
313
314// ---------------------------------------------------------------------------
315// RagPipeline
316// ---------------------------------------------------------------------------
317
318/// Retrieval-Augmented Generation pipeline.
319///
320/// Orchestrates document ingestion (chunking, embedding, storage) and
321/// context-aware retrieval for LLM injection.
322pub struct RagPipeline {
323    vector_store: Arc<dyn VectorStore>,
324    embedder: Arc<dyn EmbeddingProvider>,
325    config: RagConfig,
326    /// In-memory index mapping `MemoryEntry.id` → `(DocumentChunk, doc_title, doc_source)`.
327    chunk_index: tokio::sync::RwLock<HashMap<Uuid, (DocumentChunk, String, String)>>,
328}
329
330impl RagPipeline {
331    /// Create a new RAG pipeline.
332    pub fn new(
333        vector_store: Arc<dyn VectorStore>,
334        embedder: Arc<dyn EmbeddingProvider>,
335        config: RagConfig,
336    ) -> Self {
337        Self {
338            vector_store,
339            embedder,
340            config,
341            chunk_index: tokio::sync::RwLock::new(HashMap::new()),
342        }
343    }
344
345    /// Ingest a single document: chunk it, embed each chunk, and store.
346    pub async fn ingest_document(&self, doc: &Document) -> ArgentorResult<Vec<DocumentChunk>> {
347        let chunks = chunk_document(doc, &self.config.chunking);
348
349        for chunk in &chunks {
350            let embedding = self.embedder.embed(&chunk.content).await?;
351
352            let mut metadata = HashMap::new();
353            metadata.insert(
354                "document_id".to_string(),
355                serde_json::Value::String(doc.id.clone()),
356            );
357            metadata.insert(
358                "document_title".to_string(),
359                serde_json::Value::String(doc.title.clone()),
360            );
361            metadata.insert(
362                "source".to_string(),
363                serde_json::Value::String(doc.source.clone()),
364            );
365            metadata.insert(
366                "chunk_id".to_string(),
367                serde_json::Value::String(chunk.chunk_id.clone()),
368            );
369            metadata.insert(
370                "chunk_index".to_string(),
371                serde_json::json!(chunk.chunk_index),
372            );
373            if let Some(cat) = &doc.category {
374                metadata.insert(
375                    "category".to_string(),
376                    serde_json::Value::String(cat.clone()),
377                );
378            }
379            if self.config.include_metadata {
380                for (k, v) in &doc.metadata {
381                    metadata.insert(k.clone(), serde_json::Value::String(v.clone()));
382                }
383            }
384
385            let entry_id = Uuid::new_v4();
386            let entry = MemoryEntry {
387                id: entry_id,
388                content: chunk.content.clone(),
389                embedding,
390                metadata,
391                session_id: None,
392                created_at: Utc::now(),
393            };
394
395            self.vector_store.insert(entry).await?;
396
397            // Track the chunk in our local index.
398            let mut idx = self.chunk_index.write().await;
399            idx.insert(
400                entry_id,
401                (chunk.clone(), doc.title.clone(), doc.source.clone()),
402            );
403        }
404
405        Ok(chunks)
406    }
407
408    /// Batch-ingest multiple documents.
409    pub async fn ingest_batch(&self, docs: &[Document]) -> ArgentorResult<Vec<Vec<DocumentChunk>>> {
410        let mut all_chunks = Vec::with_capacity(docs.len());
411        for doc in docs {
412            let chunks = self.ingest_document(doc).await?;
413            all_chunks.push(chunks);
414        }
415        Ok(all_chunks)
416    }
417
418    /// Query the knowledge base and return scored, filtered chunks.
419    pub async fn query(&self, question: &str, top_k: Option<usize>) -> ArgentorResult<RagResult> {
420        let start = Instant::now();
421        let k = top_k.unwrap_or(self.config.top_k);
422
423        let query_embedding = self.embedder.embed(question).await?;
424
425        let total_chunks_searched = self.vector_store.count().await?;
426
427        // Retrieve more candidates than requested so we can filter by score.
428        let fetch_k = (k * 3).max(k);
429        let results = self
430            .vector_store
431            .search(&query_embedding, fetch_k, None)
432            .await?;
433
434        let idx = self.chunk_index.read().await;
435
436        let mut scored: Vec<ScoredChunk> = results
437            .into_iter()
438            .filter(|r| r.score >= self.config.min_relevance_score)
439            .filter_map(|r| {
440                let (chunk, title, source) = idx.get(&r.entry.id)?;
441                Some(ScoredChunk {
442                    chunk: chunk.clone(),
443                    score: r.score,
444                    document_title: title.clone(),
445                    source: source.clone(),
446                })
447            })
448            .collect();
449
450        scored.sort_by(|a, b| {
451            b.score
452                .partial_cmp(&a.score)
453                .unwrap_or(std::cmp::Ordering::Equal)
454        });
455        scored.truncate(k);
456
457        let context_text = format_context(&scored, self.config.max_context_tokens);
458
459        let elapsed = start.elapsed().as_millis() as u64;
460
461        Ok(RagResult {
462            chunks: scored,
463            context_text,
464            total_chunks_searched,
465            query_time_ms: elapsed,
466        })
467    }
468
469    /// Query and return a formatted context string sized to fit a given
470    /// context window (in estimated tokens).
471    pub async fn query_with_context(
472        &self,
473        question: &str,
474        top_k: Option<usize>,
475        context_window: usize,
476    ) -> ArgentorResult<RagResult> {
477        let mut result = self.query(question, top_k).await?;
478        // Re-format the context with the caller-specified window size.
479        result.context_text = format_context(&result.chunks, context_window);
480        Ok(result)
481    }
482
483    /// Return a reference to the pipeline configuration.
484    pub fn config(&self) -> &RagConfig {
485        &self.config
486    }
487}
488
489/// Format scored chunks into a context string for LLM injection,
490/// respecting a maximum token budget.
491fn format_context(chunks: &[ScoredChunk], max_tokens: usize) -> String {
492    let mut parts: Vec<String> = Vec::new();
493    let mut token_budget = max_tokens;
494
495    for (i, sc) in chunks.iter().enumerate() {
496        let header = format!(
497            "[Source: {} | Document: {} | Score: {:.2}]",
498            sc.source, sc.document_title, sc.score
499        );
500        let section = format!("--- Chunk {} ---\n{}\n{}", i + 1, header, sc.chunk.content);
501        let section_tokens = estimate_tokens(&section);
502        if section_tokens > token_budget {
503            // Try to include a truncated version if there is room.
504            if token_budget > 20 {
505                let available_chars = token_budget * 4;
506                let truncated: String = section.chars().take(available_chars).collect();
507                parts.push(truncated);
508            }
509            break;
510        }
511        token_budget = token_budget.saturating_sub(section_tokens);
512        parts.push(section);
513    }
514
515    parts.join("\n\n")
516}
517
518// ---------------------------------------------------------------------------
519// Tests
520// ---------------------------------------------------------------------------
521
522#[cfg(test)]
523#[allow(clippy::unwrap_used, clippy::expect_used)]
524mod tests {
525    use super::*;
526    use crate::embedding::LocalEmbedding;
527    use crate::store::InMemoryVectorStore;
528
529    // -- Helpers --
530
531    fn sample_doc(id: &str, title: &str, content: &str) -> Document {
532        Document {
533            id: id.to_string(),
534            title: title.to_string(),
535            content: content.to_string(),
536            source: "test".to_string(),
537            metadata: HashMap::new(),
538            category: None,
539        }
540    }
541
542    fn make_pipeline(config: RagConfig) -> RagPipeline {
543        let store = Arc::new(InMemoryVectorStore::new()) as Arc<dyn VectorStore>;
544        let embedder = Arc::new(LocalEmbedding::default()) as Arc<dyn EmbeddingProvider>;
545        RagPipeline::new(store, embedder, config)
546    }
547
548    fn default_pipeline() -> RagPipeline {
549        make_pipeline(RagConfig::default())
550    }
551
552    // -----------------------------------------------------------------------
553    // Chunking unit tests
554    // -----------------------------------------------------------------------
555
556    #[test]
557    fn test_chunk_fixed_size_basic() {
558        let chunks = chunk_fixed_size("abcdefghij", 4, 0);
559        assert_eq!(chunks.len(), 3); // "abcd", "efgh", "ij"
560        assert_eq!(chunks[0], "abcd");
561        assert_eq!(chunks[1], "efgh");
562        assert_eq!(chunks[2], "ij");
563    }
564
565    #[test]
566    fn test_chunk_fixed_size_with_overlap() {
567        let chunks = chunk_fixed_size("abcdefghij", 5, 2);
568        // step = 5 - 2 = 3, windows: [0..5], [3..8], [6..10]
569        assert_eq!(chunks.len(), 3);
570        assert_eq!(chunks[0], "abcde");
571        assert_eq!(chunks[1], "defgh");
572        assert_eq!(chunks[2], "ghij");
573    }
574
575    #[test]
576    fn test_chunk_fixed_size_empty_text() {
577        let chunks = chunk_fixed_size("", 10, 0);
578        assert!(chunks.is_empty());
579    }
580
581    #[test]
582    fn test_chunk_fixed_size_zero_size() {
583        let chunks = chunk_fixed_size("hello", 0, 0);
584        assert!(chunks.is_empty());
585    }
586
587    #[test]
588    fn test_chunk_fixed_size_text_shorter_than_chunk() {
589        let chunks = chunk_fixed_size("hi", 100, 0);
590        assert_eq!(chunks.len(), 1);
591        assert_eq!(chunks[0], "hi");
592    }
593
594    #[test]
595    fn test_chunk_paragraph_basic() {
596        let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
597        let chunks = chunk_paragraph(text);
598        assert_eq!(chunks.len(), 3);
599        assert_eq!(chunks[0], "First paragraph.");
600        assert_eq!(chunks[1], "Second paragraph.");
601        assert_eq!(chunks[2], "Third paragraph.");
602    }
603
604    #[test]
605    fn test_chunk_paragraph_empty() {
606        let chunks = chunk_paragraph("");
607        assert!(chunks.is_empty());
608    }
609
610    #[test]
611    fn test_chunk_paragraph_single() {
612        let text = "Just one paragraph with no double newlines.";
613        let chunks = chunk_paragraph(text);
614        assert_eq!(chunks.len(), 1);
615        assert_eq!(chunks[0], text);
616    }
617
618    #[test]
619    fn test_chunk_sentence_basic() {
620        let text = "First sentence. Second sentence! Third sentence?";
621        let chunks = chunk_sentence(text);
622        assert_eq!(chunks.len(), 3);
623        assert_eq!(chunks[0], "First sentence.");
624        assert_eq!(chunks[1], "Second sentence!");
625        assert_eq!(chunks[2], "Third sentence?");
626    }
627
628    #[test]
629    fn test_chunk_sentence_no_terminal() {
630        let text = "No terminal punctuation here";
631        let chunks = chunk_sentence(text);
632        assert_eq!(chunks.len(), 1);
633        assert_eq!(chunks[0], text);
634    }
635
636    #[test]
637    fn test_chunk_sentence_empty() {
638        let chunks = chunk_sentence("");
639        assert!(chunks.is_empty());
640    }
641
642    #[test]
643    fn test_chunk_semantic_headings() {
644        let text = "# Heading 1\nParagraph one.\n# Heading 2\nParagraph two.";
645        let chunks = chunk_semantic(text, 1000);
646        // Two sections: "# Heading 1\nParagraph one." and "# Heading 2\nParagraph two."
647        // They fit within 1000 tokens, so they may be merged.
648        assert!(!chunks.is_empty());
649        // With a huge token budget they get merged into one.
650        assert_eq!(chunks.len(), 1);
651    }
652
653    #[test]
654    fn test_chunk_semantic_splits_on_budget() {
655        let text = "# A\nLorem ipsum dolor sit amet.\n# B\nConsectetur adipiscing elit.";
656        // Very small budget forces a split.
657        let chunks = chunk_semantic(text, 8);
658        assert!(chunks.len() >= 2, "small budget should force split");
659    }
660
661    #[test]
662    fn test_chunk_semantic_empty() {
663        let chunks = chunk_semantic("", 100);
664        assert!(chunks.is_empty());
665    }
666
667    #[test]
668    fn test_estimate_tokens() {
669        assert_eq!(estimate_tokens(""), 0);
670        assert_eq!(estimate_tokens("abcd"), 1);
671        // 12 chars -> 3 tokens
672        assert_eq!(estimate_tokens("abcdefghijkl"), 3);
673    }
674
675    // -----------------------------------------------------------------------
676    // Document model tests
677    // -----------------------------------------------------------------------
678
679    #[test]
680    fn test_document_chunk_fields() {
681        let chunk = DocumentChunk {
682            chunk_id: "doc1_chunk_0".into(),
683            document_id: "doc1".into(),
684            content: "hello world".into(),
685            chunk_index: 0,
686            token_estimate: 3,
687        };
688        assert_eq!(chunk.chunk_id, "doc1_chunk_0");
689        assert_eq!(chunk.document_id, "doc1");
690        assert_eq!(chunk.chunk_index, 0);
691        assert_eq!(chunk.token_estimate, 3);
692    }
693
694    #[test]
695    fn test_rag_config_default() {
696        let cfg = RagConfig::default();
697        assert_eq!(cfg.top_k, 5);
698        assert!((cfg.min_relevance_score - 0.3).abs() < f32::EPSILON);
699        assert!(cfg.include_metadata);
700        assert_eq!(cfg.max_context_tokens, 4096);
701    }
702
703    // -----------------------------------------------------------------------
704    // Integration / pipeline tests
705    // -----------------------------------------------------------------------
706
707    #[tokio::test]
708    async fn test_ingest_document_creates_chunks() {
709        let pipeline = default_pipeline();
710        let doc = sample_doc("d1", "Test Doc", "Hello world. This is a test document.");
711        let chunks = pipeline.ingest_document(&doc).await.unwrap();
712        assert!(!chunks.is_empty(), "should produce at least one chunk");
713        // Each chunk should reference the parent document.
714        for c in &chunks {
715            assert_eq!(c.document_id, "d1");
716            assert!(!c.content.is_empty());
717            assert!(c.token_estimate > 0);
718        }
719    }
720
721    #[tokio::test]
722    async fn test_ingest_batch() {
723        let pipeline = default_pipeline();
724        let docs = vec![
725            sample_doc("d1", "Doc One", "Content of document one."),
726            sample_doc("d2", "Doc Two", "Content of document two."),
727        ];
728        let all = pipeline.ingest_batch(&docs).await.unwrap();
729        assert_eq!(all.len(), 2);
730        assert!(!all[0].is_empty());
731        assert!(!all[1].is_empty());
732    }
733
734    #[tokio::test]
735    async fn test_query_returns_results() {
736        let pipeline = default_pipeline();
737        let doc = sample_doc(
738            "d1",
739            "Rust Book",
740            "Rust is a systems programming language focused on safety and performance.",
741        );
742        pipeline.ingest_document(&doc).await.unwrap();
743
744        let result = pipeline.query("rust programming", None).await.unwrap();
745        assert!(!result.chunks.is_empty(), "query should return results");
746        assert!(result.query_time_ms < 10_000, "query should be fast");
747        assert!(
748            result.total_chunks_searched > 0,
749            "should report chunks searched"
750        );
751    }
752
753    #[tokio::test]
754    async fn test_query_context_text_not_empty() {
755        let pipeline = default_pipeline();
756        let doc = sample_doc("d1", "FAQ", "How do I install Rust? Use rustup.");
757        pipeline.ingest_document(&doc).await.unwrap();
758
759        let result = pipeline.query("install rust", None).await.unwrap();
760        assert!(
761            !result.context_text.is_empty(),
762            "context_text should be populated"
763        );
764        assert!(
765            result.context_text.contains("Chunk 1"),
766            "context should include chunk header"
767        );
768    }
769
770    #[tokio::test]
771    async fn test_query_with_context_window() {
772        let cfg = RagConfig {
773            min_relevance_score: 0.0, // accept any score so the test is deterministic
774            ..RagConfig::default()
775        };
776        let pipeline = make_pipeline(cfg);
777        let doc = sample_doc(
778            "d1",
779            "Long Doc",
780            "Rust programming language. Memory safety without garbage collection. Zero-cost abstractions.",
781        );
782        pipeline.ingest_document(&doc).await.unwrap();
783
784        let result = pipeline
785            .query_with_context("rust programming language", None, 8192)
786            .await
787            .unwrap();
788        assert!(!result.context_text.is_empty());
789    }
790
791    #[tokio::test]
792    async fn test_query_min_relevance_filter() {
793        let cfg = RagConfig {
794            min_relevance_score: 0.99, // very high threshold
795            ..RagConfig::default()
796        };
797        let pipeline = make_pipeline(cfg);
798        let doc = sample_doc("d1", "Doc", "some random content about various topics");
799        pipeline.ingest_document(&doc).await.unwrap();
800
801        let result = pipeline
802            .query("completely unrelated xyz", None)
803            .await
804            .unwrap();
805        // With a 0.99 threshold most results should be filtered out.
806        // We don't assert exact count because LocalEmbedding is approximate.
807        assert!(
808            result.chunks.len() <= 1,
809            "high threshold should filter most results"
810        );
811    }
812
813    #[tokio::test]
814    async fn test_scored_chunk_has_metadata() {
815        let pipeline = default_pipeline();
816        let doc = sample_doc("d1", "My Title", "Content about Rust programming.");
817        pipeline.ingest_document(&doc).await.unwrap();
818
819        let result = pipeline.query("rust", None).await.unwrap();
820        if let Some(sc) = result.chunks.first() {
821            assert_eq!(sc.document_title, "My Title");
822            assert_eq!(sc.source, "test");
823            assert!(sc.score > 0.0);
824        }
825    }
826
827    #[tokio::test]
828    async fn test_config_accessor() {
829        let pipeline = default_pipeline();
830        assert_eq!(pipeline.config().top_k, 5);
831    }
832
833    #[test]
834    fn test_format_context_empty() {
835        let ctx = format_context(&[], 4096);
836        assert!(ctx.is_empty());
837    }
838
839    #[test]
840    fn test_format_context_includes_source() {
841        let chunks = vec![ScoredChunk {
842            chunk: DocumentChunk {
843                chunk_id: "c1".into(),
844                document_id: "d1".into(),
845                content: "Hello".into(),
846                chunk_index: 0,
847                token_estimate: 2,
848            },
849            score: 0.95,
850            document_title: "Title".into(),
851            source: "kb".into(),
852        }];
853        let ctx = format_context(&chunks, 4096);
854        assert!(ctx.contains("kb"));
855        assert!(ctx.contains("Title"));
856        assert!(ctx.contains("0.95"));
857        assert!(ctx.contains("Hello"));
858    }
859}