kiromi-ai-memory 0.2.2

Local-first multi-tenant memory store engine: Markdown/text content on object storage, metadata in SQLite, plugin-shaped embedder/storage/metadata, hybrid text+vector search.
Documentation
// SPDX-License-Identifier: Apache-2.0 OR MIT
//! Plan 18 phase C — default `VectorIndex` implementation backed by
//! `sqlite-vec`'s `vec0` virtual table.
//!
//! Shares the `SqlitePool` with `SqliteMetadata` so vector ops can run inside
//! the same SQL transaction as the catalog row writes (the headline
//! "single-transaction append" property of Plan 18).
//!
//! # Atomicity
//!
//! Each `upsert_*` is a DELETE followed by an INSERT against the underlying
//! `vec0` virtual table — `INSERT OR REPLACE` does not seed ANN segments
//! cleanly on existing rows, so the explicit DELETE is required. Direct
//! invocations of these methods (i.e. without an enclosing SQL transaction)
//! do **not** atomically pair the index update with the catalog row.
//!
//! In the engine's authoritative writers
//! ([`crate::metadata::MetadataStore::append_memory`],
//! [`crate::metadata::MetadataStore::insert_summary`],
//! [`crate::metadata::MetadataStore::regenerate_memory_embedding`]) the
//! catalog and the vec0 row writes are folded into a single `tx.commit()`,
//! so the per-mutation single-transaction property of Plan 18 holds.
//!
//! Callers writing custom reindex tooling that calls into this trait
//! directly must therefore tolerate transient out-of-sync windows or wrap
//! their work in `MetadataStore::regenerate_memory_embedding` instead.

use std::str::FromStr;
use std::sync::Arc;

use async_trait::async_trait;
use sqlx::{Row, SqlitePool};

use crate::attribute::AttributeValue;
use crate::error::{Error, Result};
use crate::index::vector_trait::{
    DistanceMetric, VectorFilter, VectorIndex, VectorIndexCapabilities, VectorScope,
};
use crate::memory::{MemoryId, MemoryKind};
use crate::partition::PartitionPath;
use crate::summarizer::SummaryStyle;
use crate::summary::SummaryId;

/// Default `VectorIndex` impl for SQLite-backed stores.
#[derive(Debug, Clone)]
pub struct SqliteVecIndex {
    pool: Arc<SqlitePool>,
}

impl SqliteVecIndex {
    /// Build a new index over the supplied pool. The pool MUST already have
    /// the `vec0` extension auto-loaded (see
    /// `SqliteMetadata::ensure_sqlite_vec_loaded`).
    #[must_use]
    pub fn new(pool: Arc<SqlitePool>) -> Self {
        Self { pool }
    }
}

#[async_trait]
impl VectorIndex for SqliteVecIndex {
    async fn upsert_memory(
        &self,
        id: &MemoryId,
        partition_path: &PartitionPath,
        kind: Option<&MemoryKind>,
        embedding: &[f32],
    ) -> Result<()> {
        let blob: &[u8] = bytemuck::cast_slice(embedding);
        let kind_tag: Option<&'static str> = kind.map(|k| k.as_persisted_str());
        // sqlite-vec's vec0 virtual table does not honour `INSERT OR REPLACE`
        // on rows already present (the underlying ANN segments treat it as
        // a fresh insert and fail the PK uniqueness check). Always DELETE the
        // prior row first; idempotent on missing.
        sqlx::query("DELETE FROM memory_vec WHERE memory_id = ?")
            .bind(id.to_string())
            .execute(self.pool.as_ref())
            .await
            .map_err(|e| Error::metadata("SqliteVecIndex::upsert_memory delete", e))?;
        sqlx::query(
            "INSERT INTO memory_vec(memory_id, partition_path, kind, embedding) \
             VALUES (?, ?, ?, ?)",
        )
        .bind(id.to_string())
        .bind(partition_path.as_str())
        .bind(kind_tag)
        .bind(blob)
        .execute(self.pool.as_ref())
        .await
        .map_err(|e| Error::metadata("SqliteVecIndex::upsert_memory", e))?;
        Ok(())
    }

    async fn upsert_summary(
        &self,
        id: &SummaryId,
        parent_path: &str,
        style: &SummaryStyle,
        embedding: &[f32],
    ) -> Result<()> {
        let blob: &[u8] = bytemuck::cast_slice(embedding);
        sqlx::query("DELETE FROM summary_vec WHERE summary_id = ?")
            .bind(id.to_string())
            .execute(self.pool.as_ref())
            .await
            .map_err(|e| Error::metadata("SqliteVecIndex::upsert_summary delete", e))?;
        sqlx::query(
            "INSERT INTO summary_vec(summary_id, parent_path, style, embedding) \
             VALUES (?, ?, ?, ?)",
        )
        .bind(id.to_string())
        .bind(parent_path)
        .bind(style.as_str().as_ref())
        .bind(blob)
        .execute(self.pool.as_ref())
        .await
        .map_err(|e| Error::metadata("SqliteVecIndex::upsert_summary", e))?;
        Ok(())
    }

    async fn delete_memory(&self, id: &MemoryId) -> Result<()> {
        sqlx::query("DELETE FROM memory_vec WHERE memory_id = ?")
            .bind(id.to_string())
            .execute(self.pool.as_ref())
            .await
            .map_err(|e| Error::metadata("SqliteVecIndex::delete_memory", e))?;
        Ok(())
    }

    async fn delete_summary(&self, id: &SummaryId) -> Result<()> {
        sqlx::query("DELETE FROM summary_vec WHERE summary_id = ?")
            .bind(id.to_string())
            .execute(self.pool.as_ref())
            .await
            .map_err(|e| Error::metadata("SqliteVecIndex::delete_summary", e))?;
        Ok(())
    }

    async fn knn_memory(
        &self,
        query: &[f32],
        k: u32,
        scope: VectorScope,
        filter: Option<&VectorFilter>,
    ) -> Result<Vec<(MemoryId, f32)>> {
        let qblob: &[u8] = bytemuck::cast_slice(query);

        // Build the SQL based on scope + filter.
        // sqlite-vec idiomatic syntax: WHERE embedding MATCH ? AND k = ?
        // ORDER BY distance.
        let scope_clause: (&'static str, Vec<String>) = match &scope {
            VectorScope::Tenant => ("", Vec::new()),
            VectorScope::Partition(p) => {
                (" AND mv.partition_path = ?", vec![p.as_str().to_string()])
            }
            VectorScope::PartitionPrefix(prefix) => (
                " AND (mv.partition_path = ? OR mv.partition_path LIKE ?)",
                vec![prefix.clone(), format!("{prefix}/%")],
            ),
        };

        let (filter_join, filter_where, filter_binds) = match filter {
            None => ("", "", Vec::<String>::new()),
            Some(f) => {
                let (col, val_str) = attribute_to_filter_pair(&f.value);
                (
                    " JOIN memory_attribute ma ON ma.memory_id = mv.memory_id",
                    match col {
                        "v_string" => " AND ma.key = ? AND ma.v_string = ?",
                        "v_int" => " AND ma.key = ? AND ma.v_int = ?",
                        "v_decimal" => " AND ma.key = ? AND ma.v_decimal = ?",
                        "v_timestamp" => " AND ma.key = ? AND ma.v_timestamp = ?",
                        "v_bool" => " AND ma.key = ? AND ma.v_bool = ?",
                        _ => " AND ma.key = ? AND ma.v_string = ?",
                    },
                    vec![f.key.clone(), val_str],
                )
            }
        };

        let sql = format!(
            "SELECT mv.memory_id AS id, distance \
             FROM memory_vec mv{filter_join} \
             WHERE mv.embedding MATCH ? AND k = ?{scope_where}{filter_where} \
             ORDER BY distance",
            filter_join = filter_join,
            scope_where = scope_clause.0,
            filter_where = filter_where,
        );

        let mut q = sqlx::query(&sql).bind(qblob).bind(i64::from(k));
        for s in &scope_clause.1 {
            q = q.bind(s);
        }
        for s in &filter_binds {
            q = q.bind(s);
        }

        let rows = q
            .fetch_all(self.pool.as_ref())
            .await
            .map_err(|e| Error::metadata("SqliteVecIndex::knn_memory", e))?;

        let mut out = Vec::with_capacity(rows.len());
        for row in rows {
            let id_s: String = row
                .try_get("id")
                .map_err(|e| Error::metadata("read memory_id", e))?;
            let dist: f32 =
                row.try_get::<f64, _>("distance")
                    .map_err(|e| Error::metadata("read distance", e))? as f32;
            let id = MemoryId::from_str(&id_s)
                .map_err(|_| Error::metadata("parse memory_id", std::io::Error::other("bad id")))?;
            out.push((id, dist));
        }
        Ok(out)
    }

    async fn knn_summary(
        &self,
        query: &[f32],
        k: u32,
        parent_path_prefix: &str,
    ) -> Result<Vec<(SummaryId, f32)>> {
        let qblob: &[u8] = bytemuck::cast_slice(query);
        let prefix_eq = parent_path_prefix.to_string();
        let prefix_like = format!("{parent_path_prefix}/%");

        let sql = "SELECT summary_id AS id, distance \
                   FROM summary_vec \
                   WHERE embedding MATCH ? AND k = ? \
                     AND (parent_path = ? OR parent_path LIKE ?) \
                   ORDER BY distance";

        let rows = sqlx::query(sql)
            .bind(qblob)
            .bind(i64::from(k))
            .bind(&prefix_eq)
            .bind(&prefix_like)
            .fetch_all(self.pool.as_ref())
            .await
            .map_err(|e| Error::metadata("SqliteVecIndex::knn_summary", e))?;

        let mut out = Vec::with_capacity(rows.len());
        for row in rows {
            let id_s: String = row
                .try_get("id")
                .map_err(|e| Error::metadata("read summary_id", e))?;
            let dist: f32 =
                row.try_get::<f64, _>("distance")
                    .map_err(|e| Error::metadata("read distance", e))? as f32;
            let id = SummaryId::from_str(&id_s).map_err(|_| {
                Error::metadata("parse summary_id", std::io::Error::other("bad id"))
            })?;
            out.push((id, dist));
        }
        Ok(out)
    }

    fn id(&self) -> &str {
        "sqlite-vec:vec0"
    }

    fn capabilities(&self) -> VectorIndexCapabilities {
        VectorIndexCapabilities {
            knn_filtered: true,
            max_dimensions: 4096,
            distance_metric: DistanceMetric::CosineDistance,
        }
    }
}

/// Render an attribute value to the `(column, string)` pair used by the JOIN.
/// `Array` values are stringified into `v_string` (best-effort — the engine
/// does not yet support querying arrays by element).
fn attribute_to_filter_pair(v: &AttributeValue) -> (&'static str, String) {
    match v {
        AttributeValue::String(s) => ("v_string", s.clone()),
        AttributeValue::Int(i) => ("v_int", i.to_string()),
        AttributeValue::Decimal(d) => ("v_decimal", d.to_string()),
        AttributeValue::Bool(b) => ("v_bool", if *b { "1".into() } else { "0".into() }),
        AttributeValue::Timestamp(t) => ("v_timestamp", t.to_string()),
        AttributeValue::Array(_) => ("v_string", serde_json::to_string(v).unwrap_or_default()),
    }
}