1use async_trait::async_trait;
43use serde_json::Value;
44use std::collections::HashMap;
45use std::sync::Arc;
46use synaptic_core::{Document, Embeddings, SynapticError, VectorStore};
47use tokio::sync::RwLock;
48
49#[derive(Debug, Clone)]
51pub struct LanceDbConfig {
52 pub uri: String,
55 pub table_name: String,
57 pub dim: usize,
60}
61
62impl LanceDbConfig {
63 pub fn new(uri: impl Into<String>, table_name: impl Into<String>, dim: usize) -> Self {
65 Self {
66 uri: uri.into(),
67 table_name: table_name.into(),
68 dim,
69 }
70 }
71}
72
73#[derive(Clone)]
75struct Row {
76 id: String,
77 content: String,
78 metadata: HashMap<String, Value>,
79 embedding: Vec<f32>,
80}
81
82pub struct LanceDbVectorStore {
92 config: LanceDbConfig,
93 rows: Arc<RwLock<Vec<Row>>>,
94}
95
96impl LanceDbVectorStore {
97 pub async fn new(config: LanceDbConfig) -> Result<Self, SynapticError> {
102 Ok(Self {
103 config,
104 rows: Arc::new(RwLock::new(Vec::new())),
105 })
106 }
107
108 pub fn config(&self) -> &LanceDbConfig {
110 &self.config
111 }
112
113 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
115 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
116 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
117 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
118 if norm_a == 0.0 || norm_b == 0.0 {
119 0.0
120 } else {
121 dot / (norm_a * norm_b)
122 }
123 }
124
125 async fn knn_search(
127 &self,
128 query: &[f32],
129 k: usize,
130 ) -> Result<Vec<(Document, f32)>, SynapticError> {
131 let rows = self.rows.read().await;
132 let mut scored: Vec<(f32, usize)> = rows
133 .iter()
134 .enumerate()
135 .map(|(i, row)| (Self::cosine_similarity(query, &row.embedding), i))
136 .collect();
137
138 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
140
141 let results = scored
142 .into_iter()
143 .take(k)
144 .map(|(score, i)| {
145 let row = &rows[i];
146 (
147 Document::with_metadata(
148 row.id.clone(),
149 row.content.clone(),
150 row.metadata.clone(),
151 ),
152 score,
153 )
154 })
155 .collect();
156
157 Ok(results)
158 }
159}
160
161#[async_trait]
162impl VectorStore for LanceDbVectorStore {
163 async fn add_documents(
164 &self,
165 docs: Vec<Document>,
166 embeddings: &dyn Embeddings,
167 ) -> Result<Vec<String>, SynapticError> {
168 if docs.is_empty() {
169 return Ok(vec![]);
170 }
171
172 let texts: Vec<&str> = docs.iter().map(|d| d.content.as_str()).collect();
173 let vectors = embeddings.embed_documents(&texts).await?;
174
175 let mut rows = self.rows.write().await;
176 let ids: Vec<String> = docs
177 .into_iter()
178 .zip(vectors)
179 .map(|(doc, vec)| {
180 let id = doc.id.clone();
181 rows.push(Row {
182 id: doc.id,
183 content: doc.content,
184 metadata: doc.metadata,
185 embedding: vec,
186 });
187 id
188 })
189 .collect();
190
191 Ok(ids)
192 }
193
194 async fn similarity_search(
195 &self,
196 query: &str,
197 k: usize,
198 embeddings: &dyn Embeddings,
199 ) -> Result<Vec<Document>, SynapticError> {
200 let results = self
201 .similarity_search_with_score(query, k, embeddings)
202 .await?;
203 Ok(results.into_iter().map(|(doc, _)| doc).collect())
204 }
205
206 async fn similarity_search_with_score(
207 &self,
208 query: &str,
209 k: usize,
210 embeddings: &dyn Embeddings,
211 ) -> Result<Vec<(Document, f32)>, SynapticError> {
212 let qvec = embeddings.embed_query(query).await?;
213 self.knn_search(&qvec, k).await
214 }
215
216 async fn similarity_search_by_vector(
217 &self,
218 embedding: &[f32],
219 k: usize,
220 ) -> Result<Vec<Document>, SynapticError> {
221 let results = self.knn_search(embedding, k).await?;
222 Ok(results.into_iter().map(|(doc, _)| doc).collect())
223 }
224
225 async fn delete(&self, ids: &[&str]) -> Result<(), SynapticError> {
226 let id_set: std::collections::HashSet<&str> = ids.iter().copied().collect();
227 let mut rows = self.rows.write().await;
228 rows.retain(|row| !id_set.contains(row.id.as_str()));
229 Ok(())
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn config_new_sets_fields() {
239 let config = LanceDbConfig::new("/tmp/test_db", "test_table", 1536);
240 assert_eq!(config.uri, "/tmp/test_db");
241 assert_eq!(config.table_name, "test_table");
242 assert_eq!(config.dim, 1536);
243 }
244
245 #[tokio::test]
246 async fn store_new_creates_instance() {
247 let config = LanceDbConfig::new("/tmp/db", "tbl", 4);
248 let store = LanceDbVectorStore::new(config).await.unwrap();
249 assert_eq!(store.config().table_name, "tbl");
250 assert_eq!(store.config().dim, 4);
251 }
252
253 #[test]
254 fn cosine_similarity_identical_vectors() {
255 let v = vec![1.0_f32, 0.0, 0.0];
256 let score = LanceDbVectorStore::cosine_similarity(&v, &v);
257 assert!((score - 1.0).abs() < 1e-6);
258 }
259
260 #[test]
261 fn cosine_similarity_orthogonal_vectors() {
262 let a = vec![1.0_f32, 0.0];
263 let b = vec![0.0_f32, 1.0];
264 let score = LanceDbVectorStore::cosine_similarity(&a, &b);
265 assert!(score.abs() < 1e-6);
266 }
267}