1use std::collections::HashSet;
2use std::hash::{Hash, Hasher};
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::{json, Value};
7use tokio::sync::RwLock;
8
9use crate::error::Result;
10
11#[derive(Clone, Debug)]
12pub struct Document {
13 pub id: String,
14 pub text: String,
15 pub metadata: Value,
16}
17
18#[derive(Clone, Debug)]
19pub struct ScoredDocument {
20 pub document: Document,
21 pub score: f32,
22}
23
24#[async_trait]
25pub trait Embedder: Send + Sync {
26 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
27}
28
29#[async_trait]
30pub trait OpenAiEmbeddingClient: Send + Sync {
31 async fn embed(&self, model: &str, input: &str) -> Result<Vec<f32>>;
32}
33
34pub struct OpenAiEmbedder<C> {
36 client: Arc<C>,
37 model: String,
38}
39
40impl<C> OpenAiEmbedder<C> {
41 pub fn new(client: Arc<C>, model: impl Into<String>) -> Self {
42 Self {
43 client,
44 model: model.into(),
45 }
46 }
47}
48
49#[async_trait]
50impl<C> Embedder for OpenAiEmbedder<C>
51where
52 C: OpenAiEmbeddingClient,
53{
54 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
55 self.client.embed(&self.model, text).await
56 }
57}
58
59#[async_trait]
60pub trait TransformerClient: Send + Sync {
61 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
62}
63
64pub struct TransformerEmbedder<C> {
66 client: Arc<C>,
67}
68
69impl<C> TransformerEmbedder<C> {
70 pub fn new(client: Arc<C>) -> Self {
71 Self { client }
72 }
73}
74
75#[async_trait]
76impl<C> Embedder for TransformerEmbedder<C>
77where
78 C: TransformerClient,
79{
80 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
81 self.client.embed(text).await
82 }
83}
84
85#[async_trait]
86pub trait VectorStore: Send + Sync {
87 async fn add(&self, document: Document, embedding: Vec<f32>) -> Result<()>;
88 async fn search(
89 &self,
90 embedding: Vec<f32>,
91 params: SearchParams,
92 ) -> Result<Vec<ScoredDocument>>;
93}
94
95pub struct WhitespaceEmbedder {
97 buckets: usize,
98}
99
100impl Default for WhitespaceEmbedder {
101 fn default() -> Self {
102 Self { buckets: 32 }
103 }
104}
105
106impl WhitespaceEmbedder {
107 pub fn new(buckets: usize) -> Self {
108 Self { buckets }
109 }
110}
111
112#[async_trait]
113impl Embedder for WhitespaceEmbedder {
114 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
115 let mut vector = vec![0.0; self.buckets];
116
117 for token in text.split_whitespace() {
118 let mut hasher = std::collections::hash_map::DefaultHasher::new();
119 token.hash(&mut hasher);
120 let idx = (hasher.finish() as usize) % self.buckets;
121 vector[idx] += 1.0;
122 }
123
124 Ok(vector)
125 }
126}
127
128#[derive(Default)]
129pub struct InMemoryVectorStore {
130 entries: RwLock<Vec<(Document, Vec<f32>)>>,
131}
132
133#[async_trait]
134impl VectorStore for InMemoryVectorStore {
135 async fn add(&self, document: Document, embedding: Vec<f32>) -> Result<()> {
136 self.entries.write().await.push((document, embedding));
137 Ok(())
138 }
139
140 async fn search(
141 &self,
142 embedding: Vec<f32>,
143 params: SearchParams,
144 ) -> Result<Vec<ScoredDocument>> {
145 let entries = self.entries.read().await;
146 let mut scored: Vec<ScoredDocument> = entries
147 .iter()
148 .map(|(doc, stored)| ScoredDocument {
149 document: doc.clone(),
150 score: similarity(stored, &embedding, params.similarity),
151 })
152 .collect();
153
154 scored.sort_by(|a, b| {
155 b.score
156 .partial_cmp(&a.score)
157 .unwrap_or(std::cmp::Ordering::Equal)
158 });
159 scored.truncate(params.top_k);
160 Ok(scored)
161 }
162}
163
164fn similarity(a: &[f32], b: &[f32], metric: SimilarityMetric) -> f32 {
165 let (mut dot, mut norm_a, mut norm_b) = (0.0, 0.0, 0.0);
166 for (x, y) in a.iter().zip(b.iter()) {
167 dot += x * y;
168 norm_a += x * x;
169 norm_b += y * y;
170 }
171
172 match metric {
173 SimilarityMetric::Cosine => {
174 if norm_a == 0.0 || norm_b == 0.0 {
175 0.0
176 } else {
177 dot / (norm_a.sqrt() * norm_b.sqrt())
178 }
179 }
180 SimilarityMetric::DotProduct => dot,
181 SimilarityMetric::Euclidean => {
182 let mut squared_distance = 0.0;
184 for (x, y) in a.iter().zip(b.iter()) {
185 let diff = x - y;
186 squared_distance += diff * diff;
187 }
188 1.0 / (1.0 + squared_distance.sqrt())
189 }
190 }
191}
192
193#[derive(Clone, Copy, Debug)]
194pub enum SimilarityMetric {
195 Cosine,
196 DotProduct,
197 Euclidean,
198}
199
200#[derive(Clone, Debug)]
201pub struct SearchParams {
202 pub top_k: usize,
203 pub similarity: SimilarityMetric,
204}
205
206impl Default for SearchParams {
207 fn default() -> Self {
208 Self {
209 top_k: 5,
210 similarity: SimilarityMetric::Cosine,
211 }
212 }
213}
214
215#[async_trait]
216pub trait PgVectorClient: Send + Sync {
217 async fn upsert(&self, document: &Document, embedding: &[f32]) -> Result<()>;
218 async fn query(&self, embedding: &[f32], params: SearchParams) -> Result<Vec<ScoredDocument>>;
219}
220
221pub struct PgVectorStore<C> {
223 client: Arc<C>,
224}
225
226impl<C> PgVectorStore<C> {
227 pub fn new(client: Arc<C>) -> Self {
228 Self { client }
229 }
230}
231
232#[async_trait]
233impl<C> VectorStore for PgVectorStore<C>
234where
235 C: PgVectorClient,
236{
237 async fn add(&self, document: Document, embedding: Vec<f32>) -> Result<()> {
238 self.client.upsert(&document, &embedding).await
239 }
240
241 async fn search(
242 &self,
243 embedding: Vec<f32>,
244 params: SearchParams,
245 ) -> Result<Vec<ScoredDocument>> {
246 self.client.query(&embedding, params).await
247 }
248}
249
250#[async_trait]
251pub trait QdrantClient: Send + Sync {
252 async fn upsert(&self, document: &Document, embedding: &[f32]) -> Result<()>;
253 async fn query(&self, embedding: &[f32], params: SearchParams) -> Result<Vec<ScoredDocument>>;
254}
255
256pub struct QdrantStore<C> {
258 client: Arc<C>,
259}
260
261impl<C> QdrantStore<C> {
262 pub fn new(client: Arc<C>) -> Self {
263 Self { client }
264 }
265}
266
267#[async_trait]
268impl<C> VectorStore for QdrantStore<C>
269where
270 C: QdrantClient,
271{
272 async fn add(&self, document: Document, embedding: Vec<f32>) -> Result<()> {
273 self.client.upsert(&document, &embedding).await
274 }
275
276 async fn search(
277 &self,
278 embedding: Vec<f32>,
279 params: SearchParams,
280 ) -> Result<Vec<ScoredDocument>> {
281 self.client.query(&embedding, params).await
282 }
283}
284
285pub trait DocumentChunker: Send + Sync {
286 fn chunk(&self, document: &Document) -> Vec<Document>;
287}
288
289pub struct SlidingWindowChunker {
291 pub max_tokens: usize,
292 pub overlap: usize,
293}
294
295impl Default for SlidingWindowChunker {
296 fn default() -> Self {
297 Self {
298 max_tokens: 256,
299 overlap: 32,
300 }
301 }
302}
303
304impl DocumentChunker for SlidingWindowChunker {
305 fn chunk(&self, document: &Document) -> Vec<Document> {
306 if document.text.is_empty() {
307 return vec![document.clone()];
308 }
309
310 let tokens: Vec<&str> = document.text.split_whitespace().collect();
311 if tokens.len() <= self.max_tokens {
312 return vec![document.clone()];
313 }
314
315 let mut chunks = Vec::new();
316 let mut start = 0usize;
317 let mut chunk_index = 0usize;
318
319 while start < tokens.len() {
320 let end = usize::min(start + self.max_tokens, tokens.len());
321 let text = tokens[start..end].join(" ");
322 let mut metadata = document.metadata.clone();
323
324 if let Value::Object(map) = &mut metadata {
325 map.insert("chunk_index".to_string(), Value::from(chunk_index as u64));
326 map.insert("source_id".to_string(), Value::from(document.id.clone()));
327 } else {
328 metadata = json!({
329 "chunk_index": chunk_index,
330 "source_id": document.id
331 });
332 }
333
334 chunks.push(Document {
335 id: format!("{}::{}", document.id, chunk_index),
336 text,
337 metadata,
338 });
339
340 if end == tokens.len() {
341 break;
342 }
343
344 start = end.saturating_sub(self.overlap.min(end - start));
345 chunk_index += 1;
346 }
347
348 chunks
349 }
350}
351
352pub type Reranker = Arc<dyn Fn(&ScoredDocument) -> f32 + Send + Sync>;
353
354pub struct KnowledgeBase<E: Embedder, S: VectorStore> {
355 embedder: Arc<E>,
356 store: Arc<S>,
357 config: RetrievalConfig,
358 chunker: Option<Arc<dyn DocumentChunker>>,
359}
360
361impl<E: Embedder, S: VectorStore> KnowledgeBase<E, S> {
362 pub fn new(embedder: Arc<E>, store: Arc<S>) -> Self {
363 Self {
364 embedder,
365 store,
366 config: RetrievalConfig::default(),
367 chunker: None,
368 }
369 }
370
371 pub fn with_reranker(mut self, reranker: Reranker) -> Self {
372 self.config.reranker = Some(reranker);
373 self
374 }
375
376 pub fn with_chunker(mut self, chunker: Arc<dyn DocumentChunker>) -> Self {
377 self.chunker = Some(chunker);
378 self
379 }
380
381 pub fn with_config(mut self, config: RetrievalConfig) -> Self {
382 self.config = config;
383 self
384 }
385
386 pub fn config(&self) -> &RetrievalConfig {
387 &self.config
388 }
389
390 pub async fn add_document(&self, document: Document) -> Result<()> {
391 let chunks = if let Some(chunker) = &self.chunker {
392 chunker.chunk(&document)
393 } else {
394 vec![document]
395 };
396
397 for chunk in chunks {
398 let embedding = self.embedder.embed(&chunk.text).await?;
399 self.store.add(chunk, embedding).await?;
400 }
401
402 Ok(())
403 }
404
405 pub async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<ScoredDocument>> {
406 let overrides = RetrievalOverrides {
407 top_k: Some(top_k),
408 ..Default::default()
409 };
410 self.retrieve_with_overrides(query, overrides).await
411 }
412
413 pub async fn retrieve_with_overrides(
414 &self,
415 query: &str,
416 overrides: RetrievalOverrides,
417 ) -> Result<Vec<ScoredDocument>> {
418 let embedding = self.embedder.embed(query).await?;
419 let params = SearchParams {
420 top_k: overrides.top_k.unwrap_or(self.config.top_k),
421 similarity: overrides.similarity.unwrap_or(self.config.similarity),
422 };
423 let mut scored = self.store.search(embedding, params).await?;
424
425 if let Some(reranker) = overrides.reranker.or_else(|| self.config.reranker.clone()) {
426 for doc in scored.iter_mut() {
427 doc.score = reranker(doc);
428 }
429 scored.sort_by(|a, b| {
430 b.score
431 .partial_cmp(&a.score)
432 .unwrap_or(std::cmp::Ordering::Equal)
433 });
434 }
435
436 Ok(scored)
437 }
438
439 pub async fn evaluate(
440 &self,
441 query: &str,
442 relevant_document_ids: &[String],
443 overrides: RetrievalOverrides,
444 ) -> Result<RetrievalEvaluation> {
445 let retrieved = self.retrieve_with_overrides(query, overrides).await?;
446 let retrieved_ids: HashSet<String> =
447 retrieved.iter().map(|d| d.document.id.clone()).collect();
448 let relevant: HashSet<String> = relevant_document_ids.iter().cloned().collect();
449
450 let hits = relevant.intersection(&retrieved_ids).count() as f32;
451 let precision = if retrieved.is_empty() {
452 0.0
453 } else {
454 hits / retrieved.len() as f32
455 };
456 let recall = if relevant.is_empty() {
457 0.0
458 } else {
459 hits / relevant.len() as f32
460 };
461
462 Ok(RetrievalEvaluation {
463 retrieved,
464 precision,
465 recall,
466 })
467 }
468}
469
470#[async_trait]
471pub trait Retriever: Send + Sync {
472 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<String>>;
473}
474
475#[async_trait]
476impl<E, S> Retriever for KnowledgeBase<E, S>
477where
478 E: Embedder,
479 S: VectorStore,
480{
481 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<String>> {
482 let docs = KnowledgeBase::retrieve(self, query, top_k).await?;
483 Ok(docs.into_iter().map(|d| d.document.text).collect())
484 }
485}
486
487#[derive(Clone)]
488pub struct RetrievalConfig {
489 pub top_k: usize,
490 pub similarity: SimilarityMetric,
491 pub reranker: Option<Reranker>,
492}
493
494impl Default for RetrievalConfig {
495 fn default() -> Self {
496 Self {
497 top_k: 5,
498 similarity: SimilarityMetric::Cosine,
499 reranker: None,
500 }
501 }
502}
503
504#[derive(Clone, Default)]
505pub struct RetrievalOverrides {
506 pub top_k: Option<usize>,
507 pub similarity: Option<SimilarityMetric>,
508 pub reranker: Option<Reranker>,
509}
510
511pub struct RetrievalEvaluation {
512 pub retrieved: Vec<ScoredDocument>,
513 pub precision: f32,
514 pub recall: f32,
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 struct TestEmbedder;
522
523 #[async_trait]
524 impl Embedder for TestEmbedder {
525 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
526 Ok(vec![text.len() as f32])
527 }
528 }
529
530 #[tokio::test]
531 async fn chunks_documents() {
532 let embedder = Arc::new(TestEmbedder);
533 let store = Arc::new(InMemoryVectorStore::default());
534 let kb: KnowledgeBase<_, _> =
535 KnowledgeBase::new(embedder, store).with_chunker(Arc::new(SlidingWindowChunker {
536 max_tokens: 2,
537 overlap: 0,
538 }));
539
540 kb.add_document(Document {
541 id: "doc".into(),
542 text: "a b c d".into(),
543 metadata: Value::Null,
544 })
545 .await
546 .unwrap();
547
548 let scored = kb.retrieve("a b", 10).await.unwrap();
549 assert_eq!(scored.len(), 2);
550 }
551
552 #[tokio::test]
553 async fn evaluates_precision_recall() {
554 let embedder = Arc::new(TestEmbedder);
555 let store = Arc::new(InMemoryVectorStore::default());
556 let kb: KnowledgeBase<_, _> = KnowledgeBase::new(embedder, store);
557
558 kb.add_document(Document {
559 id: "d1".into(),
560 text: "hello world".into(),
561 metadata: Value::Null,
562 })
563 .await
564 .unwrap();
565 kb.add_document(Document {
566 id: "d2".into(),
567 text: "other".into(),
568 metadata: Value::Null,
569 })
570 .await
571 .unwrap();
572
573 let report = kb
574 .evaluate(
575 "hello",
576 &[String::from("d1")],
577 RetrievalOverrides {
578 top_k: Some(1),
579 ..Default::default()
580 },
581 )
582 .await
583 .unwrap();
584
585 assert_eq!(report.recall, 1.0);
586 assert_eq!(report.precision, 1.0);
587 }
588}