Skip to main content

cognis_rag/retrievers/
vector.rs

1//! Vector retriever — wraps any [`VectorStore`] in a `Runnable`.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::sync::RwLock;
7
8use cognis_core::{Result, Runnable, RunnableConfig};
9
10use crate::document::Document;
11use crate::vectorstore::VectorStore;
12
13/// Adapts a [`VectorStore`] to a `Runnable<String, Vec<Document>>`.
14///
15/// Holds the store behind an `Arc<RwLock>` so concurrent invocations can
16/// share a single store instance.
17pub struct VectorRetriever {
18    store: Arc<RwLock<dyn VectorStore>>,
19    k: usize,
20}
21
22impl VectorRetriever {
23    /// Wrap a vector store with a target top-k.
24    pub fn new(store: Arc<RwLock<dyn VectorStore>>, k: usize) -> Self {
25        Self { store, k }
26    }
27
28    /// Override `k` (builder-style).
29    pub fn with_k(mut self, k: usize) -> Self {
30        self.k = k;
31        self
32    }
33}
34
35#[async_trait]
36impl Runnable<String, Vec<Document>> for VectorRetriever {
37    async fn invoke(&self, query: String, _: RunnableConfig) -> Result<Vec<Document>> {
38        let results = self
39            .store
40            .read()
41            .await
42            .similarity_search(&query, self.k)
43            .await?;
44        Ok(results
45            .into_iter()
46            .map(|r| Document {
47                id: Some(r.id),
48                content: r.text,
49                metadata: r.metadata,
50            })
51            .collect())
52    }
53
54    fn name(&self) -> &str {
55        "VectorRetriever"
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62    use crate::embeddings::FakeEmbeddings;
63    use crate::vectorstore::InMemoryVectorStore;
64
65    #[tokio::test]
66    async fn retrieves_via_vector_store() {
67        let mut store = InMemoryVectorStore::new(Arc::new(FakeEmbeddings::new(8)));
68        store
69            .add_texts(vec!["hello world".into(), "rust programming".into()], None)
70            .await
71            .unwrap();
72        let store_arc: Arc<RwLock<dyn VectorStore>> = Arc::new(RwLock::new(store));
73
74        let retriever = VectorRetriever::new(store_arc, 2);
75        let docs = retriever
76            .invoke("hello".into(), RunnableConfig::default())
77            .await
78            .unwrap();
79        assert_eq!(docs.len(), 2);
80    }
81}