azoth_vector/extension.rs
1//! Extension loading and initialization
2
3use crate::types::VectorConfig;
4use azoth_core::{error::AzothError, Result};
5use rusqlite::Connection;
6use std::path::Path;
7
8/// Load the sqlite-vector extension
9///
10/// The extension binary should be available at the given path.
11/// Pre-built binaries can be downloaded from:
12/// <https://github.com/sqliteai/sqlite-vector/releases>
13///
14/// # Platform-specific defaults
15///
16/// If no path is provided, defaults to:
17/// - Linux: `./libsqlite_vector.so`
18/// - macOS: `./libsqlite_vector.dylib`
19/// - Windows: `./sqlite_vector.dll`
20///
21/// # Safety
22///
23/// This function uses unsafe code to load the extension. The extension must be
24/// a valid SQLite extension and should be from a trusted source.
25pub fn load_vector_extension(conn: &Connection, path: Option<&Path>) -> Result<()> {
26 let ext_path = path.unwrap_or_else(|| {
27 #[cfg(target_os = "linux")]
28 {
29 Path::new("./libsqlite_vector.so")
30 }
31 #[cfg(target_os = "macos")]
32 {
33 Path::new("./libsqlite_vector.dylib")
34 }
35 #[cfg(target_os = "windows")]
36 {
37 Path::new("./sqlite_vector.dll")
38 }
39 #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
40 {
41 Path::new("./libsqlite_vector.so")
42 }
43 });
44
45 unsafe {
46 let _guard = rusqlite::LoadExtensionGuard::new(conn)
47 .map_err(|e| AzothError::Projection(format!("Failed to enable extensions: {}", e)))?;
48
49 conn.load_extension(ext_path, None).map_err(|e| {
50 AzothError::Projection(format!(
51 "Failed to load vector extension from {}: {}",
52 ext_path.display(),
53 e
54 ))
55 })?;
56 }
57
58 tracing::info!("Loaded sqlite-vector extension successfully");
59 Ok(())
60}
61
62/// Extend SqliteProjectionStore with vector support
63pub trait VectorExtension {
64 /// Load the sqlite-vector extension
65 ///
66 /// # Example
67 ///
68 /// ```no_run
69 /// use azoth::prelude::*;
70 /// use azoth_vector::VectorExtension;
71 ///
72 /// # fn example() -> Result<()> {
73 /// let db = AzothDb::open("./data")?;
74 /// db.projection().load_vector_extension(None)?;
75 /// # Ok(())
76 /// # }
77 /// ```
78 fn load_vector_extension(&self, path: Option<&Path>) -> Result<()>;
79
80 /// Initialize a vector column
81 ///
82 /// Must be called after creating the table with a BLOB column.
83 ///
84 /// # Example
85 ///
86 /// ```no_run
87 /// use azoth::prelude::*;
88 /// use azoth_vector::{VectorExtension, VectorConfig};
89 ///
90 /// # fn example() -> Result<()> {
91 /// # let db = AzothDb::open("./data")?;
92 /// # db.projection().load_vector_extension(None)?;
93 /// // Create table with BLOB column
94 /// # db.projection().execute(|conn: &rusqlite::Connection| {
95 /// # conn.execute(
96 /// # "CREATE TABLE embeddings (id INTEGER PRIMARY KEY, vector BLOB)",
97 /// # [],
98 /// # ).map_err(|e| azoth::AzothError::Projection(e.to_string()))?;
99 /// # Ok(())
100 /// # })?;
101 /// // Initialize vector column
102 /// # db.projection().vector_init("embeddings", "vector", VectorConfig::default())?;
103 /// # Ok(())
104 /// # }
105 /// ```
106 fn vector_init(&self, table: &str, column: &str, config: VectorConfig) -> Result<()>;
107
108 /// Check if vector extension is loaded
109 ///
110 /// # Example
111 ///
112 /// ```no_run
113 /// use azoth::prelude::*;
114 /// use azoth_vector::VectorExtension;
115 ///
116 /// # fn example() -> Result<()> {
117 /// let db = AzothDb::open("./data")?;
118 /// if !db.projection().has_vector_support() {
119 /// db.projection().load_vector_extension(None)?;
120 /// }
121 /// # Ok(())
122 /// # }
123 /// ```
124 fn has_vector_support(&self) -> bool;
125
126 /// Get the version of the sqlite-vector extension
127 fn vector_version(&self) -> Result<String>;
128}
129
130impl VectorExtension for azoth_sqlite::SqliteProjectionStore {
131 fn load_vector_extension(&self, path: Option<&Path>) -> Result<()> {
132 let conn = self.conn().lock().unwrap();
133 load_vector_extension(&conn, path)
134 }
135
136 fn vector_init(&self, table: &str, column: &str, config: VectorConfig) -> Result<()> {
137 let conn = self.conn().lock().unwrap();
138 let config_str = config.to_config_string();
139
140 conn.query_row(
141 &format!(
142 "SELECT vector_init('{}', '{}', ?)",
143 table.replace('\'', "''"),
144 column.replace('\'', "''")
145 ),
146 [&config_str],
147 |_row| Ok(()),
148 )
149 .map_err(|e| {
150 AzothError::Projection(format!(
151 "Failed to init vector column {}.{}: {}",
152 table, column, e
153 ))
154 })?;
155
156 tracing::info!(
157 "Initialized vector column {}.{} ({})",
158 table,
159 column,
160 &config_str
161 );
162 Ok(())
163 }
164
165 fn has_vector_support(&self) -> bool {
166 let conn = self.conn().lock().unwrap();
167 let result = conn.prepare("SELECT vector_version()");
168 result.is_ok()
169 }
170
171 fn vector_version(&self) -> Result<String> {
172 let conn = self.conn().lock().unwrap();
173 let version: String = conn
174 .query_row("SELECT vector_version()", [], |row| row.get(0))
175 .map_err(|e| AzothError::Projection(format!("Failed to get vector version: {}", e)))?;
176 Ok(version)
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_extension_not_loaded() {
186 use tempfile::tempdir;
187 let dir = tempdir().unwrap();
188 let db_path = dir.path().join("test.db");
189
190 let conn = Connection::open(&db_path).unwrap();
191
192 // Should fail without extension
193 let result = conn.prepare("SELECT vector_version()");
194 assert!(result.is_err());
195 }
196
197 // Note: We can't easily test extension loading without the actual .so/.dylib file
198 // Integration tests with the extension binary should be in tests/
199}