Skip to main content

azoth_vector/
extension.rs

1//! Extension loading and initialization
2
3use crate::search::validate_sql_identifier;
4use crate::types::VectorConfig;
5use azoth_core::{error::AzothError, Result};
6use rusqlite::Connection;
7use std::path::Path;
8
9/// Load the sqlite-vector extension
10///
11/// The extension binary should be available at the given path.
12/// Pre-built binaries can be downloaded from:
13/// <https://github.com/sqliteai/sqlite-vector/releases>
14///
15/// # Platform-specific defaults
16///
17/// If no path is provided, defaults to:
18/// - Linux: `./libsqlite_vector.so`
19/// - macOS: `./libsqlite_vector.dylib`
20/// - Windows: `./sqlite_vector.dll`
21///
22/// # Safety
23///
24/// This function uses `unsafe` because `rusqlite::Connection::load_extension`
25/// calls `sqlite3_load_extension`, which loads a native shared library (`.so` / `.dylib` / `.dll`)
26/// into the current process. Loading an untrusted library can execute arbitrary code.
27///
28/// **Requirements for safe usage:**
29/// - The extension binary **must** come from a trusted, verified source (e.g., official releases).
30/// - The `path` argument should **never** be derived from user-supplied input.
31/// - Consider validating file checksums before loading in production deployments.
32pub fn load_vector_extension(conn: &Connection, path: Option<&Path>) -> Result<()> {
33    let ext_path = path.unwrap_or_else(|| {
34        #[cfg(target_os = "linux")]
35        {
36            Path::new("./libsqlite_vector.so")
37        }
38        #[cfg(target_os = "macos")]
39        {
40            Path::new("./libsqlite_vector.dylib")
41        }
42        #[cfg(target_os = "windows")]
43        {
44            Path::new("./sqlite_vector.dll")
45        }
46        #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
47        {
48            Path::new("./libsqlite_vector.so")
49        }
50    });
51
52    // Reject paths that contain path-traversal components to prevent loading
53    // arbitrary libraries from unexpected locations.
54    validate_extension_path(ext_path)?;
55
56    unsafe {
57        let _guard = rusqlite::LoadExtensionGuard::new(conn)
58            .map_err(|e| AzothError::Projection(format!("Failed to enable extensions: {}", e)))?;
59
60        conn.load_extension(ext_path, None).map_err(|e| {
61            AzothError::Projection(format!(
62                "Failed to load vector extension from {}: {}",
63                ext_path.display(),
64                e
65            ))
66        })?;
67    }
68
69    tracing::info!("Loaded sqlite-vector extension successfully");
70    Ok(())
71}
72
73/// Validate that an extension path does not contain path-traversal components
74/// or point through symlinks.
75///
76/// Rejects paths containing `..` components and paths that are symlinks,
77/// as these could be used to load arbitrary shared libraries.
78fn validate_extension_path(path: &Path) -> Result<()> {
79    // Reject ".." components anywhere in the path
80    for component in path.components() {
81        if let std::path::Component::ParentDir = component {
82            return Err(AzothError::Config(format!(
83                "Extension path '{}' contains '..' component. \
84                 Path traversal is not allowed for extension loading.",
85                path.display()
86            )));
87        }
88    }
89
90    // If the file exists, reject symlinks
91    if path.exists()
92        && path
93            .symlink_metadata()
94            .map(|m| m.file_type().is_symlink())
95            .unwrap_or(false)
96    {
97        return Err(AzothError::Config(format!(
98            "Extension path '{}' is a symbolic link. \
99             Symlinks are not allowed for extension loading.",
100            path.display()
101        )));
102    }
103
104    Ok(())
105}
106
107/// Extend SqliteProjectionStore with vector support
108pub trait VectorExtension {
109    /// Load the sqlite-vector extension
110    ///
111    /// # Example
112    ///
113    /// ```no_run
114    /// use azoth::prelude::*;
115    /// use azoth_vector::VectorExtension;
116    ///
117    /// # fn example() -> Result<()> {
118    /// let db = AzothDb::open("./data")?;
119    /// db.projection().load_vector_extension(None)?;
120    /// # Ok(())
121    /// # }
122    /// ```
123    fn load_vector_extension(&self, path: Option<&Path>) -> Result<()>;
124
125    /// Initialize a vector column
126    ///
127    /// Must be called after creating the table with a BLOB column.
128    ///
129    /// # Example
130    ///
131    /// ```no_run
132    /// use azoth::prelude::*;
133    /// use azoth_vector::{VectorExtension, VectorConfig};
134    ///
135    /// # fn example() -> Result<()> {
136    /// # let db = AzothDb::open("./data")?;
137    /// # db.projection().load_vector_extension(None)?;
138    /// // Create table with BLOB column
139    /// # db.projection().execute(|conn: &rusqlite::Connection| {
140    /// #     conn.execute(
141    /// #         "CREATE TABLE embeddings (id INTEGER PRIMARY KEY, vector BLOB)",
142    /// #         [],
143    /// #     ).map_err(|e| azoth::AzothError::Projection(e.to_string()))?;
144    /// #     Ok(())
145    /// # })?;
146    /// // Initialize vector column
147    /// # db.projection().vector_init("embeddings", "vector", VectorConfig::default())?;
148    /// # Ok(())
149    /// # }
150    /// ```
151    fn vector_init(&self, table: &str, column: &str, config: VectorConfig) -> Result<()>;
152
153    /// Check if vector extension is loaded
154    ///
155    /// # Example
156    ///
157    /// ```no_run
158    /// use azoth::prelude::*;
159    /// use azoth_vector::VectorExtension;
160    ///
161    /// # fn example() -> Result<()> {
162    /// let db = AzothDb::open("./data")?;
163    /// if !db.projection().has_vector_support() {
164    ///     db.projection().load_vector_extension(None)?;
165    /// }
166    /// # Ok(())
167    /// # }
168    /// ```
169    fn has_vector_support(&self) -> bool;
170
171    /// Get the version of the sqlite-vector extension
172    fn vector_version(&self) -> Result<String>;
173}
174
175impl VectorExtension for azoth_sqlite::SqliteProjectionStore {
176    fn load_vector_extension(&self, path: Option<&Path>) -> Result<()> {
177        let conn = self.conn().lock();
178        load_vector_extension(&conn, path)
179    }
180
181    fn vector_init(&self, table: &str, column: &str, config: VectorConfig) -> Result<()> {
182        // Validate identifiers to prevent SQL injection (same check used in VectorSearch::new)
183        validate_sql_identifier(table, "Table")?;
184        validate_sql_identifier(column, "Column")?;
185
186        let conn = self.conn().lock();
187        let config_str = config.to_config_string();
188
189        conn.query_row(
190            &format!("SELECT vector_init('{table}', '{column}', ?)"),
191            [&config_str],
192            |_row| Ok(()),
193        )
194        .map_err(|e| {
195            AzothError::Projection(format!(
196                "Failed to init vector column {}.{}: {}",
197                table, column, e
198            ))
199        })?;
200
201        tracing::info!(
202            "Initialized vector column {}.{} ({})",
203            table,
204            column,
205            &config_str
206        );
207        Ok(())
208    }
209
210    fn has_vector_support(&self) -> bool {
211        let conn = self.conn().lock();
212        let result = conn.prepare("SELECT vector_version()");
213        result.is_ok()
214    }
215
216    fn vector_version(&self) -> Result<String> {
217        let conn = self.conn().lock();
218        let version: String = conn
219            .query_row("SELECT vector_version()", [], |row| row.get(0))
220            .map_err(|e| AzothError::Projection(format!("Failed to get vector version: {}", e)))?;
221        Ok(version)
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_extension_not_loaded() {
231        use tempfile::tempdir;
232        let dir = tempdir().unwrap();
233        let db_path = dir.path().join("test.db");
234
235        let conn = Connection::open(&db_path).unwrap();
236
237        // Should fail without extension
238        let result = conn.prepare("SELECT vector_version()");
239        assert!(result.is_err());
240    }
241
242    // Note: We can't easily test extension loading without the actual .so/.dylib file
243    // Integration tests with the extension binary should be in tests/
244}