synaptic_retrieval/
lib.rs1mod bm25;
2mod compression;
3mod ensemble;
4mod multi_query;
5mod parent_document;
6mod self_query;
7
8pub use bm25::BM25Retriever;
9pub use compression::{ContextualCompressionRetriever, DocumentCompressor, EmbeddingsFilter};
10pub use ensemble::EnsembleRetriever;
11pub use multi_query::MultiQueryRetriever;
12pub use parent_document::ParentDocumentRetriever;
13pub use self_query::{MetadataFieldInfo, SelfQueryRetriever};
14
15use std::collections::{HashMap, HashSet};
16
17use async_trait::async_trait;
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use synaptic_core::SynapseError;
21
22#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
24pub struct Document {
25 pub id: String,
26 pub content: String,
27 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
28 pub metadata: HashMap<String, Value>,
29}
30
31impl Document {
32 pub fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
33 Self {
34 id: id.into(),
35 content: content.into(),
36 metadata: HashMap::new(),
37 }
38 }
39
40 pub fn with_metadata(
41 id: impl Into<String>,
42 content: impl Into<String>,
43 metadata: HashMap<String, Value>,
44 ) -> Self {
45 Self {
46 id: id.into(),
47 content: content.into(),
48 metadata,
49 }
50 }
51}
52
53#[async_trait]
55pub trait Retriever: Send + Sync {
56 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError>;
57}
58
59#[derive(Debug, Clone)]
61pub struct InMemoryRetriever {
62 documents: Vec<Document>,
63}
64
65impl InMemoryRetriever {
66 pub fn new(documents: Vec<Document>) -> Self {
67 Self { documents }
68 }
69}
70
71#[async_trait]
72impl Retriever for InMemoryRetriever {
73 async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<Document>, SynapseError> {
74 let query_terms = tokenize(query);
75 let mut scored: Vec<(usize, &Document)> = self
76 .documents
77 .iter()
78 .map(|doc| {
79 let terms = tokenize(&doc.content);
80 let score = query_terms.intersection(&terms).count();
81 (score, doc)
82 })
83 .collect();
84
85 scored.sort_by(|a, b| b.0.cmp(&a.0));
86 Ok(scored
87 .into_iter()
88 .filter(|(score, _)| *score > 0)
89 .take(top_k)
90 .map(|(_, doc)| doc.clone())
91 .collect())
92 }
93}
94
95pub(crate) fn tokenize(input: &str) -> HashSet<String> {
96 input
97 .split_whitespace()
98 .map(|term| term.to_ascii_lowercase())
99 .collect()
100}
101
102pub(crate) fn tokenize_to_vec(input: &str) -> Vec<String> {
105 input
106 .split_whitespace()
107 .map(|term| term.to_ascii_lowercase())
108 .collect()
109}