1#![deny(missing_docs)]
27#![allow(clippy::redundant_closure)]
28
29pub mod chunker;
30pub mod embedding;
31pub mod hybrid;
32pub mod rag_context;
33
34use async_trait::async_trait;
35use serde::{Deserialize, Serialize};
36
37pub use chunker::{Chunker, FixedSizeChunker, RecursiveChunker, SentenceChunker};
38pub use embedding::{EmbeddingProvider, EmbeddingRetriever};
39pub use hybrid::{CitationStrategy, ContextWindowStrategy, HybridRetriever};
40pub use rag_context::RagContextManager;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct Document {
45 pub id: String,
47 pub content: String,
49 pub metadata: Option<serde_json::Value>,
51 pub score: f64,
53}
54
55impl Document {
56 #[must_use]
58 pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
59 Self {
60 id: id.into(),
61 content: content.into(),
62 metadata: None,
63 score: 0.0,
64 }
65 }
66}
67
68#[async_trait]
72pub trait Retriever: Send + Sync + 'static {
73 async fn retrieve(&self, query: &str, limit: usize) -> traitclaw_core::Result<Vec<Document>>;
75}
76
77pub trait GroundingStrategy: Send + Sync + 'static {
79 fn ground(&self, documents: &[Document]) -> String;
81}
82
83pub struct PrependStrategy;
85
86impl GroundingStrategy for PrependStrategy {
87 fn ground(&self, documents: &[Document]) -> String {
88 if documents.is_empty() {
89 return String::new();
90 }
91 let mut ctx = String::from("Relevant context:\n\n");
92 for (i, doc) in documents.iter().enumerate() {
93 use std::fmt::Write;
94 let _ = write!(ctx, "[{}] {}\n\n", i + 1, doc.content);
95 }
96 ctx
97 }
98}
99
100pub struct KeywordRetriever {
104 documents: Vec<Document>,
105}
106
107impl KeywordRetriever {
108 #[must_use]
110 pub fn new() -> Self {
111 Self {
112 documents: Vec::new(),
113 }
114 }
115
116 pub fn add(&mut self, doc: Document) {
118 self.documents.push(doc);
119 }
120
121 pub fn add_many(&mut self, docs: impl IntoIterator<Item = Document>) {
123 self.documents.extend(docs);
124 }
125
126 fn score(query_terms: &[String], content: &str) -> f64 {
128 let content_lower = content.to_lowercase();
129 let words: Vec<&str> = content_lower.split_whitespace().collect();
130 let doc_len = words.len() as f64;
131
132 if doc_len == 0.0 {
133 return 0.0;
134 }
135
136 let mut total_score = 0.0;
137 for term in query_terms {
138 let tf = words.iter().filter(|w| **w == term.as_str()).count() as f64;
139 let score = tf / (tf + 1.0);
142 total_score += score;
143 }
144
145 total_score
146 }
147}
148
149impl Default for KeywordRetriever {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155#[async_trait]
156impl Retriever for KeywordRetriever {
157 async fn retrieve(&self, query: &str, limit: usize) -> traitclaw_core::Result<Vec<Document>> {
158 let terms: Vec<String> = query
159 .to_lowercase()
160 .split_whitespace()
161 .map(String::from)
162 .collect();
163
164 let mut scored: Vec<Document> = self
165 .documents
166 .iter()
167 .map(|doc| {
168 let mut d = doc.clone();
169 d.score = Self::score(&terms, &doc.content);
170 d
171 })
172 .filter(|d| d.score > 0.0)
173 .collect();
174
175 scored.sort_by(|a, b| {
176 b.score
177 .partial_cmp(&a.score)
178 .unwrap_or(std::cmp::Ordering::Equal)
179 });
180 scored.truncate(limit);
181
182 Ok(scored)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[tokio::test]
191 async fn test_keyword_retriever_basic() {
192 let mut r = KeywordRetriever::new();
193 r.add(Document::new("1", "Rust is a systems programming language"));
194 r.add(Document::new("2", "Python is great for data science"));
195 r.add(Document::new("3", "Rust has zero-cost abstractions"));
196
197 let results = r.retrieve("Rust programming", 10).await.unwrap();
198 assert!(!results.is_empty());
199 assert!(results.len() >= 2);
201 assert_eq!(results[0].id, "1");
203 }
204
205 #[tokio::test]
206 async fn test_keyword_retriever_empty_query() {
207 let mut r = KeywordRetriever::new();
208 r.add(Document::new("1", "Some content"));
209 let results = r.retrieve("", 10).await.unwrap();
210 assert!(results.is_empty());
211 }
212
213 #[tokio::test]
214 async fn test_keyword_retriever_no_match() {
215 let mut r = KeywordRetriever::new();
216 r.add(Document::new("1", "Hello world"));
217 let results = r.retrieve("quantum computing", 10).await.unwrap();
218 assert!(results.is_empty());
219 }
220
221 #[tokio::test]
222 async fn test_keyword_retriever_limit() {
223 let mut r = KeywordRetriever::new();
224 for i in 0..10 {
225 r.add(Document::new(format!("{i}"), format!("rust item {i}")));
226 }
227 let results = r.retrieve("rust", 3).await.unwrap();
228 assert_eq!(results.len(), 3);
229 }
230
231 #[test]
232 fn test_prepend_strategy() {
233 let docs = vec![
234 Document::new("1", "First doc"),
235 Document::new("2", "Second doc"),
236 ];
237 let ctx = PrependStrategy.ground(&docs);
238 assert!(ctx.contains("[1] First doc"));
239 assert!(ctx.contains("[2] Second doc"));
240 }
241
242 #[test]
243 fn test_prepend_strategy_empty() {
244 let ctx = PrependStrategy.ground(&[]);
245 assert!(ctx.is_empty());
246 }
247
248 #[test]
249 fn test_document_new() {
250 let doc = Document::new("id1", "content1");
251 assert_eq!(doc.id, "id1");
252 assert_eq!(doc.content, "content1");
253 assert!(doc.metadata.is_none());
254 assert!((doc.score - 0.0).abs() < f64::EPSILON);
255 }
256}