Skip to main content

azoth_vector/
extension.rs

1//! Extension loading and initialization
2
3use crate::types::VectorConfig;
4use azoth_core::{error::AzothError, Result};
5use rusqlite::Connection;
6use std::path::Path;
7
8/// Load the sqlite-vector extension
9///
10/// The extension binary should be available at the given path.
11/// Pre-built binaries can be downloaded from:
12/// <https://github.com/sqliteai/sqlite-vector/releases>
13///
14/// # Platform-specific defaults
15///
16/// If no path is provided, defaults to:
17/// - Linux: `./libsqlite_vector.so`
18/// - macOS: `./libsqlite_vector.dylib`
19/// - Windows: `./sqlite_vector.dll`
20///
21/// # Safety
22///
23/// This function uses unsafe code to load the extension. The extension must be
24/// a valid SQLite extension and should be from a trusted source.
25pub fn load_vector_extension(conn: &Connection, path: Option<&Path>) -> Result<()> {
26    let ext_path = path.unwrap_or_else(|| {
27        #[cfg(target_os = "linux")]
28        {
29            Path::new("./libsqlite_vector.so")
30        }
31        #[cfg(target_os = "macos")]
32        {
33            Path::new("./libsqlite_vector.dylib")
34        }
35        #[cfg(target_os = "windows")]
36        {
37            Path::new("./sqlite_vector.dll")
38        }
39        #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
40        {
41            Path::new("./libsqlite_vector.so")
42        }
43    });
44
45    unsafe {
46        let _guard = rusqlite::LoadExtensionGuard::new(conn)
47            .map_err(|e| AzothError::Projection(format!("Failed to enable extensions: {}", e)))?;
48
49        conn.load_extension(ext_path, None).map_err(|e| {
50            AzothError::Projection(format!(
51                "Failed to load vector extension from {}: {}",
52                ext_path.display(),
53                e
54            ))
55        })?;
56    }
57
58    tracing::info!("Loaded sqlite-vector extension successfully");
59    Ok(())
60}
61
62/// Extend SqliteProjectionStore with vector support
63pub trait VectorExtension {
64    /// Load the sqlite-vector extension
65    ///
66    /// # Example
67    ///
68    /// ```no_run
69    /// use azoth::prelude::*;
70    /// use azoth_vector::VectorExtension;
71    ///
72    /// # fn example() -> Result<()> {
73    /// let db = AzothDb::open("./data")?;
74    /// db.projection().load_vector_extension(None)?;
75    /// # Ok(())
76    /// # }
77    /// ```
78    fn load_vector_extension(&self, path: Option<&Path>) -> Result<()>;
79
80    /// Initialize a vector column
81    ///
82    /// Must be called after creating the table with a BLOB column.
83    ///
84    /// # Example
85    ///
86    /// ```no_run
87    /// use azoth::prelude::*;
88    /// use azoth_vector::{VectorExtension, VectorConfig};
89    ///
90    /// # fn example() -> Result<()> {
91    /// # let db = AzothDb::open("./data")?;
92    /// # db.projection().load_vector_extension(None)?;
93    /// // Create table with BLOB column
94    /// # db.projection().execute(|conn: &rusqlite::Connection| {
95    /// #     conn.execute(
96    /// #         "CREATE TABLE embeddings (id INTEGER PRIMARY KEY, vector BLOB)",
97    /// #         [],
98    /// #     ).map_err(|e| azoth::AzothError::Projection(e.to_string()))?;
99    /// #     Ok(())
100    /// # })?;
101    /// // Initialize vector column
102    /// # db.projection().vector_init("embeddings", "vector", VectorConfig::default())?;
103    /// # Ok(())
104    /// # }
105    /// ```
106    fn vector_init(&self, table: &str, column: &str, config: VectorConfig) -> Result<()>;
107
108    /// Check if vector extension is loaded
109    ///
110    /// # Example
111    ///
112    /// ```no_run
113    /// use azoth::prelude::*;
114    /// use azoth_vector::VectorExtension;
115    ///
116    /// # fn example() -> Result<()> {
117    /// let db = AzothDb::open("./data")?;
118    /// if !db.projection().has_vector_support() {
119    ///     db.projection().load_vector_extension(None)?;
120    /// }
121    /// # Ok(())
122    /// # }
123    /// ```
124    fn has_vector_support(&self) -> bool;
125
126    /// Get the version of the sqlite-vector extension
127    fn vector_version(&self) -> Result<String>;
128}
129
130impl VectorExtension for azoth_sqlite::SqliteProjectionStore {
131    fn load_vector_extension(&self, path: Option<&Path>) -> Result<()> {
132        let conn = self.conn().lock().unwrap();
133        load_vector_extension(&conn, path)
134    }
135
136    fn vector_init(&self, table: &str, column: &str, config: VectorConfig) -> Result<()> {
137        let conn = self.conn().lock().unwrap();
138        let config_str = config.to_config_string();
139
140        conn.query_row(
141            &format!(
142                "SELECT vector_init('{}', '{}', ?)",
143                table.replace('\'', "''"),
144                column.replace('\'', "''")
145            ),
146            [&config_str],
147            |_row| Ok(()),
148        )
149        .map_err(|e| {
150            AzothError::Projection(format!(
151                "Failed to init vector column {}.{}: {}",
152                table, column, e
153            ))
154        })?;
155
156        tracing::info!(
157            "Initialized vector column {}.{} ({})",
158            table,
159            column,
160            &config_str
161        );
162        Ok(())
163    }
164
165    fn has_vector_support(&self) -> bool {
166        let conn = self.conn().lock().unwrap();
167        let result = conn.prepare("SELECT vector_version()");
168        result.is_ok()
169    }
170
171    fn vector_version(&self) -> Result<String> {
172        let conn = self.conn().lock().unwrap();
173        let version: String = conn
174            .query_row("SELECT vector_version()", [], |row| row.get(0))
175            .map_err(|e| AzothError::Projection(format!("Failed to get vector version: {}", e)))?;
176        Ok(version)
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_extension_not_loaded() {
186        use tempfile::tempdir;
187        let dir = tempdir().unwrap();
188        let db_path = dir.path().join("test.db");
189
190        let conn = Connection::open(&db_path).unwrap();
191
192        // Should fail without extension
193        let result = conn.prepare("SELECT vector_version()");
194        assert!(result.is_err());
195    }
196
197    // Note: We can't easily test extension loading without the actual .so/.dylib file
198    // Integration tests with the extension binary should be in tests/
199}