kiromi-ai-memory 0.2.2

Local-first multi-tenant memory store engine: Markdown/text content on object storage, metadata in SQLite, plugin-shaped embedder/storage/metadata, hybrid text+vector search.
Documentation
// SPDX-License-Identifier: Apache-2.0 OR MIT
//! Embedder plugin trait + registry.

/// Sentinel id written to `schema_meta.embedder_id` (and to every
/// `memory.embedder_id` row) when the engine is opened without a configured
/// `Embedder`. Persisted along with the dim of the first caller-supplied
/// vector so downstream readers can detect the no-embedder configuration.
///
/// The angle-brackets disambiguate this sentinel from any real
/// `<family>:<model>:<version>` id a third-party embedder might publish.
///
/// See spec § 12.6.
pub const CALLER_PROVIDED_EMBEDDER_ID: &str = "<caller-provided>";

use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use async_trait::async_trait;

use crate::error::{Error, Result};

/// Which side of an asymmetric embedder the call is on.
///
/// Many production embedding models (E5, BGE, GTE, Jina families) use different
/// prefixes / encoders for the document side (the indexed corpus) and the query
/// side (the user's prompt). Symmetric models (CLIP, SBERT, hash-mock) ignore
/// this and produce the same output for both.
///
/// See spec § 12.12.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum EmbedRole {
    /// The text is being indexed (the `append` path).
    Document,
    /// The text is a user-supplied query (the `search` / `related` path).
    Query,
}

impl EmbedRole {
    /// Stable string tag for tracing / metrics.
    #[must_use]
    pub fn as_str(&self) -> &'static str {
        match self {
            EmbedRole::Document => "document",
            EmbedRole::Query => "query",
        }
    }
}

/// Embedder trait. See spec § 12 for the contract every implementation must satisfy.
#[async_trait]
pub trait Embedder: Send + Sync + fmt::Debug + 'static {
    /// Stable identifier persisted in `schema_meta.embedder_id`.
    /// Convention: `"<family>:<model>:<version>"`.
    fn id(&self) -> &str;

    /// Output vector dimensionality. Constant for the lifetime of the instance.
    fn dimensions(&self) -> usize;

    /// Recommended max input length in tokens (best-effort).
    fn max_input_tokens(&self) -> Option<usize> {
        None
    }

    /// Self-declared capabilities. Default: symmetric, batched, no token cap.
    fn capabilities(&self) -> crate::capabilities::EmbedderCapabilities {
        crate::capabilities::EmbedderCapabilities::default()
    }

    /// Embed a batch under a given role. Symmetric models ignore `role`;
    /// asymmetric models (E5, BGE, GTE, Jina families) use it to choose
    /// the right prefix / encoder. See spec §§ 12.1, 12.2 (rule 7), 12.12.
    async fn embed(&self, role: EmbedRole, texts: &[&str]) -> Result<Vec<Vec<f32>>>;

    /// Optional diagnostic config snapshot.
    fn describe(&self) -> serde_json::Value {
        serde_json::Value::Null
    }
}

type BoxFut<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
type FactoryFn = Arc<
    dyn Fn(serde_json::Value) -> BoxFut<'static, Result<Box<dyn Embedder>>> + Send + Sync + 'static,
>;

/// Config-driven registry. CLI / server use it; programmatic callers pass an
/// embedder directly to the builder.
#[derive(Default, Clone)]
pub struct EmbedderRegistry {
    factories: HashMap<String, FactoryFn>,
}

impl fmt::Debug for EmbedderRegistry {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("EmbedderRegistry")
            .field("families", &self.factories.keys().collect::<Vec<_>>())
            .finish()
    }
}

impl EmbedderRegistry {
    /// Empty registry.
    #[must_use]
    pub fn empty() -> Self {
        Self::default()
    }

    /// Register a factory under a family name.
    pub fn register<F, Fut>(&mut self, family: &'static str, factory: F)
    where
        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = Result<Box<dyn Embedder>>> + Send + 'static,
    {
        let f: FactoryFn = Arc::new(move |v| Box::pin(factory(v)));
        self.factories.insert(family.to_string(), f);
    }

    /// Build an embedder from a family + JSON config. Returns
    /// `Error::Config` if the family isn't registered.
    pub async fn build(
        &self,
        family: &str,
        config: serde_json::Value,
    ) -> Result<Box<dyn Embedder>> {
        let factory = self
            .factories
            .get(family)
            .ok_or_else(|| Error::Config(format!("embedder family {family:?} not registered")))?;
        factory(config).await
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[derive(Debug)]
    struct Fake;
    #[async_trait]
    impl Embedder for Fake {
        fn id(&self) -> &str {
            "fake:1"
        }
        fn dimensions(&self) -> usize {
            4
        }
        async fn embed(&self, _role: EmbedRole, t: &[&str]) -> Result<Vec<Vec<f32>>> {
            Ok(t.iter().map(|_| vec![0.0; 4]).collect())
        }
    }

    #[tokio::test]
    async fn registry_round_trip() {
        let mut r = EmbedderRegistry::empty();
        r.register("fake", |_cfg| async {
            Ok(Box::new(Fake) as Box<dyn Embedder>)
        });
        let e = r.build("fake", serde_json::Value::Null).await.unwrap();
        assert_eq!(e.id(), "fake:1");
        assert_eq!(e.dimensions(), 4);
        let out = e.embed(EmbedRole::Document, &["a", "b"]).await.unwrap();
        assert_eq!(out.len(), 2);
    }

    #[tokio::test]
    async fn embed_role_round_trip() {
        let e = Fake;
        let d = e.embed(EmbedRole::Document, &["x"]).await.unwrap();
        let q = e.embed(EmbedRole::Query, &["x"]).await.unwrap();
        // Symmetric stub — equal vectors for either role.
        assert_eq!(d, q);
    }

    #[tokio::test]
    async fn registry_unknown_family() {
        let r = EmbedderRegistry::empty();
        let err = r.build("nope", serde_json::Value::Null).await.unwrap_err();
        assert!(matches!(err, Error::Config(_)));
    }
}