sayr_engine/
knowledge.rs

1use std::collections::HashSet;
2use std::hash::{Hash, Hasher};
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde_json::{json, Value};
7use tokio::sync::RwLock;
8
9use crate::error::Result;
10
11#[derive(Clone, Debug)]
12pub struct Document {
13    pub id: String,
14    pub text: String,
15    pub metadata: Value,
16}
17
18#[derive(Clone, Debug)]
19pub struct ScoredDocument {
20    pub document: Document,
21    pub score: f32,
22}
23
24#[async_trait]
25pub trait Embedder: Send + Sync {
26    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
27}
28
29#[async_trait]
30pub trait OpenAiEmbeddingClient: Send + Sync {
31    async fn embed(&self, model: &str, input: &str) -> Result<Vec<f32>>;
32}
33
34/// Embedder that delegates to an OpenAI-compatible embedding client.
35pub struct OpenAiEmbedder<C> {
36    client: Arc<C>,
37    model: String,
38}
39
40impl<C> OpenAiEmbedder<C> {
41    pub fn new(client: Arc<C>, model: impl Into<String>) -> Self {
42        Self {
43            client,
44            model: model.into(),
45        }
46    }
47}
48
49#[async_trait]
50impl<C> Embedder for OpenAiEmbedder<C>
51where
52    C: OpenAiEmbeddingClient,
53{
54    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
55        self.client.embed(&self.model, text).await
56    }
57}
58
59#[async_trait]
60pub trait TransformerClient: Send + Sync {
61    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
62}
63
64/// Embedder that wraps a transformer runtime (e.g., candle, ort, ggml).
65pub struct TransformerEmbedder<C> {
66    client: Arc<C>,
67}
68
69impl<C> TransformerEmbedder<C> {
70    pub fn new(client: Arc<C>) -> Self {
71        Self { client }
72    }
73}
74
75#[async_trait]
76impl<C> Embedder for TransformerEmbedder<C>
77where
78    C: TransformerClient,
79{
80    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
81        self.client.embed(text).await
82    }
83}
84
85#[async_trait]
86pub trait VectorStore: Send + Sync {
87    async fn add(&self, document: Document, embedding: Vec<f32>) -> Result<()>;
88    async fn search(
89        &self,
90        embedding: Vec<f32>,
91        params: SearchParams,
92    ) -> Result<Vec<ScoredDocument>>;
93}
94
95/// Basic whitespace tokenizer with hashed buckets for deterministic embeddings.
96pub struct WhitespaceEmbedder {
97    buckets: usize,
98}
99
100impl Default for WhitespaceEmbedder {
101    fn default() -> Self {
102        Self { buckets: 32 }
103    }
104}
105
106impl WhitespaceEmbedder {
107    pub fn new(buckets: usize) -> Self {
108        Self { buckets }
109    }
110}
111
112#[async_trait]
113impl Embedder for WhitespaceEmbedder {
114    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
115        let mut vector = vec![0.0; self.buckets];
116
117        for token in text.split_whitespace() {
118            let mut hasher = std::collections::hash_map::DefaultHasher::new();
119            token.hash(&mut hasher);
120            let idx = (hasher.finish() as usize) % self.buckets;
121            vector[idx] += 1.0;
122        }
123
124        Ok(vector)
125    }
126}
127
128#[derive(Default)]
129pub struct InMemoryVectorStore {
130    entries: RwLock<Vec<(Document, Vec<f32>)>>,
131}
132
133#[async_trait]
134impl VectorStore for InMemoryVectorStore {
135    async fn add(&self, document: Document, embedding: Vec<f32>) -> Result<()> {
136        self.entries.write().await.push((document, embedding));
137        Ok(())
138    }
139
140    async fn search(
141        &self,
142        embedding: Vec<f32>,
143        params: SearchParams,
144    ) -> Result<Vec<ScoredDocument>> {
145        let entries = self.entries.read().await;
146        let mut scored: Vec<ScoredDocument> = entries
147            .iter()
148            .map(|(doc, stored)| ScoredDocument {
149                document: doc.clone(),
150                score: similarity(stored, &embedding, params.similarity),
151            })
152            .collect();
153
154        scored.sort_by(|a, b| {
155            b.score
156                .partial_cmp(&a.score)
157                .unwrap_or(std::cmp::Ordering::Equal)
158        });
159        scored.truncate(params.top_k);
160        Ok(scored)
161    }
162}
163
164fn similarity(a: &[f32], b: &[f32], metric: SimilarityMetric) -> f32 {
165    let (mut dot, mut norm_a, mut norm_b) = (0.0, 0.0, 0.0);
166    for (x, y) in a.iter().zip(b.iter()) {
167        dot += x * y;
168        norm_a += x * x;
169        norm_b += y * y;
170    }
171
172    match metric {
173        SimilarityMetric::Cosine => {
174            if norm_a == 0.0 || norm_b == 0.0 {
175                0.0
176            } else {
177                dot / (norm_a.sqrt() * norm_b.sqrt())
178            }
179        }
180        SimilarityMetric::DotProduct => dot,
181        SimilarityMetric::Euclidean => {
182            // Invert distance so higher is better while keeping the return type consistent.
183            let mut squared_distance = 0.0;
184            for (x, y) in a.iter().zip(b.iter()) {
185                let diff = x - y;
186                squared_distance += diff * diff;
187            }
188            1.0 / (1.0 + squared_distance.sqrt())
189        }
190    }
191}
192
193#[derive(Clone, Copy, Debug)]
194pub enum SimilarityMetric {
195    Cosine,
196    DotProduct,
197    Euclidean,
198}
199
200#[derive(Clone, Debug)]
201pub struct SearchParams {
202    pub top_k: usize,
203    pub similarity: SimilarityMetric,
204}
205
206impl Default for SearchParams {
207    fn default() -> Self {
208        Self {
209            top_k: 5,
210            similarity: SimilarityMetric::Cosine,
211        }
212    }
213}
214
215#[async_trait]
216pub trait PgVectorClient: Send + Sync {
217    async fn upsert(&self, document: &Document, embedding: &[f32]) -> Result<()>;
218    async fn query(&self, embedding: &[f32], params: SearchParams) -> Result<Vec<ScoredDocument>>;
219}
220
221/// Adapter for Postgres/pgvector style databases.
222pub struct PgVectorStore<C> {
223    client: Arc<C>,
224}
225
226impl<C> PgVectorStore<C> {
227    pub fn new(client: Arc<C>) -> Self {
228        Self { client }
229    }
230}
231
232#[async_trait]
233impl<C> VectorStore for PgVectorStore<C>
234where
235    C: PgVectorClient,
236{
237    async fn add(&self, document: Document, embedding: Vec<f32>) -> Result<()> {
238        self.client.upsert(&document, &embedding).await
239    }
240
241    async fn search(
242        &self,
243        embedding: Vec<f32>,
244        params: SearchParams,
245    ) -> Result<Vec<ScoredDocument>> {
246        self.client.query(&embedding, params).await
247    }
248}
249
250#[async_trait]
251pub trait QdrantClient: Send + Sync {
252    async fn upsert(&self, document: &Document, embedding: &[f32]) -> Result<()>;
253    async fn query(&self, embedding: &[f32], params: SearchParams) -> Result<Vec<ScoredDocument>>;
254}
255
256/// Adapter for Qdrant (or other HTTP/gRPC vector databases).
257pub struct QdrantStore<C> {
258    client: Arc<C>,
259}
260
261impl<C> QdrantStore<C> {
262    pub fn new(client: Arc<C>) -> Self {
263        Self { client }
264    }
265}
266
267#[async_trait]
268impl<C> VectorStore for QdrantStore<C>
269where
270    C: QdrantClient,
271{
272    async fn add(&self, document: Document, embedding: Vec<f32>) -> Result<()> {
273        self.client.upsert(&document, &embedding).await
274    }
275
276    async fn search(
277        &self,
278        embedding: Vec<f32>,
279        params: SearchParams,
280    ) -> Result<Vec<ScoredDocument>> {
281        self.client.query(&embedding, params).await
282    }
283}
284
285pub trait DocumentChunker: Send + Sync {
286    fn chunk(&self, document: &Document) -> Vec<Document>;
287}
288
289/// Token (word) based chunker with sliding window overlap.
290pub struct SlidingWindowChunker {
291    pub max_tokens: usize,
292    pub overlap: usize,
293}
294
295impl Default for SlidingWindowChunker {
296    fn default() -> Self {
297        Self {
298            max_tokens: 256,
299            overlap: 32,
300        }
301    }
302}
303
304impl DocumentChunker for SlidingWindowChunker {
305    fn chunk(&self, document: &Document) -> Vec<Document> {
306        if document.text.is_empty() {
307            return vec![document.clone()];
308        }
309
310        let tokens: Vec<&str> = document.text.split_whitespace().collect();
311        if tokens.len() <= self.max_tokens {
312            return vec![document.clone()];
313        }
314
315        let mut chunks = Vec::new();
316        let mut start = 0usize;
317        let mut chunk_index = 0usize;
318
319        while start < tokens.len() {
320            let end = usize::min(start + self.max_tokens, tokens.len());
321            let text = tokens[start..end].join(" ");
322            let mut metadata = document.metadata.clone();
323
324            if let Value::Object(map) = &mut metadata {
325                map.insert("chunk_index".to_string(), Value::from(chunk_index as u64));
326                map.insert("source_id".to_string(), Value::from(document.id.clone()));
327            } else {
328                metadata = json!({
329                    "chunk_index": chunk_index,
330                    "source_id": document.id
331                });
332            }
333
334            chunks.push(Document {
335                id: format!("{}::{}", document.id, chunk_index),
336                text,
337                metadata,
338            });
339
340            if end == tokens.len() {
341                break;
342            }
343
344            start = end.saturating_sub(self.overlap.min(end - start));
345            chunk_index += 1;
346        }
347
348        chunks
349    }
350}
351
352pub type Reranker = Arc<dyn Fn(&ScoredDocument) -> f32 + Send + Sync>;
353
354pub struct KnowledgeBase<E: Embedder, S: VectorStore> {
355    embedder: Arc<E>,
356    store: Arc<S>,
357    config: RetrievalConfig,
358    chunker: Option<Arc<dyn DocumentChunker>>,
359}
360
361impl<E: Embedder, S: VectorStore> KnowledgeBase<E, S> {
362    pub fn new(embedder: Arc<E>, store: Arc<S>) -> Self {
363        Self {
364            embedder,
365            store,
366            config: RetrievalConfig::default(),
367            chunker: None,
368        }
369    }
370
371    pub fn with_reranker(mut self, reranker: Reranker) -> Self {
372        self.config.reranker = Some(reranker);
373        self
374    }
375
376    pub fn with_chunker(mut self, chunker: Arc<dyn DocumentChunker>) -> Self {
377        self.chunker = Some(chunker);
378        self
379    }
380
381    pub fn with_config(mut self, config: RetrievalConfig) -> Self {
382        self.config = config;
383        self
384    }
385
386    pub fn config(&self) -> &RetrievalConfig {
387        &self.config
388    }
389
390    pub async fn add_document(&self, document: Document) -> Result<()> {
391        let chunks = if let Some(chunker) = &self.chunker {
392            chunker.chunk(&document)
393        } else {
394            vec![document]
395        };
396
397        for chunk in chunks {
398            let embedding = self.embedder.embed(&chunk.text).await?;
399            self.store.add(chunk, embedding).await?;
400        }
401
402        Ok(())
403    }
404
405    pub async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<ScoredDocument>> {
406        let overrides = RetrievalOverrides {
407            top_k: Some(top_k),
408            ..Default::default()
409        };
410        self.retrieve_with_overrides(query, overrides).await
411    }
412
413    pub async fn retrieve_with_overrides(
414        &self,
415        query: &str,
416        overrides: RetrievalOverrides,
417    ) -> Result<Vec<ScoredDocument>> {
418        let embedding = self.embedder.embed(query).await?;
419        let params = SearchParams {
420            top_k: overrides.top_k.unwrap_or(self.config.top_k),
421            similarity: overrides.similarity.unwrap_or(self.config.similarity),
422        };
423        let mut scored = self.store.search(embedding, params).await?;
424
425        if let Some(reranker) = overrides.reranker.or_else(|| self.config.reranker.clone()) {
426            for doc in scored.iter_mut() {
427                doc.score = reranker(doc);
428            }
429            scored.sort_by(|a, b| {
430                b.score
431                    .partial_cmp(&a.score)
432                    .unwrap_or(std::cmp::Ordering::Equal)
433            });
434        }
435
436        Ok(scored)
437    }
438
439    pub async fn evaluate(
440        &self,
441        query: &str,
442        relevant_document_ids: &[String],
443        overrides: RetrievalOverrides,
444    ) -> Result<RetrievalEvaluation> {
445        let retrieved = self.retrieve_with_overrides(query, overrides).await?;
446        let retrieved_ids: HashSet<String> =
447            retrieved.iter().map(|d| d.document.id.clone()).collect();
448        let relevant: HashSet<String> = relevant_document_ids.iter().cloned().collect();
449
450        let hits = relevant.intersection(&retrieved_ids).count() as f32;
451        let precision = if retrieved.is_empty() {
452            0.0
453        } else {
454            hits / retrieved.len() as f32
455        };
456        let recall = if relevant.is_empty() {
457            0.0
458        } else {
459            hits / relevant.len() as f32
460        };
461
462        Ok(RetrievalEvaluation {
463            retrieved,
464            precision,
465            recall,
466        })
467    }
468}
469
470#[async_trait]
471pub trait Retriever: Send + Sync {
472    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<String>>;
473}
474
475#[async_trait]
476impl<E, S> Retriever for KnowledgeBase<E, S>
477where
478    E: Embedder,
479    S: VectorStore,
480{
481    async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<String>> {
482        let docs = KnowledgeBase::retrieve(self, query, top_k).await?;
483        Ok(docs.into_iter().map(|d| d.document.text).collect())
484    }
485}
486
487#[derive(Clone)]
488pub struct RetrievalConfig {
489    pub top_k: usize,
490    pub similarity: SimilarityMetric,
491    pub reranker: Option<Reranker>,
492}
493
494impl Default for RetrievalConfig {
495    fn default() -> Self {
496        Self {
497            top_k: 5,
498            similarity: SimilarityMetric::Cosine,
499            reranker: None,
500        }
501    }
502}
503
504#[derive(Clone, Default)]
505pub struct RetrievalOverrides {
506    pub top_k: Option<usize>,
507    pub similarity: Option<SimilarityMetric>,
508    pub reranker: Option<Reranker>,
509}
510
511pub struct RetrievalEvaluation {
512    pub retrieved: Vec<ScoredDocument>,
513    pub precision: f32,
514    pub recall: f32,
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    struct TestEmbedder;
522
523    #[async_trait]
524    impl Embedder for TestEmbedder {
525        async fn embed(&self, text: &str) -> Result<Vec<f32>> {
526            Ok(vec![text.len() as f32])
527        }
528    }
529
530    #[tokio::test]
531    async fn chunks_documents() {
532        let embedder = Arc::new(TestEmbedder);
533        let store = Arc::new(InMemoryVectorStore::default());
534        let kb: KnowledgeBase<_, _> =
535            KnowledgeBase::new(embedder, store).with_chunker(Arc::new(SlidingWindowChunker {
536                max_tokens: 2,
537                overlap: 0,
538            }));
539
540        kb.add_document(Document {
541            id: "doc".into(),
542            text: "a b c d".into(),
543            metadata: Value::Null,
544        })
545        .await
546        .unwrap();
547
548        let scored = kb.retrieve("a b", 10).await.unwrap();
549        assert_eq!(scored.len(), 2);
550    }
551
552    #[tokio::test]
553    async fn evaluates_precision_recall() {
554        let embedder = Arc::new(TestEmbedder);
555        let store = Arc::new(InMemoryVectorStore::default());
556        let kb: KnowledgeBase<_, _> = KnowledgeBase::new(embedder, store);
557
558        kb.add_document(Document {
559            id: "d1".into(),
560            text: "hello world".into(),
561            metadata: Value::Null,
562        })
563        .await
564        .unwrap();
565        kb.add_document(Document {
566            id: "d2".into(),
567            text: "other".into(),
568            metadata: Value::Null,
569        })
570        .await
571        .unwrap();
572
573        let report = kb
574            .evaluate(
575                "hello",
576                &[String::from("d1")],
577                RetrievalOverrides {
578                    top_k: Some(1),
579                    ..Default::default()
580                },
581            )
582            .await
583            .unwrap();
584
585        assert_eq!(report.recall, 1.0);
586        assert_eq!(report.precision, 1.0);
587    }
588}