swiftide_integrations/duckdb/
mod.rs1use std::{
2 collections::HashMap,
3 sync::{Arc, Mutex},
4};
5
6use anyhow::{Context as _, Result};
7use derive_builder::Builder;
8use swiftide_core::{
9 indexing::{Chunk, EmbeddedField},
10 querying::search_strategies::HybridSearch,
11};
12use tera::Context;
13use tokio::sync::RwLock;
14
15pub mod node_cache;
16pub mod persist;
17pub mod retrieve;
18
19const DEFAULT_INDEXING_SCHEMA: &str = include_str!("schema.sql");
20const DEFAULT_UPSERT_QUERY: &str = include_str!("upsert.sql");
21const DEFAULT_HYBRID_QUERY: &str = include_str!("hybrid_query.sql");
22
23#[derive(Clone, Builder)]
35#[builder(setter(into))]
36pub struct Duckdb<T: Chunk = String> {
37 #[builder(setter(custom))]
42 connection: Arc<Mutex<duckdb::Connection>>,
43
44 #[builder(default = "swiftide".into())]
46 table_name: String,
47
48 #[builder(default = self.default_schema())]
54 schema: String,
55
56 #[builder(default)]
58 vectors: HashMap<EmbeddedField, usize>,
59
60 #[builder(default = "256")]
62 batch_size: usize,
63
64 #[builder(private, default = self.default_node_upsert_sql())]
66 node_upsert_sql: String,
67
68 #[builder(default = "swiftide_cache".into())]
70 cache_table: String,
71
72 #[builder(private, default = Arc::new(false.into()))]
74 cache_table_created: Arc<RwLock<bool>>, #[builder(default = "String::new()")]
79 cache_key_prefix: String,
80
81 #[builder(default)]
84 #[allow(dead_code)]
85 upsert_vectors: bool,
86
87 #[builder(default)]
88 chunk_type: std::marker::PhantomData<T>,
89}
90
91impl<T: Chunk> std::fmt::Debug for Duckdb<T> {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("Duckdb")
94 .field("connection", &"Arc<Mutex<duckdb::Connection>>")
95 .field("table_name", &self.table_name)
96 .field("batch_size", &self.batch_size)
97 .finish()
98 }
99}
100
101impl Duckdb<String> {
102 pub fn builder() -> DuckdbBuilder<String> {
103 DuckdbBuilder::<String>::default()
104 }
105}
106impl<T: Chunk> Duckdb<T> {
107 pub fn table_name(&self) -> &str {
113 &self.table_name
114 }
115
116 pub fn cache_table(&self) -> &str {
118 &self.cache_table
119 }
120
121 pub fn connection(&self) -> &Mutex<duckdb::Connection> {
123 &self.connection
124 }
125
126 pub fn create_vector_indices(&self) -> Result<()> {
140 let table_name = &self.table_name;
141 let mut conn = self.connection.lock().unwrap();
142 let tx = conn.transaction().context("Failed to start transaction")?;
143 {
144 for vector in self.vectors.keys() {
145 tx.execute(
146 &format!(
147 "CREATE INDEX IF NOT EXISTS idx_{vector} ON {table_name} USING hnsw ({vector}) WITH (metric = 'cosine')",
148 ),
149 [],
150 )
151 .context("Could not create index")?;
152 }
153 }
154 tx.commit().context("Failed to commit transaction")?;
155 Ok(())
156 }
157
158 pub async fn lazy_create_cache(&self) -> anyhow::Result<()> {
168 if !*self.cache_table_created.read().await {
169 let mut lock = self.cache_table_created.write().await;
170 let conn = self.connection.lock().unwrap();
171 conn.execute(
172 &format!(
173 "CREATE TABLE IF NOT EXISTS {} (uuid TEXT PRIMARY KEY, path TEXT)",
174 self.cache_table
175 ),
176 [],
177 )
178 .context("Could not create table")?;
179 conn.execute(
181 &format!(
182 "CREATE INDEX IF NOT EXISTS idx_path ON {} (path)",
183 self.cache_table
184 ),
185 [],
186 )
187 .context("Could not create index")?;
188 *lock = true;
189 }
190 Ok(())
191 }
192
193 pub fn node_key(&self, node: &swiftide_core::indexing::Node<T>) -> String {
195 format!("{}.{}", self.cache_key_prefix, node.id())
196 }
197
198 fn hybrid_query_sql(
199 &self,
200 search_strategy: &HybridSearch,
201 query: &str,
202 embedding: &[f32],
203 ) -> Result<String> {
204 let table_name = &self.table_name;
205
206 let (field_name, embedding_size) = self
208 .vectors
209 .iter()
210 .next()
211 .context("No vectors configured")?;
212
213 if self.vectors.len() > 1 {
214 tracing::warn!(
215 "Multiple vectors configured, but only the first one will be used: {:?}",
216 self.vectors
217 );
218 }
219
220 let embedding = embedding
221 .iter()
222 .map(ToString::to_string)
223 .collect::<Vec<_>>()
224 .join(",");
225
226 let context = Context::from_value(serde_json::json!({
227 "table_name": table_name,
228 "top_n": search_strategy.top_n(),
229 "top_k": search_strategy.top_k(),
230 "embedding_name": field_name,
231 "embedding_size": embedding_size,
232 "query": wrap_and_escape(query),
233 "embedding": embedding,
234
235
236 }))?;
237
238 let rendered = tera::Tera::one_off(DEFAULT_HYBRID_QUERY, &context, false)?;
239 Ok(rendered)
240 }
241}
242
243fn wrap_and_escape(s: &str) -> String {
244 let quote = '\'';
245 let mut buf = String::new();
246 buf.push(quote);
247 let chars = s.chars();
248 for ch in chars {
249 if ch == quote {
251 buf.push(ch);
252 }
253 buf.push(ch);
254 }
255 buf.push(quote);
256
257 buf
258}
259impl<T: Chunk> DuckdbBuilder<T> {
260 pub fn connection(&mut self, connection: impl Into<duckdb::Connection>) -> &mut Self {
261 self.connection = Some(Arc::new(Mutex::new(connection.into())));
262 self
263 }
264
265 pub fn with_vector(&mut self, field: EmbeddedField, size: usize) -> &mut Self {
266 self.vectors
267 .get_or_insert_with(HashMap::new)
268 .insert(field, size);
269 self
270 }
271
272 fn default_schema(&self) -> String {
273 let mut context = Context::default();
274 context.insert("table_name", &self.table_name);
275 context.insert("vectors", &self.vectors.clone().unwrap_or_default());
276
277 tera::Tera::one_off(DEFAULT_INDEXING_SCHEMA, &context, false)
278 .expect("Could not render schema; infalllible")
279 }
280
281 fn default_node_upsert_sql(&self) -> String {
282 let mut context = Context::default();
283 context.insert("table_name", &self.table_name);
284 context.insert("vectors", &self.vectors.clone().unwrap_or_default());
285 context.insert("upsert_vectors", &self.upsert_vectors);
286
287 context.insert(
288 "vector_field_names",
289 &self
290 .vectors
291 .as_ref()
292 .map(|v| v.keys().collect::<Vec<_>>())
293 .unwrap_or_default(),
294 );
295
296 tracing::info!("Rendering upsert sql");
297 tera::Tera::one_off(DEFAULT_UPSERT_QUERY, &context, false)
298 .expect("could not render upsert query; infallible")
299 }
300}