swiftide_integrations/duckdb/
mod.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, Mutex},
4};
5
6use anyhow::{Context as _, Result};
7use derive_builder::Builder;
8use swiftide_core::{
9    indexing::{Chunk, EmbeddedField},
10    querying::search_strategies::HybridSearch,
11};
12use tera::Context;
13use tokio::sync::RwLock;
14
15pub mod node_cache;
16pub mod persist;
17pub mod retrieve;
18
19const DEFAULT_INDEXING_SCHEMA: &str = include_str!("schema.sql");
20const DEFAULT_UPSERT_QUERY: &str = include_str!("upsert.sql");
21const DEFAULT_HYBRID_QUERY: &str = include_str!("hybrid_query.sql");
22
23/// Provides `Persist`, `Retrieve`, and `NodeCache` for duckdb
24///
25/// Unfortunately Metadata is not stored.
26///
27/// Supports the following search strategies:
28/// - `SimilaritySingleEmbedding`
29/// - `HybridSearch` (<https://motherduck.com/blog/search-using-duckdb-part-3>/)
30/// - Custom
31///
32/// NOTE: The integration is not optimized for ultra large datasets / load. It might work, if it
33/// doesn't let us know <3.
34#[derive(Clone, Builder)]
35#[builder(setter(into))]
36pub struct Duckdb<T: Chunk = String> {
37    /// The connection to the database
38    ///
39    /// Note that this uses the tokio version of a mutex because the duckdb connection contains a
40    /// `RefCell`. This is not ideal, but it is what it is.
41    #[builder(setter(custom))]
42    connection: Arc<Mutex<duckdb::Connection>>,
43
44    /// The name of the table to use for storing nodes. Defaults to "swiftide".
45    #[builder(default = "swiftide".into())]
46    table_name: String,
47
48    /// The schema to use for the table
49    ///
50    /// Note that if you change the schema, you probably also need to change the upsert query.
51    ///
52    /// Additionally, if you intend to use vectors, you must install and load the vss extension.
53    #[builder(default = self.default_schema())]
54    schema: String,
55
56    // The vectors to be stored, field name -> size
57    #[builder(default)]
58    vectors: HashMap<EmbeddedField, usize>,
59
60    /// Batch size for storing nodes
61    #[builder(default = "256")]
62    batch_size: usize,
63
64    /// Sql to upsert a node
65    #[builder(private, default = self.default_node_upsert_sql())]
66    node_upsert_sql: String,
67
68    /// Name of the table to use for caching nodes. Defaults to `"swiftide_cache"`.
69    #[builder(default = "swiftide_cache".into())]
70    cache_table: String,
71
72    /// Tracks if the cache table has been created
73    #[builder(private, default = Arc::new(false.into()))]
74    cache_table_created: Arc<RwLock<bool>>, // note might need a mutex
75
76    /// Prefix to be used for keys stored in the database to avoid collisions. Can be used to
77    /// manually invalidate the cache.
78    #[builder(default = "String::new()")]
79    cache_key_prefix: String,
80
81    /// If enabled, vectors will be upserted with an ON CONFLICT DO UPDATE. If disabled, ON
82    /// conflict does nothing. Requires `duckdb` >= 1.2.1
83    #[builder(default)]
84    #[allow(dead_code)]
85    upsert_vectors: bool,
86
87    #[builder(default)]
88    chunk_type: std::marker::PhantomData<T>,
89}
90
91impl<T: Chunk> std::fmt::Debug for Duckdb<T> {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("Duckdb")
94            .field("connection", &"Arc<Mutex<duckdb::Connection>>")
95            .field("table_name", &self.table_name)
96            .field("batch_size", &self.batch_size)
97            .finish()
98    }
99}
100
101impl Duckdb<String> {
102    pub fn builder() -> DuckdbBuilder<String> {
103        DuckdbBuilder::<String>::default()
104    }
105}
106impl<T: Chunk> Duckdb<T> {
107    // pub fn builder() -> DuckdbBuilder<String> {
108    //     DuckdbBuilder::<String>::default()
109    // }
110
111    /// Name of the indexing table
112    pub fn table_name(&self) -> &str {
113        &self.table_name
114    }
115
116    /// Name of the cache table
117    pub fn cache_table(&self) -> &str {
118        &self.cache_table
119    }
120
121    /// Returns the connection to the database
122    pub fn connection(&self) -> &Mutex<duckdb::Connection> {
123        &self.connection
124    }
125
126    /// Creates HNSW indices on the vector fields
127    ///
128    /// These are *not* persisted. You must recreate them on startup.
129    ///
130    /// If you want to persist them, refer to the duckdb documentation.
131    ///
132    /// # Errors
133    ///
134    /// Errors if the connection or statement fails
135    ///
136    /// # Panics
137    ///
138    /// If the mutex locking the connection is poisoned
139    pub fn create_vector_indices(&self) -> Result<()> {
140        let table_name = &self.table_name;
141        let mut conn = self.connection.lock().unwrap();
142        let tx = conn.transaction().context("Failed to start transaction")?;
143        {
144            for vector in self.vectors.keys() {
145                tx.execute(
146                    &format!(
147                        "CREATE INDEX IF NOT EXISTS idx_{vector} ON {table_name} USING hnsw ({vector}) WITH (metric = 'cosine')",
148                    ),
149                    [],
150                )
151                .context("Could not create index")?;
152            }
153        }
154        tx.commit().context("Failed to commit transaction")?;
155        Ok(())
156    }
157
158    /// Safely creates the cache table if it does not exist. Can be used concurrently
159    ///
160    /// # Errors
161    ///
162    /// Errors if the table or index could not be created
163    ///
164    /// # Panics
165    ///
166    /// If the mutex locking the connection is poisoned
167    pub async fn lazy_create_cache(&self) -> anyhow::Result<()> {
168        if !*self.cache_table_created.read().await {
169            let mut lock = self.cache_table_created.write().await;
170            let conn = self.connection.lock().unwrap();
171            conn.execute(
172                &format!(
173                    "CREATE TABLE IF NOT EXISTS {} (uuid TEXT PRIMARY KEY, path TEXT)",
174                    self.cache_table
175                ),
176                [],
177            )
178            .context("Could not create table")?;
179            // Create an extra index on path
180            conn.execute(
181                &format!(
182                    "CREATE INDEX IF NOT EXISTS idx_path ON {} (path)",
183                    self.cache_table
184                ),
185                [],
186            )
187            .context("Could not create index")?;
188            *lock = true;
189        }
190        Ok(())
191    }
192
193    /// Formats a node key for the cache table
194    pub fn node_key(&self, node: &swiftide_core::indexing::Node<T>) -> String {
195        format!("{}.{}", self.cache_key_prefix, node.id())
196    }
197
198    fn hybrid_query_sql(
199        &self,
200        search_strategy: &HybridSearch,
201        query: &str,
202        embedding: &[f32],
203    ) -> Result<String> {
204        let table_name = &self.table_name;
205
206        // Silently ignores multiple vector fields
207        let (field_name, embedding_size) = self
208            .vectors
209            .iter()
210            .next()
211            .context("No vectors configured")?;
212
213        if self.vectors.len() > 1 {
214            tracing::warn!(
215                "Multiple vectors configured, but only the first one will be used: {:?}",
216                self.vectors
217            );
218        }
219
220        let embedding = embedding
221            .iter()
222            .map(ToString::to_string)
223            .collect::<Vec<_>>()
224            .join(",");
225
226        let context = Context::from_value(serde_json::json!({
227            "table_name": table_name,
228            "top_n": search_strategy.top_n(),
229            "top_k": search_strategy.top_k(),
230            "embedding_name": field_name,
231            "embedding_size": embedding_size,
232            "query": wrap_and_escape(query),
233            "embedding": embedding,
234
235
236        }))?;
237
238        let rendered = tera::Tera::one_off(DEFAULT_HYBRID_QUERY, &context, false)?;
239        Ok(rendered)
240    }
241}
242
243fn wrap_and_escape(s: &str) -> String {
244    let quote = '\'';
245    let mut buf = String::new();
246    buf.push(quote);
247    let chars = s.chars();
248    for ch in chars {
249        // escape `quote` by doubling it
250        if ch == quote {
251            buf.push(ch);
252        }
253        buf.push(ch);
254    }
255    buf.push(quote);
256
257    buf
258}
259impl<T: Chunk> DuckdbBuilder<T> {
260    pub fn connection(&mut self, connection: impl Into<duckdb::Connection>) -> &mut Self {
261        self.connection = Some(Arc::new(Mutex::new(connection.into())));
262        self
263    }
264
265    pub fn with_vector(&mut self, field: EmbeddedField, size: usize) -> &mut Self {
266        self.vectors
267            .get_or_insert_with(HashMap::new)
268            .insert(field, size);
269        self
270    }
271
272    fn default_schema(&self) -> String {
273        let mut context = Context::default();
274        context.insert("table_name", &self.table_name);
275        context.insert("vectors", &self.vectors.clone().unwrap_or_default());
276
277        tera::Tera::one_off(DEFAULT_INDEXING_SCHEMA, &context, false)
278            .expect("Could not render schema; infalllible")
279    }
280
281    fn default_node_upsert_sql(&self) -> String {
282        let mut context = Context::default();
283        context.insert("table_name", &self.table_name);
284        context.insert("vectors", &self.vectors.clone().unwrap_or_default());
285        context.insert("upsert_vectors", &self.upsert_vectors);
286
287        context.insert(
288            "vector_field_names",
289            &self
290                .vectors
291                .as_ref()
292                .map(|v| v.keys().collect::<Vec<_>>())
293                .unwrap_or_default(),
294        );
295
296        tracing::info!("Rendering upsert sql");
297        tera::Tera::one_off(DEFAULT_UPSERT_QUERY, &context, false)
298            .expect("could not render upsert query; infallible")
299    }
300}