sqlite-vector-rs 0.2.2

SQLite extension providing PGVector-like native vector types with HNSW indexing
Documentation
use std::fmt;

use crate::distance::DistanceMetric;
use crate::index::HnswParams;
use crate::types::VectorType;

#[derive(Debug)]
pub struct ConfigError(pub String);

impl fmt::Display for ConfigError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "config error: {}", self.0)
    }
}

impl std::error::Error for ConfigError {}

/// Parsed configuration from CREATE VIRTUAL TABLE arguments.
#[derive(Debug, Clone)]
pub struct VectorTableConfig {
    pub db_name: String,
    pub table_name: String,
    pub dim: usize,
    pub vtype: VectorType,
    pub metric: DistanceMetric,
    pub hnsw_params: HnswParams,
    pub metadata_columns: Vec<(String, String)>,
}

impl VectorTableConfig {
    pub fn parse(args: &[&str]) -> Result<Self, ConfigError> {
        if args.len() < 3 {
            return Err(ConfigError(
                "expected at least module, db, and table name".into(),
            ));
        }

        let db_name = args[1].to_string();
        let table_name = args[2].to_string();

        let mut dim: Option<usize> = None;
        let mut vtype = VectorType::Float4;
        let mut metric = DistanceMetric::L2;
        let mut hnsw_params = HnswParams::default();
        let mut metadata_columns = Vec::new();

        for &arg in &args[3..] {
            let (key, value) = arg
                .split_once('=')
                .ok_or_else(|| ConfigError(format!("invalid argument: {arg}")))?;
            let key = key.trim();
            let value = value.trim().trim_matches('"');

            match key {
                "dim" => {
                    let d: i64 = value
                        .parse()
                        .map_err(|_| ConfigError(format!("invalid dim: {value}")))?;
                    if d <= 0 {
                        return Err(ConfigError(format!("dim must be positive, got {d}")));
                    }
                    dim = Some(d as usize);
                }
                "type" => {
                    vtype = VectorType::from_name(value).map_err(|e| ConfigError(e.to_string()))?;
                }
                "metric" => {
                    metric =
                        DistanceMetric::from_name(value).map_err(|e| ConfigError(e.to_string()))?;
                }
                "m" => {
                    hnsw_params.m = value
                        .parse()
                        .map_err(|_| ConfigError(format!("invalid m: {value}")))?;
                }
                "ef_construction" => {
                    hnsw_params.ef_construction = value
                        .parse()
                        .map_err(|_| ConfigError(format!("invalid ef_construction: {value}")))?;
                }
                "ef_search" => {
                    hnsw_params.ef_search = value
                        .parse()
                        .map_err(|_| ConfigError(format!("invalid ef_search: {value}")))?;
                }
                "metadata" => {
                    metadata_columns = parse_metadata_columns(value)?;
                }
                other => {
                    return Err(ConfigError(format!("unknown parameter: {other}")));
                }
            }
        }

        let dim = dim.ok_or_else(|| ConfigError("dim is required".into()))?;

        Ok(Self {
            db_name,
            table_name,
            dim,
            vtype,
            metric,
            hnsw_params,
            metadata_columns,
        })
    }

    pub fn vtab_schema(&self) -> String {
        let mut cols = vec![
            "id INTEGER PRIMARY KEY".to_string(),
            "vector BLOB".to_string(),
        ];
        for (name, sql_type) in &self.metadata_columns {
            cols.push(format!("{name} {sql_type}"));
        }
        cols.push("distance REAL HIDDEN".to_string());
        format!("CREATE TABLE x({})", cols.join(", "))
    }
}

fn parse_metadata_columns(spec: &str) -> Result<Vec<(String, String)>, ConfigError> {
    let mut columns = Vec::new();
    for part in spec.split(',') {
        let part = part.trim();
        if part.is_empty() {
            continue;
        }
        let mut tokens = part.split_whitespace();
        let name = tokens
            .next()
            .ok_or_else(|| ConfigError("empty metadata column definition".to_string()))?
            .to_string();
        let sql_type = tokens
            .next()
            .ok_or_else(|| ConfigError(format!("missing type for metadata column {name}")))?
            .to_string();
        columns.push((name, sql_type));
    }
    Ok(columns)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::distance::DistanceMetric;
    use crate::types::VectorType;

    // ----------------------------------------------------------------
    // parse — minimal args (dim only, defaults for type/metric)
    // ----------------------------------------------------------------

    #[test]
    fn parse_minimal_args_uses_defaults() {
        let cfg = VectorTableConfig::parse(&["vector", "main", "test", "dim=3"]).unwrap();
        assert_eq!(cfg.db_name, "main");
        assert_eq!(cfg.table_name, "test");
        assert_eq!(cfg.dim, 3);
        assert_eq!(cfg.vtype, VectorType::Float4);
        assert_eq!(cfg.metric, DistanceMetric::L2);
        // HNSW defaults
        assert_eq!(cfg.hnsw_params.m, 16);
        assert_eq!(cfg.hnsw_params.ef_construction, 200);
        assert_eq!(cfg.hnsw_params.ef_search, 64);
        assert!(cfg.metadata_columns.is_empty());
    }

    // ----------------------------------------------------------------
    // parse — all parameters specified
    // ----------------------------------------------------------------

    #[test]
    fn parse_all_params_specified() {
        let cfg = VectorTableConfig::parse(&[
            "vector",
            "main",
            "embeddings",
            "dim=128",
            "type=float4",
            "metric=l2",
            "m=32",
            "ef_construction=400",
            "ef_search=128",
        ])
        .unwrap();
        assert_eq!(cfg.dim, 128);
        assert_eq!(cfg.vtype, VectorType::Float4);
        assert_eq!(cfg.metric, DistanceMetric::L2);
        assert_eq!(cfg.hnsw_params.m, 32);
        assert_eq!(cfg.hnsw_params.ef_construction, 400);
        assert_eq!(cfg.hnsw_params.ef_search, 128);
    }

    // ----------------------------------------------------------------
    // parse — type=float8, metric=cosine
    // ----------------------------------------------------------------

    #[test]
    fn parse_float8_cosine() {
        let cfg = VectorTableConfig::parse(&[
            "vector",
            "main",
            "vecs",
            "dim=64",
            "type=float8",
            "metric=cosine",
        ])
        .unwrap();
        assert_eq!(cfg.vtype, VectorType::Float8);
        assert_eq!(cfg.metric, DistanceMetric::Cosine);
    }

    // ----------------------------------------------------------------
    // parse — custom HNSW params (m, ef_construction, ef_search)
    // ----------------------------------------------------------------

    #[test]
    fn parse_custom_hnsw_params() {
        let cfg = VectorTableConfig::parse(&[
            "vector",
            "main",
            "idx",
            "dim=16",
            "m=8",
            "ef_construction=100",
            "ef_search=50",
        ])
        .unwrap();
        assert_eq!(cfg.hnsw_params.m, 8);
        assert_eq!(cfg.hnsw_params.ef_construction, 100);
        assert_eq!(cfg.hnsw_params.ef_search, 50);
    }

    // ----------------------------------------------------------------
    // parse — metadata columns
    // ----------------------------------------------------------------

    #[test]
    fn parse_metadata_columns_parsed_correctly() {
        let cfg = VectorTableConfig::parse(&[
            "vector",
            "main",
            "docs",
            "dim=4",
            "metadata=label TEXT,score REAL",
        ])
        .unwrap();
        assert_eq!(cfg.metadata_columns.len(), 2);
        assert_eq!(
            cfg.metadata_columns[0],
            ("label".to_string(), "TEXT".to_string())
        );
        assert_eq!(
            cfg.metadata_columns[1],
            ("score".to_string(), "REAL".to_string())
        );
    }

    // ----------------------------------------------------------------
    // parse errors — too few args (< 3)
    // ----------------------------------------------------------------

    #[test]
    fn parse_error_too_few_args_zero() {
        let err = VectorTableConfig::parse(&[]).unwrap_err();
        assert!(
            err.0.contains("at least"),
            "unexpected error message: {}",
            err.0
        );
    }

    #[test]
    fn parse_error_too_few_args_two() {
        // Only module + db name; table name is absent.
        let err = VectorTableConfig::parse(&["vector", "main"]).unwrap_err();
        assert!(
            err.0.contains("at least"),
            "unexpected error message: {}",
            err.0
        );
    }

    // ----------------------------------------------------------------
    // parse errors — missing dim
    // ----------------------------------------------------------------

    #[test]
    fn parse_error_missing_dim() {
        let err = VectorTableConfig::parse(&["vector", "main", "tbl", "type=float4"]).unwrap_err();
        assert!(
            err.0.contains("dim"),
            "expected error mentioning 'dim', got: {}",
            err.0
        );
    }

    // ----------------------------------------------------------------
    // parse errors — invalid dim values
    // ----------------------------------------------------------------

    #[test]
    fn parse_error_dim_zero() {
        let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=0"]).unwrap_err();
        assert!(
            err.0.contains("positive") || err.0.contains("dim"),
            "unexpected error: {}",
            err.0
        );
    }

    #[test]
    fn parse_error_dim_negative() {
        let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=-5"]).unwrap_err();
        assert!(
            err.0.contains("positive") || err.0.contains("dim"),
            "unexpected error: {}",
            err.0
        );
    }

    #[test]
    fn parse_error_dim_non_numeric() {
        let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=abc"]).unwrap_err();
        assert!(
            err.0.contains("dim"),
            "expected error mentioning 'dim', got: {}",
            err.0
        );
    }

    // ----------------------------------------------------------------
    // parse errors — unknown parameter
    // ----------------------------------------------------------------

    #[test]
    fn parse_error_unknown_parameter() {
        let err =
            VectorTableConfig::parse(&["vector", "main", "tbl", "dim=4", "foo=bar"]).unwrap_err();
        assert!(
            err.0.contains("unknown") && err.0.contains("foo"),
            "unexpected error: {}",
            err.0
        );
    }

    // ----------------------------------------------------------------
    // parse errors — arg without '='
    // ----------------------------------------------------------------

    #[test]
    fn parse_error_arg_without_equals() {
        let err = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=4", "invalidarg"])
            .unwrap_err();
        assert!(
            err.0.contains("invalid argument") || err.0.contains("invalidarg"),
            "unexpected error: {}",
            err.0
        );
    }

    // ----------------------------------------------------------------
    // vtab_schema — no metadata
    // ----------------------------------------------------------------

    #[test]
    fn vtab_schema_no_metadata() {
        let cfg = VectorTableConfig::parse(&["vector", "main", "tbl", "dim=3"]).unwrap();
        let schema = cfg.vtab_schema();
        assert_eq!(
            schema,
            "CREATE TABLE x(id INTEGER PRIMARY KEY, vector BLOB, distance REAL HIDDEN)"
        );
    }

    // ----------------------------------------------------------------
    // vtab_schema — with metadata columns
    // ----------------------------------------------------------------

    #[test]
    fn vtab_schema_with_metadata_columns_before_distance() {
        let cfg = VectorTableConfig::parse(&[
            "vector",
            "main",
            "tbl",
            "dim=4",
            "metadata=label TEXT,score REAL",
        ])
        .unwrap();
        let schema = cfg.vtab_schema();
        assert_eq!(
            schema,
            "CREATE TABLE x(id INTEGER PRIMARY KEY, vector BLOB, label TEXT, score REAL, distance REAL HIDDEN)"
        );
        // Verify ordering: metadata must appear before distance.
        let label_pos = schema.find("label").unwrap();
        let distance_pos = schema.find("distance").unwrap();
        assert!(
            label_pos < distance_pos,
            "metadata columns must precede distance in schema"
        );
    }
}