1#![cfg(feature = "lancedb-store")]
2use std::{
29 collections::HashSet,
30 path::Path,
31 sync::Arc,
32};
33
34use arrow_array::{
35 FixedSizeListArray, Float32Array, Int32Array, Int64Array, RecordBatch, RecordBatchIterator,
36 StringArray,
37};
38use arrow_schema::{DataType, Field, Schema};
39use futures::TryStreamExt as _;
40use lancedb::query::{ExecutableQuery, QueryBase};
41
42use crate::{
43 error::{Error, Result},
44 vector_store::{Chunk, ChunkSearchResult, VecInfo, VectorStore},
45};
46
47const TABLE_NAME: &str = "chunks";
48
49pub const LANCEDB_SCHEMA_VERSION: i32 = 1;
56
57pub fn versioned_dir(root: &Path) -> std::path::PathBuf {
61 root.join(format!("lancedb_v{LANCEDB_SCHEMA_VERSION}"))
62}
63
64struct LanceStore {
67 table: lancedb::Table,
68 dim: i32,
69}
70
71impl LanceStore {
72 async fn open(data_dir: &Path, embedding_dim: u32) -> Result<Self> {
73 std::fs::create_dir_all(data_dir)?;
74 let db = lancedb::connect(data_dir.to_str().unwrap_or_default())
75 .execute()
76 .await
77 .map_err(|e| Error::Embed(e.to_string()))?;
78
79 let dim = embedding_dim as i32;
80 let names = db
81 .table_names()
82 .execute()
83 .await
84 .map_err(|e| Error::Embed(e.to_string()))?;
85
86 let table = if names.contains(&TABLE_NAME.to_string()) {
87 db.open_table(TABLE_NAME)
88 .execute()
89 .await
90 .map_err(|e| Error::Embed(e.to_string()))?
91 } else {
92 let schema = make_schema(dim);
93 let empty = RecordBatch::new_empty(schema.clone());
94 db.create_table(TABLE_NAME, RecordBatchIterator::new(vec![Ok(empty)], schema))
95 .execute()
96 .await
97 .map_err(|e| Error::Embed(e.to_string()))?
98 };
99
100 Ok(LanceStore { table, dim })
101 }
102
103 async fn embedded_chunk_keys(&self) -> Result<HashSet<(i64, usize)>> {
105 let batches: Vec<RecordBatch> = self
106 .table
107 .query()
108 .select(lancedb::query::Select::Columns(vec![
109 "entry_id".to_string(),
110 "chunk_index".to_string(),
111 ]))
112 .execute()
113 .await
114 .map_err(|e| Error::Embed(e.to_string()))?
115 .try_collect()
116 .await
117 .map_err(|e| Error::Embed(e.to_string()))?;
118
119 let mut keys = HashSet::new();
120 for batch in &batches {
121 let entry_ids = batch
122 .column_by_name("entry_id")
123 .and_then(|c| c.as_any().downcast_ref::<Int64Array>());
124 let chunk_idxs = batch
125 .column_by_name("chunk_index")
126 .and_then(|c| c.as_any().downcast_ref::<Int32Array>());
127 if let (Some(eids), Some(cidxs)) = (entry_ids, chunk_idxs) {
128 for i in 0..batch.num_rows() {
129 keys.insert((eids.value(i), cidxs.value(i) as usize));
130 }
131 }
132 }
133 Ok(keys)
134 }
135
136 async fn insert_embeddings(&self, chunks: &[Chunk], embeddings: &[Vec<f32>]) -> Result<()> {
138 if chunks.is_empty() {
139 return Ok(());
140 }
141 let schema = make_schema(self.dim);
142
143 let entry_ids: Vec<i64> = chunks.iter().map(|c| c.entry_id).collect();
144 let chunk_idxs: Vec<i32> = chunks.iter().map(|c| c.chunk_index as i32).collect();
145 let titles: Vec<&str> = chunks.iter().map(|c| c.entry_title.as_str()).collect();
146 let paths: Vec<&str> = chunks.iter().map(|c| c.entry_path.as_str()).collect();
147 let texts: Vec<&str> = chunks.iter().map(|c| c.text.as_str()).collect();
148
149 let batch = RecordBatch::try_new(
150 schema.clone(),
151 vec![
152 Arc::new(Int64Array::from(entry_ids)),
153 Arc::new(Int32Array::from(chunk_idxs)),
154 Arc::new(StringArray::from(titles)),
155 Arc::new(StringArray::from(paths)),
156 Arc::new(StringArray::from(texts)),
157 Arc::new(make_embedding_array(embeddings, self.dim)?),
158 ],
159 )
160 .map_err(|e| Error::Embed(e.to_string()))?;
161
162 self.table
163 .add(RecordBatchIterator::new(vec![Ok(batch)], schema))
164 .execute()
165 .await
166 .map_err(|e| Error::Embed(e.to_string()))?;
167 Ok(())
168 }
169
170 async fn search_similar(
172 &self,
173 query_vec: &[f32],
174 limit: usize,
175 ) -> Result<Vec<ChunkSearchResult>> {
176 let batches: Vec<RecordBatch> = self
177 .table
178 .vector_search(query_vec)
179 .map_err(|e| Error::Embed(e.to_string()))?
180 .column("embedding")
181 .limit(limit)
182 .execute()
183 .await
184 .map_err(|e| Error::Embed(e.to_string()))?
185 .try_collect()
186 .await
187 .map_err(|e| Error::Embed(e.to_string()))?;
188
189 let mut results = Vec::new();
190 for batch in &batches {
191 let entry_ids = batch
192 .column_by_name("entry_id")
193 .and_then(|c| c.as_any().downcast_ref::<Int64Array>())
194 .ok_or_else(|| Error::Embed("missing `entry_id` in search result".into()))?;
195 let chunk_idxs = batch
196 .column_by_name("chunk_index")
197 .and_then(|c| c.as_any().downcast_ref::<Int32Array>())
198 .ok_or_else(|| Error::Embed("missing `chunk_index` in search result".into()))?;
199 let titles = batch
200 .column_by_name("entry_title")
201 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
202 .ok_or_else(|| Error::Embed("missing `entry_title` in search result".into()))?;
203 let paths = batch
204 .column_by_name("entry_path")
205 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
206 .ok_or_else(|| Error::Embed("missing `entry_path` in search result".into()))?;
207 let texts = batch
208 .column_by_name("text")
209 .and_then(|c| c.as_any().downcast_ref::<StringArray>())
210 .ok_or_else(|| Error::Embed("missing `text` in search result".into()))?;
211 let dists = batch
212 .column_by_name("_distance")
213 .and_then(|c| c.as_any().downcast_ref::<Float32Array>())
214 .ok_or_else(|| Error::Embed("missing `_distance` in search result".into()))?;
215
216 for i in 0..batch.num_rows() {
217 results.push(ChunkSearchResult {
218 entry_id: entry_ids.value(i),
219 chunk_index: chunk_idxs.value(i) as usize,
220 entry_title: titles.value(i).to_owned(),
221 entry_path: paths.value(i).to_owned(),
222 chunk_text: texts.value(i).to_owned(),
223 score: dists.value(i) as f64,
224 });
225 }
226 }
227 Ok(results)
228 }
229
230 async fn embedded_count(&self) -> u64 {
231 self.table
232 .count_rows(None)
233 .await
234 .unwrap_or(0) as u64
235 }
236}
237
238pub struct LanceDbVectorStore {
245 inner: LanceStore,
246 rt: tokio::runtime::Runtime,
247}
248
249impl LanceDbVectorStore {
250 pub fn new(data_dir: &Path, embedding_dim: u32) -> Result<Self> {
253 let rt = tokio::runtime::Runtime::new()
254 .map_err(|e| Error::Embed(format!("failed to create Tokio runtime: {e}")))?;
255 let inner = rt.block_on(LanceStore::open(data_dir, embedding_dim))?;
256 Ok(Self { inner, rt })
257 }
258
259 pub fn vec_info(&self, sqlite_conn: &rusqlite::Connection) -> Result<VecInfo> {
261 let vector_count = self.rt.block_on(self.inner.embedded_count());
262 let chunk_count: u64 = sqlite_conn
263 .query_row("SELECT COUNT(*) FROM chunks", [], |row| row.get(0))
264 .unwrap_or(0);
265 Ok(VecInfo {
266 embedding_dim: self.inner.dim as u32,
267 vector_count,
268 pending_count: chunk_count.saturating_sub(vector_count),
269 })
270 }
271}
272
273impl VectorStore for LanceDbVectorStore {
274 fn embedded_chunk_keys(&self) -> Result<HashSet<(i64, usize)>> {
275 self.rt.block_on(self.inner.embedded_chunk_keys())
276 }
277
278 fn insert_embeddings(&self, chunks: &[Chunk], embeddings: &[Vec<f32>]) -> Result<()> {
279 self.rt.block_on(self.inner.insert_embeddings(chunks, embeddings))
280 }
281
282 fn search_similar(&self, query_vec: &[f32], limit: usize) -> Result<Vec<ChunkSearchResult>> {
283 self.rt.block_on(self.inner.search_similar(query_vec, limit))
284 }
285}
286
287fn make_schema(dim: i32) -> Arc<Schema> {
290 Arc::new(Schema::new(vec![
291 Field::new("entry_id", DataType::Int64, false),
292 Field::new("chunk_index", DataType::Int32, false),
293 Field::new("entry_title", DataType::Utf8, false),
294 Field::new("entry_path", DataType::Utf8, false),
295 Field::new("text", DataType::Utf8, false),
296 Field::new(
297 "embedding",
298 DataType::FixedSizeList(
299 Arc::new(Field::new("item", DataType::Float32, true)),
300 dim,
301 ),
302 false,
303 ),
304 ]))
305}
306
307fn make_embedding_array(embeddings: &[Vec<f32>], dim: i32) -> Result<FixedSizeListArray> {
308 let flat: Vec<f32> = embeddings.iter().flat_map(|v| v.iter().copied()).collect();
309 let values = Arc::new(Float32Array::from(flat));
310 FixedSizeListArray::try_new(
311 Arc::new(Field::new("item", DataType::Float32, true)),
312 dim,
313 values,
314 None,
315 )
316 .map_err(|e| Error::Embed(e.to_string()))
317}