Skip to main content

azoth_vector/
migration.rs

1//! Migration helpers for creating vector tables
2
3use crate::types::VectorConfig;
4use azoth_core::{error::AzothError, Result};
5use rusqlite::Connection;
6
7/// Helper to create a table with a vector column
8///
9/// This creates the table and initializes the vector column in one step.
10///
11/// # Example
12///
13/// ```no_run
14/// use azoth::prelude::*;
15/// use azoth_vector::{create_vector_table, VectorConfig};
16/// use rusqlite::Connection;
17///
18/// struct CreateEmbeddingsTable;
19///
20/// impl azoth::Migration for CreateEmbeddingsTable {
21///     fn version(&self) -> u32 { 2 }
22///     fn name(&self) -> &str { "create_embeddings_table" }
23///
24///     fn up(&self, conn: &Connection) -> Result<()> {
25///         create_vector_table(
26///             conn,
27///             "embeddings",
28///             "id INTEGER PRIMARY KEY, text TEXT, vector BLOB",
29///             "vector",
30///             VectorConfig::default(),
31///         )?;
32///         Ok(())
33///     }
34///
35///     fn down(&self, conn: &Connection) -> Result<()> {
36///         conn.execute("DROP TABLE embeddings", [])
37///             .map_err(|e| azoth::AzothError::Projection(e.to_string()))?;
38///         Ok(())
39///     }
40/// }
41/// ```
42pub fn create_vector_table(
43    conn: &Connection,
44    table_name: &str,
45    schema: &str,
46    vector_column: &str,
47    config: VectorConfig,
48) -> Result<()> {
49    // Validate table name and column name (prevent SQL injection)
50    if !is_valid_identifier(table_name) {
51        return Err(AzothError::InvalidState(format!(
52            "Invalid table name: {}",
53            table_name
54        )));
55    }
56    if !is_valid_identifier(vector_column) {
57        return Err(AzothError::InvalidState(format!(
58            "Invalid column name: {}",
59            vector_column
60        )));
61    }
62
63    // Create table
64    conn.execute(&format!("CREATE TABLE {} ({})", table_name, schema), [])
65        .map_err(|e| AzothError::Projection(format!("Failed to create table: {}", e)))?;
66
67    // Initialize vector column (SELECT returns a row)
68    let config_str = config.to_config_string();
69    conn.query_row(
70        &format!(
71            "SELECT vector_init('{}', '{}', ?)",
72            table_name.replace('\'', "''"),
73            vector_column.replace('\'', "''")
74        ),
75        [&config_str],
76        |_row| Ok(()),
77    )
78    .map_err(|e| AzothError::Projection(format!("Failed to init vector column: {}", e)))?;
79
80    tracing::info!(
81        "Created table {} with vector column {} ({})",
82        table_name,
83        vector_column,
84        &config_str
85    );
86
87    Ok(())
88}
89
90/// Helper to add a vector column to an existing table
91///
92/// This adds a BLOB column and initializes it for vector search.
93///
94/// # Example
95///
96/// ```no_run
97/// use azoth::prelude::*;
98/// use azoth_vector::{add_vector_column, VectorConfig};
99/// use rusqlite::Connection;
100///
101/// struct AddVectorToExistingTable;
102///
103/// impl azoth::Migration for AddVectorToExistingTable {
104///     fn version(&self) -> u32 { 3 }
105///     fn name(&self) -> &str { "add_vector_column" }
106///
107///     fn up(&self, conn: &Connection) -> Result<()> {
108///         add_vector_column(
109///             conn,
110///             "documents",
111///             "embedding",
112///             VectorConfig::default(),
113///         )?;
114///         Ok(())
115///     }
116///
117///     fn down(&self, conn: &Connection) -> Result<()> {
118///         // Note: SQLite doesn't support DROP COLUMN in older versions
119///         // You may need to recreate the table
120///         Ok(())
121///     }
122/// }
123/// ```
124pub fn add_vector_column(
125    conn: &Connection,
126    table_name: &str,
127    column_name: &str,
128    config: VectorConfig,
129) -> Result<()> {
130    // Validate identifiers
131    if !is_valid_identifier(table_name) {
132        return Err(AzothError::InvalidState(format!(
133            "Invalid table name: {}",
134            table_name
135        )));
136    }
137    if !is_valid_identifier(column_name) {
138        return Err(AzothError::InvalidState(format!(
139            "Invalid column name: {}",
140            column_name
141        )));
142    }
143
144    // Add BLOB column
145    conn.execute(
146        &format!("ALTER TABLE {} ADD COLUMN {} BLOB", table_name, column_name),
147        [],
148    )
149    .map_err(|e| AzothError::Projection(format!("Failed to add column: {}", e)))?;
150
151    // Initialize vector column
152    let config_str = config.to_config_string();
153    conn.execute(
154        &format!("SELECT vector_init('{}', '{}', ?)", table_name, column_name),
155        [&config_str],
156    )
157    .map_err(|e| AzothError::Projection(format!("Failed to init vector column: {}", e)))?;
158
159    tracing::info!(
160        "Added vector column {}.{} ({})",
161        table_name,
162        column_name,
163        &config_str
164    );
165
166    Ok(())
167}
168
169/// Validate SQL identifier (table/column name)
170///
171/// Allows alphanumeric, underscore, and must start with letter/underscore
172fn is_valid_identifier(name: &str) -> bool {
173    if name.is_empty() {
174        return false;
175    }
176
177    let first_char = name.chars().next().unwrap();
178    if !first_char.is_ascii_alphabetic() && first_char != '_' {
179        return false;
180    }
181
182    name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn test_valid_identifier() {
191        assert!(is_valid_identifier("table_name"));
192        assert!(is_valid_identifier("_private"));
193        assert!(is_valid_identifier("Table123"));
194        assert!(is_valid_identifier("_"));
195
196        assert!(!is_valid_identifier(""));
197        assert!(!is_valid_identifier("123table"));
198        assert!(!is_valid_identifier("table-name"));
199        assert!(!is_valid_identifier("table.name"));
200        assert!(!is_valid_identifier("table name"));
201        assert!(!is_valid_identifier("table'name"));
202    }
203
204    #[test]
205    fn test_create_table_invalid_name() {
206        use tempfile::tempdir;
207        let dir = tempdir().unwrap();
208        let db_path = dir.path().join("test.db");
209        let conn = Connection::open(&db_path).unwrap();
210
211        let result = create_vector_table(
212            &conn,
213            "invalid-name",
214            "id INTEGER",
215            "vector",
216            VectorConfig::default(),
217        );
218
219        assert!(result.is_err());
220    }
221
222    // Full integration tests with vector extension in tests/ directory
223}