vein-database 0.1.0

Database layer for Vein - shared memory system for AI agents and tools
Documentation
use anyhow::{bail, Context, Result};
use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use futures_util::StreamExt;
use lancedb::{index::Index, query::ExecutableQuery, query::QueryBase, Table};
use std::sync::Arc;
use std::path::PathBuf;
use crate::paths::Paths;

const MIN_ROWS_FOR_INDEX: usize = 300;

#[derive(Debug, Clone)]
pub struct Entry {
    pub id: String,
    pub content: String,
}

fn get_db_path() -> PathBuf {
    Paths::get_insights_db()
}

fn validate_id(id: &str) -> Result<()> {
    if id.is_empty() || id.len() > 100 {
        bail!("Invalid id: must be non-empty and <= 100 characters");
    }
    if id.contains('\'') || id.contains(';') || id.contains('"') {
        bail!("Invalid id: contains forbidden characters");
    }
    Ok(())
}

pub struct LanceDb {
    table: Table,
    indexed: bool,
    vector_dim: usize,
}

impl LanceDb {
    pub async fn new(table_name: &str, vector_dim: usize) -> Result<Self> {
        let db_path = get_db_path();
        let db = lancedb::connect(db_path.to_str().context("Invalid DB path")?)
            .execute()
            .await?;

        let schema = Schema::new(vec![
            Field::new("id", DataType::Utf8, false),
            Field::new("content", DataType::Utf8, false),
            Field::new(
                "vector",
                DataType::FixedSizeList(
                    Arc::new(Field::new("item", DataType::Float32, false)),
                    vector_dim as i32,
                ),
                false,
            ),
        ]);

        let schema_ref = Arc::new(schema);

        let names = db.table_names().execute().await?;
        let (table, indexed) = if names.contains(&table_name.to_string()) {
            let tbl = db.open_table(table_name).execute().await?;
            let existing_schema = tbl.schema().await?;

            // Check if existing schema matches requested vector dimension
            if let Ok(vector_field) = existing_schema.field_with_name("vector")
                && let DataType::FixedSizeList(_, existing_dim) = vector_field.data_type()
                && *existing_dim as usize != vector_dim {
                    bail!(
                        "Table '{}' exists with vector dimension {} but current embedding model produces dimension {}. \
                         Please delete the table and recreate it to switch models.",
                        table_name, existing_dim, vector_dim
                    );
            }

            let indices = tbl.list_indices().await?;
            let indexed = !indices.is_empty();
            (tbl, indexed)
        } else {
            let tbl = db
                .create_table(table_name, vec![RecordBatch::new_empty(schema_ref.clone())])
                .execute()
                .await?;
            (tbl, false)
        };

        Ok(Self { table, indexed, vector_dim })
    }

    pub async fn post(&self, id: &str, content: &str, vector: Vec<f32>) -> Result<String> {
        if vector.len() != self.vector_dim {
            bail!("vector dimension must be {}, got {}", self.vector_dim, vector.len());
        }
        validate_id(id)?;

        if self.exists_by_content(content).await? {
            return Ok(id.to_string());
        }

        let schema = self.table.schema().await?;

        let vector_array = FixedSizeListArray::try_new(
            Arc::new(Field::new("item", DataType::Float32, true)),
            self.vector_dim as i32,
            Arc::new(Float32Array::from(vector)),
            None,
        )?;

        let batch = RecordBatch::try_new(
            schema,
            vec![
                Arc::new(StringArray::from(vec![id.to_string()])),
                Arc::new(StringArray::from(vec![content])),
                Arc::new(vector_array),
            ],
        )?;

        self.table.add(vec![batch]).execute().await?;

        if !self.indexed {
            let count = self.table.count_rows(None).await?;
            if count >= MIN_ROWS_FOR_INDEX {
                self.table
                    .create_index(&["vector"], Index::Auto)
                    .execute()
                    .await?;
            }
        }

        Ok(id.to_string())
    }

    pub async fn get(&self, query_vector: &[f32], limit: usize) -> Result<Vec<Entry>> {
        if query_vector.len() != self.vector_dim {
            bail!("query vector dimension must be {}, got {}", self.vector_dim, query_vector.len());
        }

        let stream = self
            .table
            .query()
            .nearest_to(query_vector)?
            .limit(limit)
            .execute()
            .await?;

        let mut entries = Vec::new();
        let mut stream = stream;

        while let Some(batch_result) = stream.next().await {
            let batch: RecordBatch = batch_result?;
            let id_array = batch.column(0);
            let content_array = batch.column(1);

            for i in 0..batch.num_rows() {
                let id = id_array
                    .as_any()
                    .downcast_ref::<StringArray>()
                    .map(|arr| arr.value(i).to_string())
                    .unwrap_or_default();
                let content = content_array
                    .as_any()
                    .downcast_ref::<StringArray>()
                    .map(|arr| arr.value(i).to_string())
                    .unwrap_or_default();

                entries.push(Entry { id, content });
            }
        }

        Ok(entries)
    }

    pub async fn exists_by_content(&self, content: &str) -> Result<bool> {
        let escaped = content.replace('\'', "''");
        let count = self.table.count_rows(Some(format!("content = '{}'", escaped))).await?;
        Ok(count > 0)
    }

    pub async fn patch(&self, id: &str, new_content: &str, new_vector: Vec<f32>) -> Result<()> {
        if new_vector.len() != self.vector_dim {
            bail!("vector dimension must be {}, got {}", self.vector_dim, new_vector.len());
        }

        if self.exists_by_content(new_content).await? {
            return Ok(());
        }

        self.delete(id).await?;
        self.post(id, new_content, new_vector).await?;
        Ok(())
    }

    pub async fn delete(&self, id: &str) -> Result<()> {
        validate_id(id)?;
        self.table.delete(&format!("id = '{}'", id)).await?;
        Ok(())
    }

    pub async fn rebuild_index(&self) -> Result<()> {
        self.table
            .create_index(&["vector"], Index::Auto)
            .execute()
            .await?;
        Ok(())
    }
}