Skip to main content

next_plaid/
filtering.rs

1//! SQLite-based metadata filtering for next-plaid indices.
2//!
3//! This module provides functionality for storing, querying, and managing
4//! document metadata using SQLite, enabling efficient filtering during search.
5//!
6//! The API matches fast-plaid's `filtering.py` for compatibility.
7//!
8//! # Example
9//!
10//! ```ignore
11//! use next-plaid::filtering;
12//! use serde_json::json;
13//!
14//! // Create metadata for documents
15//! let metadata = vec![
16//!     json!({"name": "Alice", "category": "A", "score": 95}),
17//!     json!({"name": "Bob", "category": "B", "score": 87}),
18//! ];
19//!
20//! // Create metadata database
21//! filtering::create("my_index", &metadata)?;
22//!
23//! // Query documents matching a condition
24//! let subset = filtering::where_condition(
25//!     "my_index",
26//!     "category = ? AND score > ?",
27//!     &[json!("A"), json!(90)],
28//! )?;
29//!
30//! // Use subset in search
31//! let results = index.search(&query, &params, Some(&subset))?;
32//! ```
33
34use std::collections::{HashMap, HashSet};
35use std::fs;
36use std::path::Path;
37
38use regex::Regex;
39use rusqlite::{params_from_iter, Connection, Result as SqliteResult, ToSql};
40use serde_json::Value;
41
42use crate::error::{Error, Result};
43
44/// Database file name within the index directory.
45const METADATA_DB_NAME: &str = "metadata.db";
46
47/// Primary key column name (matches fast-plaid).
48const SUBSET_COLUMN: &str = "_subset_";
49
50/// Validate that a column name is a safe SQL identifier.
51///
52/// Column names must start with a letter or underscore, followed by
53/// letters, digits, or underscores. This prevents SQL injection.
54fn is_valid_column_name(name: &str) -> bool {
55    lazy_static_regex().is_match(name)
56}
57
58fn lazy_static_regex() -> &'static Regex {
59    use std::sync::OnceLock;
60    static REGEX: OnceLock<Regex> = OnceLock::new();
61    REGEX.get_or_init(|| Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_]*$").unwrap())
62}
63
64// =============================================================================
65// SQL Condition Validator
66// =============================================================================
67//
68// This module provides a safe SQL condition validator using a tokenizer and
69// recursive descent parser. It whitelists safe SQL operators and validates
70// column names against the database schema to prevent SQL injection.
71
72/// Token types for SQL condition parsing.
73#[derive(Debug, Clone, PartialEq)]
74enum Token {
75    Identifier(String),
76    Placeholder, // ?
77    // Comparison operators
78    Eq, // =
79    Ne, // != or <>
80    Lt, // <
81    Le, // <=
82    Gt, // >
83    Ge, // >=
84    // Keywords
85    Like,
86    Regexp,
87    Between,
88    In,
89    And,
90    Or,
91    Not,
92    Is,
93    Null,
94    // Delimiters
95    LParen,
96    RParen,
97    Comma,
98    // End of input
99    Eof,
100}
101
102/// Quick safety check to reject obviously dangerous patterns before tokenization.
103fn quick_safety_check(condition: &str) -> Result<()> {
104    let upper = condition.to_uppercase();
105
106    // Check for comment syntax
107    if condition.contains("--") || condition.contains("/*") || condition.contains("*/") {
108        return Err(Error::Filtering(
109            "SQL comments are not allowed in conditions".into(),
110        ));
111    }
112
113    // Check for statement terminators
114    if condition.contains(';') {
115        return Err(Error::Filtering(
116            "Semicolons are not allowed in conditions".into(),
117        ));
118    }
119
120    // Check for dangerous SQL keywords (must be whole words)
121    let dangerous_keywords = [
122        "SELECT", "UNION", "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "TRUNCATE",
123        "EXEC", "EXECUTE", "GRANT", "REVOKE",
124    ];
125
126    for keyword in dangerous_keywords {
127        // Check if keyword appears as a whole word
128        let pattern = format!(r"\b{}\b", keyword);
129        if Regex::new(&pattern).unwrap().is_match(&upper) {
130            return Err(Error::Filtering(format!(
131                "SQL keyword '{}' is not allowed in conditions",
132                keyword
133            )));
134        }
135    }
136
137    Ok(())
138}
139
140/// Tokenize a SQL condition string into tokens.
141fn tokenize(input: &str) -> Result<Vec<Token>> {
142    let mut tokens = Vec::new();
143    let chars: Vec<char> = input.chars().collect();
144    let mut pos = 0;
145
146    while pos < chars.len() {
147        // Skip whitespace
148        if chars[pos].is_whitespace() {
149            pos += 1;
150            continue;
151        }
152
153        // Single-character tokens
154        match chars[pos] {
155            '?' => {
156                tokens.push(Token::Placeholder);
157                pos += 1;
158                continue;
159            }
160            '(' => {
161                tokens.push(Token::LParen);
162                pos += 1;
163                continue;
164            }
165            ')' => {
166                tokens.push(Token::RParen);
167                pos += 1;
168                continue;
169            }
170            ',' => {
171                tokens.push(Token::Comma);
172                pos += 1;
173                continue;
174            }
175            '=' => {
176                tokens.push(Token::Eq);
177                pos += 1;
178                continue;
179            }
180            _ => {}
181        }
182
183        // Two-character operators
184        if pos + 1 < chars.len() {
185            let two_chars: String = chars[pos..pos + 2].iter().collect();
186            match two_chars.as_str() {
187                "!=" => {
188                    tokens.push(Token::Ne);
189                    pos += 2;
190                    continue;
191                }
192                "<>" => {
193                    tokens.push(Token::Ne);
194                    pos += 2;
195                    continue;
196                }
197                "<=" => {
198                    tokens.push(Token::Le);
199                    pos += 2;
200                    continue;
201                }
202                ">=" => {
203                    tokens.push(Token::Ge);
204                    pos += 2;
205                    continue;
206                }
207                _ => {}
208            }
209        }
210
211        // Single-character comparison operators (checked after two-char)
212        match chars[pos] {
213            '<' => {
214                tokens.push(Token::Lt);
215                pos += 1;
216                continue;
217            }
218            '>' => {
219                tokens.push(Token::Gt);
220                pos += 1;
221                continue;
222            }
223            _ => {}
224        }
225
226        // Identifiers and keywords
227        if chars[pos].is_alphabetic() || chars[pos] == '_' {
228            let start = pos;
229            while pos < chars.len() && (chars[pos].is_alphanumeric() || chars[pos] == '_') {
230                pos += 1;
231            }
232            let word: String = chars[start..pos].iter().collect();
233            let upper = word.to_uppercase();
234
235            let token = match upper.as_str() {
236                "AND" => Token::And,
237                "OR" => Token::Or,
238                "NOT" => Token::Not,
239                "IS" => Token::Is,
240                "NULL" => Token::Null,
241                "LIKE" => Token::Like,
242                "REGEXP" => Token::Regexp,
243                "BETWEEN" => Token::Between,
244                "IN" => Token::In,
245                _ => Token::Identifier(word),
246            };
247            tokens.push(token);
248            continue;
249        }
250
251        // Quoted identifier (double quotes)
252        if chars[pos] == '"' {
253            pos += 1; // skip opening quote
254            let start = pos;
255            while pos < chars.len() && chars[pos] != '"' {
256                pos += 1;
257            }
258            if pos >= chars.len() {
259                return Err(Error::Filtering("Unterminated quoted identifier".into()));
260            }
261            let word: String = chars[start..pos].iter().collect();
262            tokens.push(Token::Identifier(word));
263            pos += 1; // skip closing quote
264            continue;
265        }
266
267        // Reject unexpected characters
268        return Err(Error::Filtering(format!(
269            "Unexpected character '{}' in condition",
270            chars[pos]
271        )));
272    }
273
274    tokens.push(Token::Eof);
275    Ok(tokens)
276}
277
278/// Recursive descent parser/validator for SQL conditions.
279struct ConditionValidator<'a> {
280    tokens: &'a [Token],
281    pos: usize,
282    valid_columns: &'a HashSet<String>,
283    columns_used: Vec<String>,
284}
285
286impl<'a> ConditionValidator<'a> {
287    fn new(tokens: &'a [Token], valid_columns: &'a HashSet<String>) -> Self {
288        Self {
289            tokens,
290            pos: 0,
291            valid_columns,
292            columns_used: Vec::new(),
293        }
294    }
295
296    fn current(&self) -> &Token {
297        self.tokens.get(self.pos).unwrap_or(&Token::Eof)
298    }
299
300    fn advance(&mut self) {
301        if self.pos < self.tokens.len() {
302            self.pos += 1;
303        }
304    }
305
306    fn expect(&mut self, expected: &Token) -> Result<()> {
307        if self.current() == expected {
308            self.advance();
309            Ok(())
310        } else {
311            Err(Error::Filtering(format!(
312                "Expected {:?}, found {:?}",
313                expected,
314                self.current()
315            )))
316        }
317    }
318
319    /// Validate the entire condition.
320    fn validate(&mut self) -> Result<()> {
321        self.parse_expr()?;
322        if *self.current() != Token::Eof {
323            return Err(Error::Filtering(format!(
324                "Unexpected token {:?} after expression",
325                self.current()
326            )));
327        }
328        Ok(())
329    }
330
331    /// expr = and_expr (OR and_expr)*
332    fn parse_expr(&mut self) -> Result<()> {
333        self.parse_and_expr()?;
334        while *self.current() == Token::Or {
335            self.advance();
336            self.parse_and_expr()?;
337        }
338        Ok(())
339    }
340
341    /// and_expr = unary_expr (AND unary_expr)*
342    fn parse_and_expr(&mut self) -> Result<()> {
343        self.parse_unary_expr()?;
344        while *self.current() == Token::And {
345            self.advance();
346            self.parse_unary_expr()?;
347        }
348        Ok(())
349    }
350
351    /// unary_expr = NOT? primary_expr
352    fn parse_unary_expr(&mut self) -> Result<()> {
353        if *self.current() == Token::Not {
354            self.advance();
355        }
356        self.parse_primary_expr()
357    }
358
359    /// primary_expr = comparison | null_check | between_expr | in_expr | "(" expr ")"
360    fn parse_primary_expr(&mut self) -> Result<()> {
361        // Parenthesized expression
362        if *self.current() == Token::LParen {
363            self.advance();
364            self.parse_expr()?;
365            self.expect(&Token::RParen)?;
366            return Ok(());
367        }
368
369        // Must start with an identifier
370        let col_name = match self.current().clone() {
371            Token::Identifier(name) => name,
372            other => {
373                return Err(Error::Filtering(format!(
374                    "Expected column name, found {:?}",
375                    other
376                )))
377            }
378        };
379
380        // Validate column name against schema
381        // Case-insensitive comparison
382        let col_lower = col_name.to_lowercase();
383        let valid = self
384            .valid_columns
385            .iter()
386            .any(|c| c.to_lowercase() == col_lower);
387        if !valid {
388            return Err(Error::Filtering(format!(
389                "Unknown column '{}' in condition",
390                col_name
391            )));
392        }
393        self.columns_used.push(col_name);
394        self.advance();
395
396        // Determine what follows the identifier
397        match self.current() {
398            // IS [NOT] NULL
399            Token::Is => {
400                self.advance();
401                if *self.current() == Token::Not {
402                    self.advance();
403                }
404                self.expect(&Token::Null)?;
405            }
406
407            // [NOT] BETWEEN ? AND ?
408            Token::Not => {
409                self.advance();
410                match self.current() {
411                    Token::Between => {
412                        self.advance();
413                        self.expect(&Token::Placeholder)?;
414                        self.expect(&Token::And)?;
415                        self.expect(&Token::Placeholder)?;
416                    }
417                    Token::In => {
418                        self.advance();
419                        self.parse_in_list()?;
420                    }
421                    Token::Like => {
422                        self.advance();
423                        self.expect(&Token::Placeholder)?;
424                    }
425                    Token::Regexp => {
426                        self.advance();
427                        self.expect(&Token::Placeholder)?;
428                    }
429                    _ => {
430                        return Err(Error::Filtering(format!(
431                            "Expected BETWEEN, IN, LIKE, or REGEXP after NOT, found {:?}",
432                            self.current()
433                        )));
434                    }
435                }
436            }
437
438            Token::Between => {
439                self.advance();
440                self.expect(&Token::Placeholder)?;
441                self.expect(&Token::And)?;
442                self.expect(&Token::Placeholder)?;
443            }
444
445            // [NOT] IN (?, ?, ...)
446            Token::In => {
447                self.advance();
448                self.parse_in_list()?;
449            }
450
451            // [NOT] LIKE ?
452            Token::Like => {
453                self.advance();
454                self.expect(&Token::Placeholder)?;
455            }
456
457            // [NOT] REGEXP ?
458            Token::Regexp => {
459                self.advance();
460                self.expect(&Token::Placeholder)?;
461            }
462
463            // Comparison operators: = != <> < <= > >=
464            Token::Eq | Token::Ne | Token::Lt | Token::Le | Token::Gt | Token::Ge => {
465                self.advance();
466                self.expect(&Token::Placeholder)?;
467            }
468
469            other => {
470                return Err(Error::Filtering(format!(
471                    "Expected operator after column name, found {:?}",
472                    other
473                )));
474            }
475        }
476
477        Ok(())
478    }
479
480    /// Parse IN list: (?, ?, ...)
481    fn parse_in_list(&mut self) -> Result<()> {
482        self.expect(&Token::LParen)?;
483        self.expect(&Token::Placeholder)?;
484        while *self.current() == Token::Comma {
485            self.advance();
486            self.expect(&Token::Placeholder)?;
487        }
488        self.expect(&Token::RParen)?;
489        Ok(())
490    }
491}
492
493/// Get column names from the database schema.
494fn get_schema_columns(conn: &Connection) -> Result<HashSet<String>> {
495    let mut columns = HashSet::new();
496    let mut stmt = conn.prepare("PRAGMA table_info(METADATA)")?;
497    let rows = stmt.query_map([], |row| row.get::<_, String>(1))?;
498    for row in rows {
499        columns.insert(row?);
500    }
501    Ok(columns)
502}
503
504/// Validate a SQL WHERE condition against the allowed grammar and schema.
505///
506/// This function performs security validation on user-provided SQL conditions:
507/// 1. Quick safety check rejects dangerous patterns (comments, semicolons, DDL keywords)
508/// 2. Tokenization converts the condition to a safe token stream
509/// 3. Recursive descent parsing validates the condition against an allowlist grammar
510/// 4. Column validation ensures only known columns are referenced
511///
512/// # Allowed Grammar
513///
514/// ```text
515/// condition    = expr
516/// expr         = and_expr (OR and_expr)*
517/// and_expr     = unary_expr (AND unary_expr)*
518/// unary_expr   = NOT? primary_expr
519/// primary_expr = comparison | null_check | between_expr | in_expr | "(" expr ")"
520/// comparison   = identifier (comp_op | like_op | regexp_op) placeholder
521/// null_check   = identifier IS NOT? NULL
522/// between_expr = identifier NOT? BETWEEN placeholder AND placeholder
523/// in_expr      = identifier NOT? IN "(" placeholder ("," placeholder)* ")"
524/// ```
525/// Check if condition is a simple numeric equality like "1=1", "0=0", etc.
526/// These are common SQL idioms for "always true" or "always false" conditions.
527fn is_numeric_equality(condition: &str) -> bool {
528    lazy_static_numeric_eq_regex().is_match(condition.trim())
529}
530
531fn lazy_static_numeric_eq_regex() -> &'static Regex {
532    use std::sync::OnceLock;
533    static REGEX: OnceLock<Regex> = OnceLock::new();
534    REGEX.get_or_init(|| Regex::new(r"^(\d+)\s*=\s*(\d+)$").unwrap())
535}
536
537fn validate_condition(condition: &str, valid_columns: &HashSet<String>) -> Result<()> {
538    // Special case: numeric equality like "1=1", "0=0" are common SQL idioms
539    // for "always true" / "always false" conditions. Safe to allow.
540    if is_numeric_equality(condition) {
541        return Ok(());
542    }
543
544    // Step 1: Quick safety check
545    quick_safety_check(condition)?;
546
547    // Step 2: Tokenize
548    let tokens = tokenize(condition)?;
549
550    // Step 3: Parse and validate
551    let mut validator = ConditionValidator::new(&tokens, valid_columns);
552    validator.validate()?;
553
554    Ok(())
555}
556
557/// Infer SQL type from a JSON value.
558fn infer_sql_type(value: &Value) -> &'static str {
559    match value {
560        Value::Number(n) => {
561            if n.is_i64() || n.is_u64() {
562                "INTEGER"
563            } else {
564                "REAL"
565            }
566        }
567        Value::Bool(_) => "INTEGER",
568        Value::String(_) => "TEXT",
569        Value::Null => "TEXT",
570        Value::Array(_) | Value::Object(_) => "BLOB",
571    }
572}
573
574/// Convert a JSON value to a type that can be bound to SQLite.
575fn json_to_sql(value: &Value) -> Box<dyn ToSql> {
576    match value {
577        Value::Null => Box::new(None::<String>),
578        Value::Bool(b) => Box::new(if *b { 1i64 } else { 0i64 }),
579        Value::Number(n) => {
580            if let Some(i) = n.as_i64() {
581                Box::new(i)
582            } else if let Some(f) = n.as_f64() {
583                Box::new(f)
584            } else {
585                Box::new(n.to_string())
586            }
587        }
588        Value::String(s) => Box::new(s.clone()),
589        Value::Array(_) | Value::Object(_) => Box::new(serde_json::to_string(value).unwrap()),
590    }
591}
592
593/// Get the path to the metadata database for an index.
594fn get_db_path(index_path: &str) -> std::path::PathBuf {
595    Path::new(index_path).join(METADATA_DB_NAME)
596}
597
598/// Check if a metadata database exists for the given index.
599pub fn exists(index_path: &str) -> bool {
600    get_db_path(index_path).exists()
601}
602
603/// Create a new SQLite metadata database, replacing any existing one.
604///
605/// Each element in `metadata` is a JSON object representing a document's metadata.
606/// The `_subset_` column is automatically added as the primary key.
607///
608/// # Arguments
609///
610/// * `index_path` - Path to the index directory
611/// * `metadata` - Slice of JSON objects, one per document
612///
613/// # Returns
614///
615/// Number of rows inserted
616///
617/// # Errors
618///
619/// Returns an error if:
620/// - The index directory cannot be created
621/// - Column names are invalid (SQL injection prevention)
622/// - Database operations fail
623///
624/// # Example
625///
626/// ```ignore
627/// use next-plaid::filtering;
628/// use serde_json::json;
629///
630/// let metadata = vec![
631///     json!({"name": "Alice", "age": 30}),
632///     json!({"name": "Bob", "age": 25, "city": "NYC"}),
633/// ];
634/// let doc_ids: Vec<i64> = (0..2).collect();
635///
636/// filtering::create("my_index", &metadata, &doc_ids)?;
637/// ```
638pub fn create(index_path: &str, metadata: &[Value], doc_ids: &[i64]) -> Result<usize> {
639    // Validate doc_ids length matches metadata
640    if metadata.len() != doc_ids.len() {
641        return Err(Error::Filtering(format!(
642            "Metadata length ({}) must match doc_ids length ({})",
643            metadata.len(),
644            doc_ids.len()
645        )));
646    }
647
648    // Ensure index directory exists
649    let index_dir = Path::new(index_path);
650    if !index_dir.exists() {
651        fs::create_dir_all(index_dir)?;
652    }
653
654    // Remove existing database
655    let db_path = get_db_path(index_path);
656    if db_path.exists() {
657        fs::remove_file(&db_path)?;
658    }
659
660    if metadata.is_empty() {
661        return Ok(0);
662    }
663
664    // Collect all unique column names and infer types
665    let mut columns: Vec<String> = Vec::new();
666    let mut column_types: HashMap<String, &'static str> = HashMap::new();
667
668    for item in metadata {
669        if let Value::Object(obj) = item {
670            for (key, value) in obj {
671                if !columns.contains(key) {
672                    // Validate column name
673                    if !is_valid_column_name(key) {
674                        return Err(Error::Filtering(format!(
675                            "Invalid column name '{}'. Column names must start with a letter or \
676                             underscore, followed by letters, digits, or underscores.",
677                            key
678                        )));
679                    }
680                    columns.push(key.clone());
681                }
682                // Infer type from first non-null value
683                if !value.is_null() && !column_types.contains_key(key) {
684                    column_types.insert(key.clone(), infer_sql_type(value));
685                }
686            }
687        }
688    }
689
690    // Create connection
691    let conn = Connection::open(&db_path)?;
692
693    // Build CREATE TABLE statement
694    let mut col_defs = vec![format!("\"{}\" INTEGER PRIMARY KEY", SUBSET_COLUMN)];
695    for col in &columns {
696        let sql_type = column_types.get(col).copied().unwrap_or("TEXT");
697        col_defs.push(format!("\"{}\" {}", col, sql_type));
698    }
699
700    let create_sql = format!("CREATE TABLE METADATA ({})", col_defs.join(", "));
701    conn.execute(&create_sql, [])?;
702
703    // Prepare INSERT statement
704    let placeholders: Vec<&str> = std::iter::repeat_n("?", columns.len() + 1).collect();
705    let col_names: Vec<String> = columns.iter().map(|c| format!("\"{}\"", c)).collect();
706    let insert_sql = format!(
707        "INSERT INTO METADATA (\"{}\", {}) VALUES ({})",
708        SUBSET_COLUMN,
709        col_names.join(", "),
710        placeholders.join(", ")
711    );
712
713    // Insert rows
714    let mut stmt = conn.prepare(&insert_sql)?;
715    for (i, item) in metadata.iter().enumerate() {
716        let mut values: Vec<Box<dyn ToSql>> = vec![Box::new(doc_ids[i])];
717        if let Value::Object(obj) = item {
718            for col in &columns {
719                let value = obj.get(col).unwrap_or(&Value::Null);
720                values.push(json_to_sql(value));
721            }
722        } else {
723            // If not an object, insert nulls
724            for _ in &columns {
725                values.push(Box::new(None::<String>));
726            }
727        }
728        let params: Vec<&dyn ToSql> = values.iter().map(|v| v.as_ref()).collect();
729        stmt.execute(params_from_iter(params))?;
730    }
731
732    Ok(metadata.len())
733}
734
735/// Append new metadata rows to an existing database, adding columns if needed.
736///
737/// New columns found in the metadata are automatically added to the table.
738/// The `_subset_` IDs are provided explicitly via `doc_ids` to ensure sync with index.
739///
740/// # Arguments
741///
742/// * `index_path` - Path to the index directory
743/// * `metadata` - Slice of JSON objects for new documents
744/// * `doc_ids` - Document IDs to use as `_subset_` values (must match metadata length)
745///
746/// # Returns
747///
748/// Number of rows inserted
749///
750/// # Errors
751///
752/// Returns an error if:
753/// - The database doesn't exist
754/// - Column names are invalid
755/// - Database operations fail
756/// - metadata length doesn't match doc_ids length
757pub fn update(index_path: &str, metadata: &[Value], doc_ids: &[i64]) -> Result<usize> {
758    if metadata.is_empty() {
759        return Ok(0);
760    }
761
762    // Validate doc_ids length matches metadata
763    if metadata.len() != doc_ids.len() {
764        return Err(Error::Filtering(format!(
765            "Metadata length ({}) must match doc_ids length ({})",
766            metadata.len(),
767            doc_ids.len()
768        )));
769    }
770
771    let db_path = get_db_path(index_path);
772    if !db_path.exists() {
773        return Err(Error::Filtering(
774            "Metadata database does not exist. Use create() first.".into(),
775        ));
776    }
777
778    let conn = Connection::open(&db_path)?;
779
780    // Get existing columns
781    let mut existing_columns: Vec<String> = Vec::new();
782    {
783        let mut stmt = conn.prepare("PRAGMA table_info(METADATA)")?;
784        let rows = stmt.query_map([], |row| row.get::<_, String>(1))?;
785        for row in rows {
786            let col = row?;
787            if col != SUBSET_COLUMN {
788                existing_columns.push(col);
789            }
790        }
791    }
792
793    // Find new columns and add them
794    let mut new_columns: Vec<String> = Vec::new();
795    let mut column_types: HashMap<String, &'static str> = HashMap::new();
796
797    for item in metadata {
798        if let Value::Object(obj) = item {
799            for (key, value) in obj {
800                if !existing_columns.contains(key) && !new_columns.contains(key) {
801                    if !is_valid_column_name(key) {
802                        return Err(Error::Filtering(format!(
803                            "Invalid column name '{}'. Column names must start with a letter or \
804                             underscore, followed by letters, digits, or underscores.",
805                            key
806                        )));
807                    }
808                    new_columns.push(key.clone());
809                }
810                if !value.is_null() && !column_types.contains_key(key) {
811                    column_types.insert(key.clone(), infer_sql_type(value));
812                }
813            }
814        }
815    }
816
817    // Add new columns to table
818    for col in &new_columns {
819        let sql_type = column_types.get(col).copied().unwrap_or("TEXT");
820        let alter_sql = format!("ALTER TABLE METADATA ADD COLUMN \"{}\" {}", col, sql_type);
821        conn.execute(&alter_sql, [])?;
822    }
823
824    // Get all columns (existing + new)
825    let all_columns: Vec<String> = existing_columns.into_iter().chain(new_columns).collect();
826
827    // Prepare INSERT statement
828    let placeholders: Vec<&str> = std::iter::repeat_n("?", all_columns.len() + 1).collect();
829    let col_names: Vec<String> = all_columns.iter().map(|c| format!("\"{}\"", c)).collect();
830    let insert_sql = format!(
831        "INSERT INTO METADATA (\"{}\", {}) VALUES ({})",
832        SUBSET_COLUMN,
833        col_names.join(", "),
834        placeholders.join(", ")
835    );
836
837    // Insert rows
838    let mut stmt = conn.prepare(&insert_sql)?;
839    for (i, item) in metadata.iter().enumerate() {
840        let mut values: Vec<Box<dyn ToSql>> = vec![Box::new(doc_ids[i])];
841        if let Value::Object(obj) = item {
842            for col in &all_columns {
843                let value = obj.get(col).unwrap_or(&Value::Null);
844                values.push(json_to_sql(value));
845            }
846        } else {
847            for _ in &all_columns {
848                values.push(Box::new(None::<String>));
849            }
850        }
851        let params: Vec<&dyn ToSql> = values.iter().map(|v| v.as_ref()).collect();
852        stmt.execute(params_from_iter(params))?;
853    }
854
855    Ok(metadata.len())
856}
857
858/// Delete rows by subset IDs and re-index the _subset_ column to be sequential.
859///
860/// After deletion, remaining documents are re-indexed to maintain sequential
861/// `_subset_` IDs starting from 0. This matches fast-plaid behavior.
862///
863/// # Arguments
864///
865/// * `index_path` - Path to the index directory
866/// * `subset` - Slice of document IDs to delete (must be sorted ascending)
867///
868/// # Returns
869///
870/// Number of rows actually deleted
871///
872/// # Errors
873///
874/// Returns an error if the database operations fail.
875pub fn delete(index_path: &str, subset: &[i64]) -> Result<usize> {
876    if subset.is_empty() {
877        return Ok(0);
878    }
879
880    let db_path = get_db_path(index_path);
881    if !db_path.exists() {
882        return Ok(0);
883    }
884
885    let conn = Connection::open(&db_path)?;
886
887    // Start transaction
888    conn.execute("BEGIN", [])?;
889
890    // Delete specified rows
891    let placeholders: Vec<String> = subset.iter().map(|_| "?".to_string()).collect();
892    let delete_sql = format!(
893        "DELETE FROM METADATA WHERE \"{}\" IN ({})",
894        SUBSET_COLUMN,
895        placeholders.join(", ")
896    );
897    let subset_refs: Vec<&dyn ToSql> = subset.iter().map(|v| v as &dyn ToSql).collect();
898    let deleted = conn.execute(&delete_sql, params_from_iter(subset_refs))?;
899
900    // Get column names (excluding _subset_)
901    let mut columns: Vec<String> = Vec::new();
902    {
903        let mut stmt = conn.prepare("PRAGMA table_info(METADATA)")?;
904        let rows = stmt.query_map([], |row| row.get::<_, String>(1))?;
905        for row in rows {
906            let col = row?;
907            if col != SUBSET_COLUMN {
908                columns.push(col);
909            }
910        }
911    }
912
913    let col_str = columns
914        .iter()
915        .map(|c| format!("\"{}\"", c))
916        .collect::<Vec<_>>()
917        .join(", ");
918
919    // Create temp table with re-indexed _subset_ values
920    let create_temp_sql = format!(
921        "CREATE TEMP TABLE METADATA_TEMP AS \
922         SELECT (ROW_NUMBER() OVER (ORDER BY \"{}\")) - 1 AS new_subset_id, {} \
923         FROM METADATA",
924        SUBSET_COLUMN, col_str
925    );
926    conn.execute(&create_temp_sql, [])?;
927
928    // Clear original table
929    conn.execute("DELETE FROM METADATA", [])?;
930
931    // Copy back with new IDs
932    let insert_back_sql = format!(
933        "INSERT INTO METADATA (\"{}\", {}) \
934         SELECT new_subset_id, {} FROM METADATA_TEMP",
935        SUBSET_COLUMN, col_str, col_str
936    );
937    conn.execute(&insert_back_sql, [])?;
938
939    // Drop temp table
940    conn.execute("DROP TABLE METADATA_TEMP", [])?;
941
942    // Commit transaction
943    conn.execute("COMMIT", [])?;
944
945    Ok(deleted)
946}
947
948/// Query the database and return matching _subset_ IDs.
949///
950/// # Arguments
951///
952/// * `index_path` - Path to the index directory
953/// * `condition` - SQL WHERE clause with `?` placeholders (e.g., "category = ? AND score > ?")
954/// * `parameters` - Values to substitute for placeholders
955///
956/// # Returns
957///
958/// Vector of `_subset_` IDs matching the condition
959///
960/// # Example
961///
962/// ```ignore
963/// use next-plaid::filtering;
964/// use serde_json::json;
965///
966/// let subset = filtering::where_condition(
967///     "my_index",
968///     "category = ? AND score > ?",
969///     &[json!("A"), json!(90)],
970/// )?;
971/// ```
972pub fn where_condition(
973    index_path: &str,
974    condition: &str,
975    parameters: &[Value],
976) -> Result<Vec<i64>> {
977    let db_path = get_db_path(index_path);
978    if !db_path.exists() {
979        return Err(Error::Filtering(
980            "No metadata database found. Create it first by adding metadata during index creation."
981                .into(),
982        ));
983    }
984
985    let conn = Connection::open(&db_path)?;
986
987    // Validate condition against SQL injection
988    let valid_columns = get_schema_columns(&conn)?;
989    validate_condition(condition, &valid_columns)?;
990
991    let query = format!(
992        "SELECT \"{}\" FROM METADATA WHERE {}",
993        SUBSET_COLUMN, condition
994    );
995
996    let params: Vec<Box<dyn ToSql>> = parameters.iter().map(json_to_sql).collect();
997    let param_refs: Vec<&dyn ToSql> = params.iter().map(|v| v.as_ref()).collect();
998
999    let mut stmt = conn.prepare(&query)?;
1000    let rows = stmt.query_map(params_from_iter(param_refs), |row| row.get::<_, i64>(0))?;
1001
1002    let mut result = Vec::new();
1003    for row in rows {
1004        result.push(row?);
1005    }
1006
1007    Ok(result)
1008}
1009
1010/// Query document IDs with REGEXP support enabled.
1011///
1012/// This function is similar to `where_condition` but registers a REGEXP
1013/// function that uses Rust's regex crate for pattern matching.
1014///
1015/// # Arguments
1016///
1017/// * `index_path` - Path to the index directory
1018/// * `condition` - SQL WHERE clause (can use `column REGEXP ?`)
1019/// * `parameters` - Values for condition placeholders
1020///
1021/// # Example
1022///
1023/// ```ignore
1024/// // Find documents where code_preview matches a regex
1025/// let ids = filtering::where_condition_regexp(
1026///     "my_index",
1027///     "code_preview REGEXP ?",
1028///     &[json!("async|await")],
1029/// )?;
1030/// ```
1031///
1032/// # Security
1033///
1034/// The regex is compiled with size limits (10MB) to prevent ReDoS attacks.
1035/// Invalid regex patterns return an error with a descriptive message.
1036pub fn where_condition_regexp(
1037    index_path: &str,
1038    condition: &str,
1039    parameters: &[Value],
1040) -> Result<Vec<i64>> {
1041    let db_path = get_db_path(index_path);
1042    if !db_path.exists() {
1043        return Err(Error::Filtering(
1044            "No metadata database found. Create it first by adding metadata during index creation."
1045                .into(),
1046        ));
1047    }
1048
1049    // For REGEXP queries, extract and pre-compile the pattern once (not per-row)
1050    // This provides both performance and security benefits
1051    let regex_pattern = parameters
1052        .first()
1053        .and_then(|v| v.as_str())
1054        .ok_or_else(|| Error::Filtering("REGEXP requires a pattern parameter".into()))?;
1055
1056    // Compile regex with protections:
1057    // - size_limit: Prevents ReDoS by limiting compiled regex size (10MB)
1058    // - case_insensitive: Standard grep-like behavior
1059    let compiled_regex = std::sync::Arc::new(
1060        regex::RegexBuilder::new(regex_pattern)
1061            .case_insensitive(true)
1062            .size_limit(10 * (1 << 20)) // 10MB limit for ReDoS protection
1063            .build()
1064            .map_err(|e| {
1065                Error::Filtering(format!("Invalid regex pattern '{}': {}", regex_pattern, e))
1066            })?,
1067    );
1068
1069    let conn = Connection::open(&db_path)?;
1070
1071    // Validate condition against SQL injection
1072    let valid_columns = get_schema_columns(&conn)?;
1073    validate_condition(condition, &valid_columns)?;
1074
1075    // Register REGEXP function with pre-compiled regex (compiled once, used for all rows)
1076    let re = compiled_regex.clone();
1077    conn.create_scalar_function(
1078        "regexp",
1079        2,
1080        rusqlite::functions::FunctionFlags::SQLITE_UTF8
1081            | rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
1082        move |ctx| {
1083            // Pattern argument from SQL is ignored - we use the pre-compiled regex
1084            let _pattern: String = ctx.get(0)?;
1085            let text: String = ctx.get(1)?;
1086
1087            Ok(re.is_match(&text))
1088        },
1089    )?;
1090
1091    let query = format!(
1092        "SELECT \"{}\" FROM METADATA WHERE {}",
1093        SUBSET_COLUMN, condition
1094    );
1095
1096    let params: Vec<Box<dyn ToSql>> = parameters.iter().map(json_to_sql).collect();
1097    let param_refs: Vec<&dyn ToSql> = params.iter().map(|v| v.as_ref()).collect();
1098
1099    let mut stmt = conn.prepare(&query)?;
1100    let rows = stmt.query_map(params_from_iter(param_refs), |row| row.get::<_, i64>(0))?;
1101
1102    let mut result = Vec::new();
1103    for row in rows {
1104        result.push(row?);
1105    }
1106
1107    Ok(result)
1108}
1109
1110/// Get full metadata rows by condition or subset IDs.
1111///
1112/// Returns metadata as JSON objects with the `_subset_` field included.
1113///
1114/// # Arguments
1115///
1116/// * `index_path` - Path to the index directory
1117/// * `condition` - Optional SQL WHERE clause (mutually exclusive with `subset`)
1118/// * `parameters` - Values for condition placeholders
1119/// * `subset` - Optional list of `_subset_` IDs to retrieve (mutually exclusive with `condition`)
1120///
1121/// # Returns
1122///
1123/// Vector of JSON objects representing metadata rows
1124///
1125/// # Ordering
1126///
1127/// - If `subset` is provided: Returns rows in the order specified by `subset`
1128/// - If `condition` is provided: Returns rows ordered by `_subset_` ascending
1129pub fn get(
1130    index_path: &str,
1131    condition: Option<&str>,
1132    parameters: &[Value],
1133    subset: Option<&[i64]>,
1134) -> Result<Vec<Value>> {
1135    if condition.is_some() && subset.is_some() {
1136        return Err(Error::Filtering(
1137            "Please provide either a 'condition' or a 'subset', not both.".into(),
1138        ));
1139    }
1140
1141    let db_path = get_db_path(index_path);
1142    if !db_path.exists() {
1143        return Ok(Vec::new());
1144    }
1145
1146    let conn = Connection::open(&db_path)?;
1147
1148    // Validate condition against SQL injection if provided
1149    if let Some(cond) = condition {
1150        let valid_columns = get_schema_columns(&conn)?;
1151        validate_condition(cond, &valid_columns)?;
1152    }
1153
1154    // Get column names
1155    let mut columns: Vec<String> = Vec::new();
1156    {
1157        let mut stmt = conn.prepare("PRAGMA table_info(METADATA)")?;
1158        let rows = stmt.query_map([], |row| row.get::<_, String>(1))?;
1159        for row in rows {
1160            columns.push(row?);
1161        }
1162    }
1163
1164    // Build query
1165    let (query, params): (String, Vec<Box<dyn ToSql>>) = if let Some(cond) = condition {
1166        let query = format!(
1167            "SELECT * FROM METADATA WHERE {} ORDER BY \"{}\"",
1168            cond, SUBSET_COLUMN
1169        );
1170        let params = parameters.iter().map(json_to_sql).collect();
1171        (query, params)
1172    } else if let Some(ids) = subset {
1173        if ids.is_empty() {
1174            return Ok(Vec::new());
1175        }
1176        let placeholders: Vec<String> = ids.iter().map(|_| "?".to_string()).collect();
1177        let query = format!(
1178            "SELECT * FROM METADATA WHERE \"{}\" IN ({})",
1179            SUBSET_COLUMN,
1180            placeholders.join(", ")
1181        );
1182        let params: Vec<Box<dyn ToSql>> = ids
1183            .iter()
1184            .map(|&id| Box::new(id) as Box<dyn ToSql>)
1185            .collect();
1186        (query, params)
1187    } else {
1188        let query = format!("SELECT * FROM METADATA ORDER BY \"{}\"", SUBSET_COLUMN);
1189        (query, Vec::new())
1190    };
1191
1192    let param_refs: Vec<&dyn ToSql> = params.iter().map(|v| v.as_ref()).collect();
1193    let mut stmt = conn.prepare(&query)?;
1194    let mut rows = stmt.query(params_from_iter(param_refs))?;
1195
1196    let mut results: Vec<Value> = Vec::new();
1197    while let Some(row) = rows.next()? {
1198        let mut obj = serde_json::Map::new();
1199        for (i, col) in columns.iter().enumerate() {
1200            let value = row_to_json_value(row, i)?;
1201            obj.insert(col.clone(), value);
1202        }
1203        results.push(Value::Object(obj));
1204    }
1205
1206    // If subset was provided, reorder results to match subset order
1207    if let Some(ids) = subset {
1208        let mut results_map: HashMap<i64, Value> = HashMap::new();
1209        for result in results {
1210            if let Some(id) = result.get(SUBSET_COLUMN).and_then(|v| v.as_i64()) {
1211                results_map.insert(id, result);
1212            }
1213        }
1214        results = ids.iter().filter_map(|id| results_map.remove(id)).collect();
1215    }
1216
1217    Ok(results)
1218}
1219
1220/// Helper to convert a rusqlite row column to JSON value.
1221fn row_to_json_value(row: &rusqlite::Row, idx: usize) -> SqliteResult<Value> {
1222    // Try to get the value in order of most likely types
1223    if let Ok(i) = row.get::<_, i64>(idx) {
1224        return Ok(Value::Number(i.into()));
1225    }
1226    if let Ok(f) = row.get::<_, f64>(idx) {
1227        return Ok(serde_json::Number::from_f64(f)
1228            .map(Value::Number)
1229            .unwrap_or(Value::Null));
1230    }
1231    if let Ok(s) = row.get::<_, String>(idx) {
1232        return Ok(Value::String(s));
1233    }
1234    if let Ok(b) = row.get::<_, Vec<u8>>(idx) {
1235        // Try to parse as JSON first
1236        if let Ok(v) = serde_json::from_slice(&b) {
1237            return Ok(v);
1238        }
1239        // Otherwise return as base64 string
1240        return Ok(Value::String(base64_encode(&b)));
1241    }
1242    Ok(Value::Null)
1243}
1244
1245fn base64_encode(data: &[u8]) -> String {
1246    let mut result = String::with_capacity(data.len() * 4 / 3 + 4);
1247    const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1248
1249    for chunk in data.chunks(3) {
1250        let b0 = chunk[0] as usize;
1251        let b1 = chunk.get(1).copied().unwrap_or(0) as usize;
1252        let b2 = chunk.get(2).copied().unwrap_or(0) as usize;
1253
1254        result.push(ALPHABET[b0 >> 2] as char);
1255        result.push(ALPHABET[((b0 & 0x03) << 4) | (b1 >> 4)] as char);
1256
1257        if chunk.len() > 1 {
1258            result.push(ALPHABET[((b1 & 0x0f) << 2) | (b2 >> 6)] as char);
1259        } else {
1260            result.push('=');
1261        }
1262
1263        if chunk.len() > 2 {
1264            result.push(ALPHABET[b2 & 0x3f] as char);
1265        } else {
1266            result.push('=');
1267        }
1268    }
1269
1270    result
1271}
1272
1273/// Update metadata rows matching a SQL condition.
1274///
1275/// This function updates existing metadata rows that match the given condition.
1276/// The updates are provided as a JSON object where keys are column names and values
1277/// are the new values to set.
1278///
1279/// # Arguments
1280///
1281/// * `index_path` - Path to the index directory
1282/// * `condition` - SQL WHERE clause with `?` placeholders (e.g., "category = ? AND score > ?")
1283/// * `parameters` - Values to substitute for condition placeholders
1284/// * `updates` - JSON object with column names and new values
1285///
1286/// # Returns
1287///
1288/// Number of rows updated
1289///
1290/// # Example
1291///
1292/// ```ignore
1293/// use next-plaid::filtering;
1294/// use serde_json::json;
1295///
1296/// let updated = filtering::update_where(
1297///     "my_index",
1298///     "category = ?",
1299///     &[json!("A")],
1300///     &json!({"score": 100, "status": "reviewed"}),
1301/// )?;
1302/// ```
1303pub fn update_where(
1304    index_path: &str,
1305    condition: &str,
1306    parameters: &[Value],
1307    updates: &Value,
1308) -> Result<usize> {
1309    let db_path = get_db_path(index_path);
1310    if !db_path.exists() {
1311        return Err(Error::Filtering(
1312            "No metadata database found. Create it first by adding metadata during index creation."
1313                .into(),
1314        ));
1315    }
1316
1317    // Parse updates as an object
1318    let updates_obj = match updates {
1319        Value::Object(obj) => obj,
1320        _ => {
1321            return Err(Error::Filtering("Updates must be a JSON object".into()));
1322        }
1323    };
1324
1325    if updates_obj.is_empty() {
1326        return Ok(0);
1327    }
1328
1329    let conn = Connection::open(&db_path)?;
1330
1331    // Validate condition against SQL injection
1332    let valid_columns = get_schema_columns(&conn)?;
1333    validate_condition(condition, &valid_columns)?;
1334
1335    // Validate update column names against schema
1336    for col_name in updates_obj.keys() {
1337        if col_name == SUBSET_COLUMN {
1338            return Err(Error::Filtering("Cannot update the _subset_ column".into()));
1339        }
1340        if !is_valid_column_name(col_name) {
1341            return Err(Error::Filtering(format!(
1342                "Invalid column name '{}'. Column names must start with a letter or \
1343                 underscore, followed by letters, digits, or underscores.",
1344                col_name
1345            )));
1346        }
1347        // Check if column exists (case-insensitive)
1348        let col_lower = col_name.to_lowercase();
1349        let exists = valid_columns.iter().any(|c| c.to_lowercase() == col_lower);
1350        if !exists {
1351            return Err(Error::Filtering(format!(
1352                "Unknown column '{}' in updates",
1353                col_name
1354            )));
1355        }
1356    }
1357
1358    // Build SET clause
1359    let set_parts: Vec<String> = updates_obj
1360        .keys()
1361        .map(|col| format!("\"{}\" = ?", col))
1362        .collect();
1363    let set_clause = set_parts.join(", ");
1364
1365    // Build UPDATE query
1366    let query = format!("UPDATE METADATA SET {} WHERE {}", set_clause, condition);
1367
1368    // Build parameter list: first the update values, then the condition parameters
1369    let mut all_params: Vec<Box<dyn ToSql>> = updates_obj.values().map(json_to_sql).collect();
1370    all_params.extend(parameters.iter().map(json_to_sql));
1371
1372    let param_refs: Vec<&dyn ToSql> = all_params.iter().map(|v| v.as_ref()).collect();
1373
1374    let updated = conn.execute(&query, params_from_iter(param_refs))?;
1375
1376    Ok(updated)
1377}
1378
1379/// Get the number of documents in the metadata database.
1380pub fn count(index_path: &str) -> Result<usize> {
1381    let db_path = get_db_path(index_path);
1382    if !db_path.exists() {
1383        return Ok(0);
1384    }
1385
1386    let conn = Connection::open(&db_path)?;
1387    let count: i64 = conn.query_row("SELECT COUNT(*) FROM METADATA", [], |row| row.get(0))?;
1388    Ok(count as usize)
1389}
1390
1391#[cfg(test)]
1392mod tests {
1393    use super::*;
1394    use serde_json::json;
1395    use tempfile::TempDir;
1396
1397    fn setup_test_dir() -> TempDir {
1398        TempDir::new().unwrap()
1399    }
1400
1401    #[test]
1402    fn test_create_empty() {
1403        let dir = setup_test_dir();
1404        let path = dir.path().to_str().unwrap();
1405
1406        let result = create(path, &[], &[]).unwrap();
1407        assert_eq!(result, 0);
1408        assert!(!exists(path));
1409    }
1410
1411    #[test]
1412    fn test_create_with_metadata() {
1413        let dir = setup_test_dir();
1414        let path = dir.path().to_str().unwrap();
1415
1416        let metadata = vec![
1417            json!({"name": "Alice", "age": 30, "score": 95.5}),
1418            json!({"name": "Bob", "age": 25, "score": 87.0}),
1419            json!({"name": "Charlie", "age": 35}),
1420        ];
1421        let doc_ids: Vec<i64> = (0..3).collect();
1422
1423        let result = create(path, &metadata, &doc_ids).unwrap();
1424        assert_eq!(result, 3);
1425        assert!(exists(path));
1426
1427        // Verify count
1428        assert_eq!(count(path).unwrap(), 3);
1429    }
1430
1431    #[test]
1432    fn test_create_invalid_column_name() {
1433        let dir = setup_test_dir();
1434        let path = dir.path().to_str().unwrap();
1435
1436        let metadata = vec![json!({"valid_name": "Alice", "invalid name": 30})];
1437        let doc_ids = vec![0];
1438
1439        let result = create(path, &metadata, &doc_ids);
1440        assert!(result.is_err());
1441    }
1442
1443    #[test]
1444    fn test_where_condition() {
1445        let dir = setup_test_dir();
1446        let path = dir.path().to_str().unwrap();
1447
1448        let metadata = vec![
1449            json!({"name": "Alice", "category": "A", "score": 95}),
1450            json!({"name": "Bob", "category": "B", "score": 87}),
1451            json!({"name": "Charlie", "category": "A", "score": 92}),
1452        ];
1453        let doc_ids: Vec<i64> = (0..3).collect();
1454
1455        create(path, &metadata, &doc_ids).unwrap();
1456
1457        // Query by category
1458        let subset = where_condition(path, "category = ?", &[json!("A")]).unwrap();
1459        assert_eq!(subset, vec![0, 2]);
1460
1461        // Query with multiple conditions
1462        let subset =
1463            where_condition(path, "category = ? AND score > ?", &[json!("A"), json!(93)]).unwrap();
1464        assert_eq!(subset, vec![0]);
1465    }
1466
1467    #[test]
1468    fn test_get_all() {
1469        let dir = setup_test_dir();
1470        let path = dir.path().to_str().unwrap();
1471
1472        let metadata = vec![
1473            json!({"name": "Alice", "age": 30}),
1474            json!({"name": "Bob", "age": 25}),
1475        ];
1476        let doc_ids: Vec<i64> = (0..2).collect();
1477
1478        create(path, &metadata, &doc_ids).unwrap();
1479
1480        let results = get(path, None, &[], None).unwrap();
1481        assert_eq!(results.len(), 2);
1482        assert_eq!(results[0]["name"], "Alice");
1483        assert_eq!(results[1]["name"], "Bob");
1484    }
1485
1486    #[test]
1487    fn test_get_by_subset() {
1488        let dir = setup_test_dir();
1489        let path = dir.path().to_str().unwrap();
1490
1491        let metadata = vec![
1492            json!({"name": "Alice"}),
1493            json!({"name": "Bob"}),
1494            json!({"name": "Charlie"}),
1495        ];
1496        let doc_ids: Vec<i64> = (0..3).collect();
1497
1498        create(path, &metadata, &doc_ids).unwrap();
1499
1500        // Get specific subset in order
1501        let results = get(path, None, &[], Some(&[2, 0])).unwrap();
1502        assert_eq!(results.len(), 2);
1503        assert_eq!(results[0]["name"], "Charlie");
1504        assert_eq!(results[1]["name"], "Alice");
1505    }
1506
1507    #[test]
1508    fn test_update_adds_rows() {
1509        let dir = setup_test_dir();
1510        let path = dir.path().to_str().unwrap();
1511
1512        let metadata1 = vec![json!({"name": "Alice"}), json!({"name": "Bob"})];
1513        let doc_ids1: Vec<i64> = (0..2).collect();
1514
1515        create(path, &metadata1, &doc_ids1).unwrap();
1516        assert_eq!(count(path).unwrap(), 2);
1517
1518        let metadata2 = vec![json!({"name": "Charlie"})];
1519        let doc_ids2 = vec![2]; // Next ID after the first batch
1520
1521        update(path, &metadata2, &doc_ids2).unwrap();
1522        assert_eq!(count(path).unwrap(), 3);
1523
1524        // Verify the new row has correct _subset_ ID
1525        let results = get(path, None, &[], None).unwrap();
1526        assert_eq!(results[2]["_subset_"], 2);
1527        assert_eq!(results[2]["name"], "Charlie");
1528    }
1529
1530    #[test]
1531    fn test_update_adds_columns() {
1532        let dir = setup_test_dir();
1533        let path = dir.path().to_str().unwrap();
1534
1535        let metadata1 = vec![json!({"name": "Alice"})];
1536        let doc_ids1 = vec![0];
1537
1538        create(path, &metadata1, &doc_ids1).unwrap();
1539
1540        let metadata2 = vec![json!({"name": "Bob", "age": 25, "city": "NYC"})];
1541        let doc_ids2 = vec![1];
1542
1543        update(path, &metadata2, &doc_ids2).unwrap();
1544
1545        // Verify new columns exist
1546        let results = get(path, None, &[], None).unwrap();
1547        assert_eq!(results[0]["name"], "Alice");
1548        assert!(results[0]["age"].is_null()); // Old row has null for new column
1549        assert_eq!(results[1]["age"], 25);
1550        assert_eq!(results[1]["city"], "NYC");
1551    }
1552
1553    #[test]
1554    fn test_delete_and_reindex() {
1555        let dir = setup_test_dir();
1556        let path = dir.path().to_str().unwrap();
1557
1558        let metadata = vec![
1559            json!({"name": "Alice"}),
1560            json!({"name": "Bob"}),
1561            json!({"name": "Charlie"}),
1562            json!({"name": "Diana"}),
1563        ];
1564        let doc_ids: Vec<i64> = (0..4).collect();
1565
1566        create(path, &metadata, &doc_ids).unwrap();
1567
1568        // Delete Bob (1) and Charlie (2)
1569        let deleted = delete(path, &[1, 2]).unwrap();
1570        assert_eq!(deleted, 2);
1571        assert_eq!(count(path).unwrap(), 2);
1572
1573        // Verify remaining rows have re-indexed _subset_ IDs
1574        let results = get(path, None, &[], None).unwrap();
1575        assert_eq!(results.len(), 2);
1576        assert_eq!(results[0]["_subset_"], 0);
1577        assert_eq!(results[0]["name"], "Alice");
1578        assert_eq!(results[1]["_subset_"], 1);
1579        assert_eq!(results[1]["name"], "Diana");
1580    }
1581
1582    #[test]
1583    fn test_where_with_like() {
1584        let dir = setup_test_dir();
1585        let path = dir.path().to_str().unwrap();
1586
1587        let metadata = vec![
1588            json!({"name": "Alice"}),
1589            json!({"name": "Alex"}),
1590            json!({"name": "Bob"}),
1591        ];
1592        let doc_ids: Vec<i64> = (0..3).collect();
1593
1594        create(path, &metadata, &doc_ids).unwrap();
1595
1596        let subset = where_condition(path, "name LIKE ?", &[json!("Al%")]).unwrap();
1597        assert_eq!(subset, vec![0, 1]);
1598    }
1599
1600    #[test]
1601    fn test_is_valid_column_name() {
1602        assert!(is_valid_column_name("name"));
1603        assert!(is_valid_column_name("_private"));
1604        assert!(is_valid_column_name("column123"));
1605        assert!(is_valid_column_name("Col_Name_2"));
1606
1607        assert!(!is_valid_column_name("123column")); // starts with number
1608        assert!(!is_valid_column_name("column name")); // space
1609        assert!(!is_valid_column_name("column-name")); // hyphen
1610        assert!(!is_valid_column_name("")); // empty
1611        assert!(!is_valid_column_name("col;drop")); // SQL injection attempt
1612    }
1613
1614    #[test]
1615    fn test_type_inference() {
1616        let dir = setup_test_dir();
1617        let path = dir.path().to_str().unwrap();
1618
1619        let metadata = vec![json!({
1620            "int_val": 42,
1621            "float_val": 3.125,
1622            "str_val": "hello",
1623            "bool_val": true,
1624            "null_val": null
1625        })];
1626        let doc_ids = vec![0];
1627
1628        create(path, &metadata, &doc_ids).unwrap();
1629
1630        let results = get(path, None, &[], None).unwrap();
1631        assert_eq!(results[0]["int_val"], 42);
1632        assert!((results[0]["float_val"].as_f64().unwrap() - 3.125).abs() < 0.001);
1633        assert_eq!(results[0]["str_val"], "hello");
1634        assert_eq!(results[0]["bool_val"], 1); // Bool stored as INTEGER
1635        assert!(results[0]["null_val"].is_null());
1636    }
1637
1638    // =============================================================================
1639    // SQL Condition Validator Tests
1640    // =============================================================================
1641
1642    fn test_columns() -> HashSet<String> {
1643        ["name", "category", "score", "status", "_subset_"]
1644            .iter()
1645            .map(|s| s.to_string())
1646            .collect()
1647    }
1648
1649    #[test]
1650    fn test_validator_simple_equality() {
1651        let cols = test_columns();
1652        assert!(validate_condition("name = ?", &cols).is_ok());
1653        assert!(validate_condition("score = ?", &cols).is_ok());
1654    }
1655
1656    #[test]
1657    fn test_validator_comparison_operators() {
1658        let cols = test_columns();
1659        assert!(validate_condition("score > ?", &cols).is_ok());
1660        assert!(validate_condition("score >= ?", &cols).is_ok());
1661        assert!(validate_condition("score < ?", &cols).is_ok());
1662        assert!(validate_condition("score <= ?", &cols).is_ok());
1663        assert!(validate_condition("score != ?", &cols).is_ok());
1664        assert!(validate_condition("score <> ?", &cols).is_ok());
1665    }
1666
1667    #[test]
1668    fn test_validator_and_or() {
1669        let cols = test_columns();
1670        assert!(validate_condition("name = ? AND score > ?", &cols).is_ok());
1671        assert!(validate_condition("category = ? OR status = ?", &cols).is_ok());
1672        assert!(validate_condition("name = ? AND score > ? OR category = ?", &cols).is_ok());
1673    }
1674
1675    #[test]
1676    fn test_validator_like() {
1677        let cols = test_columns();
1678        assert!(validate_condition("name LIKE ?", &cols).is_ok());
1679        assert!(validate_condition("name NOT LIKE ?", &cols).is_ok());
1680    }
1681
1682    #[test]
1683    fn test_validator_regexp() {
1684        let cols = test_columns();
1685        assert!(validate_condition("name REGEXP ?", &cols).is_ok());
1686        assert!(validate_condition("name NOT REGEXP ?", &cols).is_ok());
1687    }
1688
1689    #[test]
1690    fn test_validator_between() {
1691        let cols = test_columns();
1692        assert!(validate_condition("score BETWEEN ? AND ?", &cols).is_ok());
1693        assert!(validate_condition("score NOT BETWEEN ? AND ?", &cols).is_ok());
1694    }
1695
1696    #[test]
1697    fn test_validator_in() {
1698        let cols = test_columns();
1699        assert!(validate_condition("category IN (?)", &cols).is_ok());
1700        assert!(validate_condition("category IN (?, ?)", &cols).is_ok());
1701        assert!(validate_condition("category IN (?, ?, ?)", &cols).is_ok());
1702        assert!(validate_condition("category NOT IN (?, ?)", &cols).is_ok());
1703    }
1704
1705    #[test]
1706    fn test_validator_is_null() {
1707        let cols = test_columns();
1708        assert!(validate_condition("name IS NULL", &cols).is_ok());
1709        assert!(validate_condition("name IS NOT NULL", &cols).is_ok());
1710    }
1711
1712    #[test]
1713    fn test_validator_parentheses() {
1714        let cols = test_columns();
1715        assert!(validate_condition("(name = ?)", &cols).is_ok());
1716        assert!(validate_condition("(name = ? AND score > ?)", &cols).is_ok());
1717        assert!(validate_condition("(name = ? OR category = ?) AND score > ?", &cols).is_ok());
1718        assert!(validate_condition("name = ? AND (category = ? OR status = ?)", &cols).is_ok());
1719    }
1720
1721    #[test]
1722    fn test_validator_not() {
1723        let cols = test_columns();
1724        assert!(validate_condition("NOT name = ?", &cols).is_ok());
1725        assert!(validate_condition("NOT (name = ? AND score > ?)", &cols).is_ok());
1726    }
1727
1728    #[test]
1729    fn test_validator_quoted_identifiers() {
1730        let cols = test_columns();
1731        assert!(validate_condition("\"name\" = ?", &cols).is_ok());
1732        assert!(validate_condition("\"score\" > ?", &cols).is_ok());
1733    }
1734
1735    #[test]
1736    fn test_validator_case_insensitive_keywords() {
1737        let cols = test_columns();
1738        assert!(validate_condition("name = ? and score > ?", &cols).is_ok());
1739        assert!(validate_condition("name = ? AND score > ?", &cols).is_ok());
1740        assert!(validate_condition("name LIKE ? or category = ?", &cols).is_ok());
1741        assert!(validate_condition("score between ? and ?", &cols).is_ok());
1742    }
1743
1744    #[test]
1745    fn test_validator_allows_numeric_equality() {
1746        // Special case: numeric equality patterns are common SQL idioms
1747        // "1=1" for "always true", "1=0" for "always false", etc.
1748        let cols = test_columns();
1749        assert!(validate_condition("1=1", &cols).is_ok());
1750        assert!(validate_condition(" 1=1 ", &cols).is_ok()); // with whitespace
1751        assert!(validate_condition("0=0", &cols).is_ok());
1752        assert!(validate_condition("1 = 1", &cols).is_ok()); // with spaces around =
1753        assert!(validate_condition("42=42", &cols).is_ok());
1754        assert!(validate_condition("1=0", &cols).is_ok()); // "always false"
1755    }
1756
1757    // SQL injection tests
1758
1759    #[test]
1760    fn test_validator_rejects_semicolon() {
1761        let cols = test_columns();
1762        let result = validate_condition("name = ?; DROP TABLE METADATA", &cols);
1763        assert!(result.is_err());
1764        assert!(result.unwrap_err().to_string().contains("Semicolon"));
1765    }
1766
1767    #[test]
1768    fn test_validator_rejects_comments() {
1769        let cols = test_columns();
1770        assert!(validate_condition("name = ? -- comment", &cols).is_err());
1771        assert!(validate_condition("name = ? /* comment */", &cols).is_err());
1772    }
1773
1774    #[test]
1775    fn test_validator_rejects_union() {
1776        let cols = test_columns();
1777        // UNION is rejected by quick_safety_check (SELECT may be rejected first if present)
1778        let result = validate_condition("name = ? UNION SELECT * FROM users", &cols);
1779        assert!(result.is_err());
1780        // Both UNION and SELECT are dangerous keywords, either error message is acceptable
1781        let err_msg = result.unwrap_err().to_string();
1782        assert!(
1783            err_msg.contains("UNION") || err_msg.contains("SELECT"),
1784            "Expected error about UNION or SELECT, got: {}",
1785            err_msg
1786        );
1787    }
1788
1789    #[test]
1790    fn test_validator_rejects_subqueries() {
1791        let cols = test_columns();
1792        // SELECT is rejected by quick_safety_check
1793        let result = validate_condition("name = (SELECT name FROM users)", &cols);
1794        assert!(result.is_err());
1795    }
1796
1797    #[test]
1798    fn test_validator_rejects_ddl_keywords() {
1799        let cols = test_columns();
1800        assert!(validate_condition("DROP TABLE METADATA", &cols).is_err());
1801        assert!(validate_condition("DELETE FROM METADATA", &cols).is_err());
1802        assert!(validate_condition("INSERT INTO METADATA VALUES (?)", &cols).is_err());
1803        assert!(validate_condition("UPDATE METADATA SET name = ?", &cols).is_err());
1804        assert!(validate_condition("CREATE TABLE foo (id INT)", &cols).is_err());
1805        assert!(validate_condition("ALTER TABLE METADATA ADD x INT", &cols).is_err());
1806        assert!(validate_condition("TRUNCATE TABLE METADATA", &cols).is_err());
1807    }
1808
1809    #[test]
1810    fn test_validator_rejects_unknown_columns() {
1811        let cols = test_columns();
1812        let result = validate_condition("unknown_column = ?", &cols);
1813        assert!(result.is_err());
1814        assert!(result.unwrap_err().to_string().contains("Unknown column"));
1815    }
1816
1817    #[test]
1818    fn test_validator_rejects_string_literals() {
1819        let cols = test_columns();
1820        // String literals are rejected as unexpected characters
1821        let result = validate_condition("name = 'Alice'", &cols);
1822        assert!(result.is_err());
1823    }
1824
1825    #[test]
1826    fn test_validator_rejects_malformed_syntax() {
1827        let cols = test_columns();
1828        // Missing placeholder
1829        assert!(validate_condition("name =", &cols).is_err());
1830        // Unbalanced parentheses
1831        assert!(validate_condition("(name = ?", &cols).is_err());
1832        assert!(validate_condition("name = ?)", &cols).is_err());
1833        // Double operators
1834        assert!(validate_condition("name = = ?", &cols).is_err());
1835        // Missing column
1836        assert!(validate_condition("= ?", &cols).is_err());
1837    }
1838
1839    #[test]
1840    fn test_validator_rejects_function_calls() {
1841        let cols = test_columns();
1842        // Function calls result in unexpected tokens
1843        let result = validate_condition("LENGTH(name) > ?", &cols);
1844        // LENGTH is parsed as identifier, then ( is unexpected after it
1845        assert!(result.is_err());
1846    }
1847
1848    #[test]
1849    fn test_validator_integration() {
1850        // Test that validation works end-to-end with actual database
1851        let dir = setup_test_dir();
1852        let path = dir.path().to_str().unwrap();
1853
1854        let metadata = vec![
1855            json!({"name": "Alice", "category": "A", "score": 95}),
1856            json!({"name": "Bob", "category": "B", "score": 87}),
1857        ];
1858        let doc_ids: Vec<i64> = (0..2).collect();
1859        create(path, &metadata, &doc_ids).unwrap();
1860
1861        // Valid condition should work
1862        let result = where_condition(path, "category = ? AND score > ?", &[json!("A"), json!(90)]);
1863        assert!(result.is_ok());
1864        assert_eq!(result.unwrap(), vec![0]);
1865
1866        // SQL injection attempt should be rejected
1867        let result = where_condition(path, "category = ?; DROP TABLE METADATA", &[json!("A")]);
1868        assert!(result.is_err());
1869
1870        // Unknown column should be rejected
1871        let result = where_condition(path, "unknown = ?", &[json!("test")]);
1872        assert!(result.is_err());
1873    }
1874
1875    #[test]
1876    fn test_validator_integration_get() {
1877        let dir = setup_test_dir();
1878        let path = dir.path().to_str().unwrap();
1879
1880        let metadata = vec![
1881            json!({"name": "Alice", "score": 95}),
1882            json!({"name": "Bob", "score": 87}),
1883        ];
1884        let doc_ids: Vec<i64> = (0..2).collect();
1885        create(path, &metadata, &doc_ids).unwrap();
1886
1887        // Valid condition should work
1888        let result = get(path, Some("score > ?"), &[json!(90)], None);
1889        assert!(result.is_ok());
1890        assert_eq!(result.unwrap().len(), 1);
1891
1892        // SQL injection should be rejected
1893        let result = get(path, Some("1=1 UNION SELECT * FROM users"), &[], None);
1894        assert!(result.is_err());
1895    }
1896}