azoth_vector/extension.rs
1//! Extension loading and initialization
2
3use crate::search::validate_sql_identifier;
4use crate::types::VectorConfig;
5use azoth_core::{error::AzothError, Result};
6use rusqlite::Connection;
7use std::path::Path;
8
9/// Load the sqlite-vector extension
10///
11/// The extension binary should be available at the given path.
12/// Pre-built binaries can be downloaded from:
13/// <https://github.com/sqliteai/sqlite-vector/releases>
14///
15/// # Platform-specific defaults
16///
17/// If no path is provided, defaults to:
18/// - Linux: `./libsqlite_vector.so`
19/// - macOS: `./libsqlite_vector.dylib`
20/// - Windows: `./sqlite_vector.dll`
21///
22/// # Safety
23///
24/// This function uses `unsafe` because `rusqlite::Connection::load_extension`
25/// calls `sqlite3_load_extension`, which loads a native shared library (`.so` / `.dylib` / `.dll`)
26/// into the current process. Loading an untrusted library can execute arbitrary code.
27///
28/// **Requirements for safe usage:**
29/// - The extension binary **must** come from a trusted, verified source (e.g., official releases).
30/// - The `path` argument should **never** be derived from user-supplied input.
31/// - Consider validating file checksums before loading in production deployments.
32pub fn load_vector_extension(conn: &Connection, path: Option<&Path>) -> Result<()> {
33 let ext_path = path.unwrap_or_else(|| {
34 #[cfg(target_os = "linux")]
35 {
36 Path::new("./libsqlite_vector.so")
37 }
38 #[cfg(target_os = "macos")]
39 {
40 Path::new("./libsqlite_vector.dylib")
41 }
42 #[cfg(target_os = "windows")]
43 {
44 Path::new("./sqlite_vector.dll")
45 }
46 #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
47 {
48 Path::new("./libsqlite_vector.so")
49 }
50 });
51
52 // Reject paths that contain path-traversal components to prevent loading
53 // arbitrary libraries from unexpected locations.
54 validate_extension_path(ext_path)?;
55
56 unsafe {
57 let _guard = rusqlite::LoadExtensionGuard::new(conn)
58 .map_err(|e| AzothError::Projection(format!("Failed to enable extensions: {}", e)))?;
59
60 conn.load_extension(ext_path, None).map_err(|e| {
61 AzothError::Projection(format!(
62 "Failed to load vector extension from {}: {}",
63 ext_path.display(),
64 e
65 ))
66 })?;
67 }
68
69 tracing::info!("Loaded sqlite-vector extension successfully");
70 Ok(())
71}
72
73/// Validate that an extension path does not contain path-traversal components
74/// or point through symlinks.
75///
76/// Rejects paths containing `..` components and paths that are symlinks,
77/// as these could be used to load arbitrary shared libraries.
78fn validate_extension_path(path: &Path) -> Result<()> {
79 // Reject ".." components anywhere in the path
80 for component in path.components() {
81 if let std::path::Component::ParentDir = component {
82 return Err(AzothError::Config(format!(
83 "Extension path '{}' contains '..' component. \
84 Path traversal is not allowed for extension loading.",
85 path.display()
86 )));
87 }
88 }
89
90 // If the file exists, reject symlinks
91 if path.exists()
92 && path
93 .symlink_metadata()
94 .map(|m| m.file_type().is_symlink())
95 .unwrap_or(false)
96 {
97 return Err(AzothError::Config(format!(
98 "Extension path '{}' is a symbolic link. \
99 Symlinks are not allowed for extension loading.",
100 path.display()
101 )));
102 }
103
104 Ok(())
105}
106
107/// Extend SqliteProjectionStore with vector support
108pub trait VectorExtension {
109 /// Load the sqlite-vector extension
110 ///
111 /// # Example
112 ///
113 /// ```no_run
114 /// use azoth::prelude::*;
115 /// use azoth_vector::VectorExtension;
116 ///
117 /// # fn example() -> Result<()> {
118 /// let db = AzothDb::open("./data")?;
119 /// db.projection().load_vector_extension(None)?;
120 /// # Ok(())
121 /// # }
122 /// ```
123 fn load_vector_extension(&self, path: Option<&Path>) -> Result<()>;
124
125 /// Initialize a vector column
126 ///
127 /// Must be called after creating the table with a BLOB column.
128 ///
129 /// # Example
130 ///
131 /// ```no_run
132 /// use azoth::prelude::*;
133 /// use azoth_vector::{VectorExtension, VectorConfig};
134 ///
135 /// # fn example() -> Result<()> {
136 /// # let db = AzothDb::open("./data")?;
137 /// # db.projection().load_vector_extension(None)?;
138 /// // Create table with BLOB column
139 /// # db.projection().execute(|conn: &rusqlite::Connection| {
140 /// # conn.execute(
141 /// # "CREATE TABLE embeddings (id INTEGER PRIMARY KEY, vector BLOB)",
142 /// # [],
143 /// # ).map_err(|e| azoth::AzothError::Projection(e.to_string()))?;
144 /// # Ok(())
145 /// # })?;
146 /// // Initialize vector column
147 /// # db.projection().vector_init("embeddings", "vector", VectorConfig::default())?;
148 /// # Ok(())
149 /// # }
150 /// ```
151 fn vector_init(&self, table: &str, column: &str, config: VectorConfig) -> Result<()>;
152
153 /// Check if vector extension is loaded
154 ///
155 /// # Example
156 ///
157 /// ```no_run
158 /// use azoth::prelude::*;
159 /// use azoth_vector::VectorExtension;
160 ///
161 /// # fn example() -> Result<()> {
162 /// let db = AzothDb::open("./data")?;
163 /// if !db.projection().has_vector_support() {
164 /// db.projection().load_vector_extension(None)?;
165 /// }
166 /// # Ok(())
167 /// # }
168 /// ```
169 fn has_vector_support(&self) -> bool;
170
171 /// Get the version of the sqlite-vector extension
172 fn vector_version(&self) -> Result<String>;
173}
174
175impl VectorExtension for azoth_sqlite::SqliteProjectionStore {
176 fn load_vector_extension(&self, path: Option<&Path>) -> Result<()> {
177 let conn = self.conn().lock();
178 load_vector_extension(&conn, path)
179 }
180
181 fn vector_init(&self, table: &str, column: &str, config: VectorConfig) -> Result<()> {
182 // Validate identifiers to prevent SQL injection (same check used in VectorSearch::new)
183 validate_sql_identifier(table, "Table")?;
184 validate_sql_identifier(column, "Column")?;
185
186 let conn = self.conn().lock();
187 let config_str = config.to_config_string();
188
189 conn.query_row(
190 &format!("SELECT vector_init('{table}', '{column}', ?)"),
191 [&config_str],
192 |_row| Ok(()),
193 )
194 .map_err(|e| {
195 AzothError::Projection(format!(
196 "Failed to init vector column {}.{}: {}",
197 table, column, e
198 ))
199 })?;
200
201 tracing::info!(
202 "Initialized vector column {}.{} ({})",
203 table,
204 column,
205 &config_str
206 );
207 Ok(())
208 }
209
210 fn has_vector_support(&self) -> bool {
211 let conn = self.conn().lock();
212 let result = conn.prepare("SELECT vector_version()");
213 result.is_ok()
214 }
215
216 fn vector_version(&self) -> Result<String> {
217 let conn = self.conn().lock();
218 let version: String = conn
219 .query_row("SELECT vector_version()", [], |row| row.get(0))
220 .map_err(|e| AzothError::Projection(format!("Failed to get vector version: {}", e)))?;
221 Ok(version)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_extension_not_loaded() {
231 use tempfile::tempdir;
232 let dir = tempdir().unwrap();
233 let db_path = dir.path().join("test.db");
234
235 let conn = Connection::open(&db_path).unwrap();
236
237 // Should fail without extension
238 let result = conn.prepare("SELECT vector_version()");
239 assert!(result.is_err());
240 }
241
242 // Note: We can't easily test extension loading without the actual .so/.dylib file
243 // Integration tests with the extension binary should be in tests/
244}