cognis_rag/retrievers/
vector.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6use tokio::sync::RwLock;
7
8use cognis_core::{Result, Runnable, RunnableConfig};
9
10use crate::document::Document;
11use crate::vectorstore::VectorStore;
12
13pub struct VectorRetriever {
18 store: Arc<RwLock<dyn VectorStore>>,
19 k: usize,
20}
21
22impl VectorRetriever {
23 pub fn new(store: Arc<RwLock<dyn VectorStore>>, k: usize) -> Self {
25 Self { store, k }
26 }
27
28 pub fn with_k(mut self, k: usize) -> Self {
30 self.k = k;
31 self
32 }
33}
34
35#[async_trait]
36impl Runnable<String, Vec<Document>> for VectorRetriever {
37 async fn invoke(&self, query: String, _: RunnableConfig) -> Result<Vec<Document>> {
38 let results = self
39 .store
40 .read()
41 .await
42 .similarity_search(&query, self.k)
43 .await?;
44 Ok(results
45 .into_iter()
46 .map(|r| Document {
47 id: Some(r.id),
48 content: r.text,
49 metadata: r.metadata,
50 })
51 .collect())
52 }
53
54 fn name(&self) -> &str {
55 "VectorRetriever"
56 }
57}
58
59#[cfg(test)]
60mod tests {
61 use super::*;
62 use crate::embeddings::FakeEmbeddings;
63 use crate::vectorstore::InMemoryVectorStore;
64
65 #[tokio::test]
66 async fn retrieves_via_vector_store() {
67 let mut store = InMemoryVectorStore::new(Arc::new(FakeEmbeddings::new(8)));
68 store
69 .add_texts(vec!["hello world".into(), "rust programming".into()], None)
70 .await
71 .unwrap();
72 let store_arc: Arc<RwLock<dyn VectorStore>> = Arc::new(RwLock::new(store));
73
74 let retriever = VectorRetriever::new(store_arc, 2);
75 let docs = retriever
76 .invoke("hello".into(), RunnableConfig::default())
77 .await
78 .unwrap();
79 assert_eq!(docs.len(), 2);
80 }
81}