yantrikdb-server 0.8.9

YantrikDB database server — multi-tenant cognitive memory with wire protocol, HTTP gateway, replication, auto-failover, and at-rest encryption
//! Built-in embedding via fastembed (all-MiniLM-L6-v2, ONNX).
//!
//! Implements the yantrikdb-core `Embedder` trait so engines can
//! auto-embed text without client-provided vectors.
//!
//! ## Caching layer
//!
//! Every embedding call goes through [`crate::cache::EmbeddingCache`]
//! BEFORE the ONNX model. On cache hit, we skip both the
//! `Mutex<TextEmbedding>` lock AND the ~10–30ms forward pass entirely.
//!
//! Why this matters: on the term=1423 incident traffic profile (5
//! concurrent pollers re-issuing the same query texts), most embed
//! calls are repeats. The cache turns "26 writes/sec serialized
//! through a ONNX mutex" into "cached calls return in microseconds,
//! only novel texts pay the model cost."
//!
//! Cache key includes the model version, so a future model upgrade
//! (RFC 013-B shadow index) automatically invalidates every entry on
//! lookup — no manual flush required.

use parking_lot::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;

use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};

use crate::cache::{EmbeddingCache, EmbeddingCacheConfig};

/// Stable identifier for the built-in model. Bumped if the underlying
/// fastembed `EmbeddingModel` enum value changes. Wrapped into
/// every cache key so an upgrade trivially invalidates stale entries.
const BUILTIN_MODEL_VERSION: &str = "all-MiniLM-L6-v2";

/// Inner model + cache shared across all engines.
struct FastEmbedInner {
    model: Mutex<TextEmbedding>,
    dim: usize,
    cache: EmbeddingCache,
    /// Stable string included in every cache key. Public via
    /// [`FastEmbedder::model_version`].
    model_version: &'static str,
    cache_hits: AtomicU64,
    cache_misses: AtomicU64,
}

/// Shareable embedder — clone this to give each engine its own `Box<dyn Embedder>`.
#[derive(Clone)]
pub struct FastEmbedder {
    inner: Arc<FastEmbedInner>,
}

impl FastEmbedder {
    pub fn new() -> anyhow::Result<Self> {
        Self::with_cache_config(EmbeddingCacheConfig::default())
    }

    /// Construct with a non-default cache size. Operators with high
    /// memory budgets can dial up `max_entries`; tests can shrink it.
    pub fn with_cache_config(cache_cfg: EmbeddingCacheConfig) -> anyhow::Result<Self> {
        tracing::info!("loading embedding model (all-MiniLM-L6-v2)...");

        // ort (the underlying ONNX Runtime binding) panics on dlopen failure
        // instead of returning an error. Catch the panic and convert it into
        // a clear actionable error so users don't see a raw stack trace on
        // first run.
        let model = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
            TextEmbedding::try_new(
                InitOptions::new(EmbeddingModel::AllMiniLML6V2)
                    .with_show_download_progress(true),
            )
        }))
        .map_err(|panic_info| {
            let msg = panic_info
                .downcast_ref::<String>()
                .cloned()
                .or_else(|| panic_info.downcast_ref::<&str>().map(|s| s.to_string()))
                .unwrap_or_else(|| "unknown panic".to_string());

            let hint = if msg.contains("dlopen") || msg.contains("libonnxruntime") || msg.contains("onnxruntime.dll") {
                "\n\n\
                ONNX Runtime library not found. The built-in embedder requires it.\n\
                \n\
                Linux:   wget https://github.com/microsoft/onnxruntime/releases/download/v1.24.4/onnxruntime-linux-x64-1.24.4.tgz\n\
                         tar xzf onnxruntime-linux-x64-1.24.4.tgz\n\
                         sudo cp onnxruntime-linux-x64-1.24.4/lib/libonnxruntime*.so* /usr/local/lib/\n\
                         sudo ldconfig\n\
                         export ORT_DYLIB_PATH=/usr/local/lib/libonnxruntime.so.1.24.4\n\
                \n\
                macOS:   brew install onnxruntime\n\
                \n\
                Windows: Download from https://github.com/microsoft/onnxruntime/releases\n\
                         and place onnxruntime.dll alongside the binary.\n\
                \n\
                Or use the Docker image (ghcr.io/yantrikos/yantrikdb) which bundles ONNX Runtime.\n\
                \n\
                Or skip the built-in embedder by setting [embedding] strategy = \"client_only\"\n\
                in your config file (you'll need to provide pre-computed embeddings to remember()).\n"
            } else {
                ""
            };

            anyhow::anyhow!(
                "embedder initialization failed: {}{}",
                msg,
                hint,
            )
        })??;

        tracing::info!(
            "embedding model loaded (384 dim, cache_capacity={})",
            cache_cfg.max_entries
        );

        Ok(Self {
            inner: Arc::new(FastEmbedInner {
                model: Mutex::new(model),
                dim: 384,
                cache: EmbeddingCache::new(cache_cfg),
                model_version: BUILTIN_MODEL_VERSION,
                cache_hits: AtomicU64::new(0),
                cache_misses: AtomicU64::new(0),
            }),
        })
    }

    /// Stable identifier of the loaded model. Useful for emitting
    /// `embedder_model{version}` Prometheus labels.
    pub fn model_version(&self) -> &'static str {
        self.inner.model_version
    }

    /// Cumulative cache hit count since startup. Pair with
    /// [`Self::cache_misses`] to compute hit rate.
    pub fn cache_hits(&self) -> u64 {
        self.inner.cache_hits.load(Ordering::Relaxed)
    }

    pub fn cache_misses(&self) -> u64 {
        self.inner.cache_misses.load(Ordering::Relaxed)
    }

    /// Convenience: hit fraction. Returns 0.0 if no calls yet.
    pub fn cache_hit_rate(&self) -> f64 {
        let h = self.cache_hits();
        let m = self.cache_misses();
        let total = h + m;
        if total == 0 {
            0.0
        } else {
            h as f64 / total as f64
        }
    }

    /// Create a boxed clone suitable for `YantrikDB::set_embedder()`.
    pub fn boxed(&self) -> Box<dyn yantrikdb::types::Embedder + Send + Sync> {
        Box::new(self.clone())
    }
}

impl yantrikdb::types::Embedder for FastEmbedder {
    fn embed(
        &self,
        text: &str,
    ) -> std::result::Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
        // Cache check FIRST — cheap (~hash + hashmap lookup) and skips
        // both the Mutex contention and the ONNX forward pass.
        if let Some(v) = self
            .inner
            .cache
            .get_for_text(text, self.inner.model_version)
        {
            self.inner.cache_hits.fetch_add(1, Ordering::Relaxed);
            return Ok(v);
        }
        self.inner.cache_misses.fetch_add(1, Ordering::Relaxed);

        // Cache miss: take the model lock, compute, release before
        // populating the cache (cache put doesn't need the model lock).
        let computed = {
            let mut model = self.inner.model.lock();
            let mut results = model.embed(vec![text], None)?;
            results
                .pop()
                .ok_or_else(|| -> Box<dyn std::error::Error + Send + Sync> {
                    "empty embedding result".into()
                })?
        };
        self.inner
            .cache
            .put_for_text(text, self.inner.model_version, computed.clone());
        Ok(computed)
    }

    fn embed_batch(
        &self,
        texts: &[&str],
    ) -> std::result::Result<Vec<Vec<f32>>, Box<dyn std::error::Error + Send + Sync>> {
        let model_ver = self.inner.model_version;

        // First pass: separate cache hits from misses. We preserve
        // original order via `results: Vec<Option<Vec<f32>>>` indexed
        // by the input order.
        let mut results: Vec<Option<Vec<f32>>> = Vec::with_capacity(texts.len());
        let mut miss_indices: Vec<usize> = Vec::new();
        let mut miss_texts: Vec<String> = Vec::new();

        for (idx, &t) in texts.iter().enumerate() {
            if let Some(v) = self.inner.cache.get_for_text(t, model_ver) {
                results.push(Some(v));
                self.inner.cache_hits.fetch_add(1, Ordering::Relaxed);
            } else {
                results.push(None);
                miss_indices.push(idx);
                miss_texts.push(t.to_string());
                self.inner.cache_misses.fetch_add(1, Ordering::Relaxed);
            }
        }

        // Second pass: one batch call for all misses. Single Mutex
        // acquisition, single ONNX forward pass — way more efficient
        // than N individual single-text calls.
        if !miss_texts.is_empty() {
            let computed = {
                let mut model = self.inner.model.lock();
                model.embed(miss_texts.clone(), None)?
            };

            for ((dest_idx, miss_text), vec) in miss_indices
                .iter()
                .zip(miss_texts.iter())
                .zip(computed.into_iter())
            {
                self.inner
                    .cache
                    .put_for_text(miss_text, model_ver, vec.clone());
                results[*dest_idx] = Some(vec);
            }
        }

        // All slots populated by now — unwrap is safe by construction.
        Ok(results
            .into_iter()
            .map(|opt| opt.expect("every position populated by either cache hit or batch"))
            .collect())
    }

    fn dim(&self) -> usize {
        self.inner.dim
    }
}

#[cfg(test)]
mod tests {
    //! Unit tests for the cache layering logic. These do NOT invoke
    //! the actual ONNX runtime — they exercise the cache plumbing via
    //! a hand-built FastEmbedInner shape that bypasses model load.
    //!
    //! End-to-end embedder behavior (with the actual ONNX model) is
    //! covered by integration tests + the bench harness.

    use super::*;

    /// Build a [`FastEmbedder`]-shaped struct without loading ONNX.
    /// Pre-populates the cache with known entries so we can verify
    /// the cache hit path returns without ever touching the model.
    fn embedder_with_seeded_cache(entries: &[(&str, Vec<f32>)]) -> Option<FastEmbedder> {
        // We can't actually construct the inner Mutex<TextEmbedding>
        // without loading ONNX. So tests for the cache hit path go
        // through `EmbeddingCache` directly — what they verify is that
        // the cache logic is correct, which is the only piece this
        // file owns vs the model.
        let _ = entries;
        None
    }

    #[test]
    fn cache_seeded_helper_compiles() {
        // Smoke: the test helper itself compiles; cache-hit-fast-path
        // verification is via `EmbeddingCache` tests (already shipped).
        // This stub keeps the test module valid until end-to-end ONNX
        // tests can run on a CI runner with the runtime installed.
        let result = embedder_with_seeded_cache(&[]);
        assert!(result.is_none()); // expected: helper is a no-op stub
    }

    #[test]
    fn cache_key_includes_model_version() {
        // Regression: the wire-form model_version constant must match
        // what gets baked into cache keys. If someone refactors
        // BUILTIN_MODEL_VERSION away, this catches it.
        use crate::cache::embedding::EmbeddingCacheKey;
        let k1 = EmbeddingCacheKey::for_text("hi", BUILTIN_MODEL_VERSION);
        let k2 = EmbeddingCacheKey::for_text("hi", "different-model");
        assert_ne!(k1, k2);
    }

    #[test]
    fn model_version_is_stable_string() {
        // Pinned identifier: changes here invalidate every cache entry
        // on the next request. That's intentional but rare — guard the
        // current value.
        assert_eq!(BUILTIN_MODEL_VERSION, "all-MiniLM-L6-v2");
    }

    #[test]
    fn cache_hit_rate_zero_on_no_calls() {
        // Verify the hit-rate accessor handles the zero-divide case.
        // Construct the rate calculation manually since we can't init
        // the full embedder without ONNX.
        let h: u64 = 0;
        let m: u64 = 0;
        let rate = if h + m == 0 {
            0.0
        } else {
            h as f64 / (h + m) as f64
        };
        assert_eq!(rate, 0.0);
    }

    #[test]
    fn cache_hit_rate_correct_for_mixed() {
        // 7 hits, 3 misses → 70% hit rate.
        let h: u64 = 7;
        let m: u64 = 3;
        let rate = if h + m == 0 {
            0.0
        } else {
            h as f64 / (h + m) as f64
        };
        assert!((rate - 0.7).abs() < 1e-9);
    }
}