sqlrite 1.0.2

RAG-oriented SQLite wrapper for AI agent workloads
Documentation
use sqlrite::{
    CustomHttpEmbeddingProvider, DeterministicEmbeddingProvider, OpenAiCompatibleEmbeddingProvider,
    ReindexOptions, RuntimeConfig, SqlRite, reindex_embeddings,
};
use std::path::PathBuf;

fn main() -> Result<(), Box<dyn std::error::Error>> {
    let args =
        parse_args(std::env::args().skip(1).collect::<Vec<_>>()).map_err(std::io::Error::other)?;
    let db = SqlRite::open_with_config(&args.db_path, RuntimeConfig::default())?;

    let mut options = ReindexOptions {
        batch_size: args.batch_size,
        tenant_id: args.tenant,
        target_model_version: args.target_model_version.clone(),
        only_if_model_mismatch: !args.force_all,
        continue_on_partial_failure: args.continue_on_partial_failure,
        checkpoint_path: args.checkpoint_path,
        ..ReindexOptions::default()
    };

    let report = match args.provider.as_str() {
        "openai" => {
            let endpoint = args.endpoint.ok_or_else(|| {
                std::io::Error::other("--endpoint is required for openai provider")
            })?;
            let model = args
                .model
                .ok_or_else(|| std::io::Error::other("--model is required for openai provider"))?;
            let api_key_env = args
                .api_key_env
                .unwrap_or_else(|| "OPENAI_API_KEY".to_string());
            let provider =
                OpenAiCompatibleEmbeddingProvider::from_env(endpoint, model, &api_key_env)?;
            reindex_embeddings(&db, provider, options)?
        }
        "custom" => {
            let endpoint = args.endpoint.ok_or_else(|| {
                std::io::Error::other("--endpoint is required for custom provider")
            })?;
            let mut provider =
                CustomHttpEmbeddingProvider::new(endpoint, &args.target_model_version)?;
            if let Some(model) = args.model {
                provider = provider.with_model(model);
            }
            for (key, value) in &args.headers {
                provider = provider.with_header(key.clone(), value.clone());
            }
            if let (Some(input_field), Some(embeddings_field)) =
                (args.input_field, args.embeddings_field)
            {
                provider = provider.with_fields(input_field, embeddings_field);
            }
            reindex_embeddings(&db, provider, options)?
        }
        _ => {
            let embedding_dim = db
                .vector_index_stats()
                .and_then(|stats| stats.dimension)
                .unwrap_or(args.embedding_dim);
            options.target_model_version = args.target_model_version;
            let provider = DeterministicEmbeddingProvider::new(
                embedding_dim,
                options.target_model_version.clone(),
            )?;
            reindex_embeddings(&db, provider, options)?
        }
    };

    println!("reindex complete");
    println!(
        "scanned={}, updated={}, skipped={}, failed={}, resumed_from={}",
        report.scanned_chunks,
        report.updated_chunks,
        report.skipped_chunks,
        report.failed_chunks,
        report.resumed_from_offset
    );
    println!(
        "provider={} model={}",
        report.provider, report.model_version
    );
    Ok(())
}

#[derive(Debug)]
struct Args {
    db_path: PathBuf,
    provider: String,
    endpoint: Option<String>,
    model: Option<String>,
    api_key_env: Option<String>,
    input_field: Option<String>,
    embeddings_field: Option<String>,
    headers: Vec<(String, String)>,
    target_model_version: String,
    embedding_dim: usize,
    batch_size: usize,
    tenant: Option<String>,
    force_all: bool,
    continue_on_partial_failure: bool,
    checkpoint_path: Option<PathBuf>,
}

fn parse_args(args: Vec<String>) -> Result<Args, String> {
    let mut out = Args {
        db_path: PathBuf::from("sqlrite_demo.db"),
        provider: "deterministic".to_string(),
        endpoint: None,
        model: None,
        api_key_env: None,
        input_field: None,
        embeddings_field: None,
        headers: Vec::new(),
        target_model_version: "det-v2".to_string(),
        embedding_dim: 256,
        batch_size: 256,
        tenant: None,
        force_all: false,
        continue_on_partial_failure: false,
        checkpoint_path: None,
    };

    let mut i = 0;
    while i < args.len() {
        match args[i].as_str() {
            "--db" => {
                i += 1;
                out.db_path = PathBuf::from(parse_string(&args, i, "--db")?);
            }
            "--provider" => {
                i += 1;
                out.provider = parse_string(&args, i, "--provider")?;
            }
            "--endpoint" => {
                i += 1;
                out.endpoint = Some(parse_string(&args, i, "--endpoint")?);
            }
            "--model" => {
                i += 1;
                out.model = Some(parse_string(&args, i, "--model")?);
            }
            "--api-key-env" => {
                i += 1;
                out.api_key_env = Some(parse_string(&args, i, "--api-key-env")?);
            }
            "--input-field" => {
                i += 1;
                out.input_field = Some(parse_string(&args, i, "--input-field")?);
            }
            "--embeddings-field" => {
                i += 1;
                out.embeddings_field = Some(parse_string(&args, i, "--embeddings-field")?);
            }
            "--header" => {
                i += 1;
                let raw = parse_string(&args, i, "--header")?;
                let Some((key, value)) = raw.split_once(':') else {
                    return Err(format!("invalid --header `{raw}`, expected key:value"));
                };
                out.headers
                    .push((key.trim().to_string(), value.trim().to_string()));
            }
            "--target-model-version" => {
                i += 1;
                out.target_model_version = parse_string(&args, i, "--target-model-version")?;
            }
            "--embedding-dim" => {
                i += 1;
                out.embedding_dim = parse_usize(&args, i, "--embedding-dim")?;
            }
            "--batch-size" => {
                i += 1;
                out.batch_size = parse_usize(&args, i, "--batch-size")?;
            }
            "--tenant" => {
                i += 1;
                out.tenant = Some(parse_string(&args, i, "--tenant")?);
            }
            "--checkpoint" => {
                i += 1;
                out.checkpoint_path = Some(PathBuf::from(parse_string(&args, i, "--checkpoint")?));
            }
            "--force-all" => {
                out.force_all = true;
            }
            "--continue-on-partial-failure" => {
                out.continue_on_partial_failure = true;
            }
            "--help" | "-h" => return Err(usage()),
            other => return Err(format!("unknown argument `{other}`\n{}", usage())),
        }
        i += 1;
    }

    Ok(out)
}

fn parse_string(args: &[String], index: usize, flag: &str) -> Result<String, String> {
    args.get(index)
        .cloned()
        .ok_or_else(|| format!("missing value for {flag}\n{}", usage()))
}

fn parse_usize(args: &[String], index: usize, flag: &str) -> Result<usize, String> {
    let raw = parse_string(args, index, flag)?;
    raw.parse::<usize>()
        .map_err(|_| format!("invalid integer for {flag}: `{raw}`\n{}", usage()))
}

fn usage() -> String {
    "usage: cargo run --bin sqlrite-reindex -- [--db PATH] [--provider deterministic|openai|custom] [--endpoint URL] [--model MODEL] [--api-key-env ENV] [--input-field FIELD] [--embeddings-field FIELD] [--header key:value] [--target-model-version VERSION] [--embedding-dim N] [--batch-size N] [--tenant TENANT] [--checkpoint PATH] [--force-all] [--continue-on-partial-failure]".to_string()
}