dakera-inference 0.11.80

Embedded inference engine for Dakera - generates embeddings locally via ONNX Runtime
Documentation
//! Pluggable embedding backend abstraction.
//!
//! # Overview
//!
//! The `EmbeddingBackend` trait decouples embedding generation from the transport layer
//! (ONNX Runtime, Candle, static lookup, GGUF quantized).  The engine selects a backend
//! at startup via environment variables:
//!
//! | Env var | Backend chosen |
//! |---------|----------------|
//! | `DAKERA_BACKEND=onnx` | [`OnnxBackend`] — production default (INT8 quantized) |
//! | `DAKERA_BACKEND=candle` | [`CandleBackend`] — pure-Rust Candle (requires `candle` feature) |
//! | `DAKERA_BACKEND=static` | [`StaticBackend`] — Model2Vec static lookup (~500× faster ingest) |
//! | `DAKERA_BACKEND=gguf` | [`GgufBackend`] — GGUF quantized via candle-gguf (requires `candle` feature) |
//! | *(unset)* | [`OnnxBackend`] — same as `onnx` |
//!
//! All backends implement the same `EmbeddingBackend` trait so callers are fully backend-agnostic.
//!
//! # Feature flags
//!
//! * `candle` — enables [`CandleBackend`] and [`GgufBackend`] (requires HuggingFace Candle crates)
//!
//! The `onnx` and `static` backends are always compiled — they have no additional dependencies
//! beyond what the `dakera-inference` crate already requires.

use crate::error::Result;
use async_trait::async_trait;
use std::sync::Arc;

pub mod onnx;
pub mod static_backend;

#[cfg(feature = "candle")]
pub mod candle_backend;
#[cfg(feature = "candle")]
pub mod gguf_backend;

pub use onnx::OnnxBackend;
pub use static_backend::StaticBackend;

#[cfg(feature = "candle")]
pub use candle_backend::CandleBackend;
#[cfg(feature = "candle")]
pub use gguf_backend::GgufBackend;

/// Identifies which backend is active.  Stored in HNSW node metadata so the
/// background re-embed job knows which memories were indexed with a fast
/// (static) embedding and need to be upgraded with the quality transformer.
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BackendKind {
    /// ONNX Runtime INT8 quantized — production default.
    Onnx,
    /// Candle pure-Rust FP32/BF16 transformer.
    Candle,
    /// Model2Vec static vocabulary lookup — fast ingest path.
    Static,
    /// GGUF quantized inference via candle-gguf.
    Gguf,
}

impl std::fmt::Display for BackendKind {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            BackendKind::Onnx => write!(f, "onnx"),
            BackendKind::Candle => write!(f, "candle"),
            BackendKind::Static => write!(f, "static"),
            BackendKind::Gguf => write!(f, "gguf"),
        }
    }
}

/// Core embedding backend trait.  All backends must be `Send + Sync` so they
/// can be held behind an `Arc` and called from async tasks.
#[async_trait]
pub trait EmbeddingBackend: Send + Sync {
    /// Embed a batch of raw texts and return a float32 embedding per input.
    ///
    /// Callers must pre-process texts (apply prefixes, truncation) before
    /// calling this method.  The returned `Vec` has the same length as
    /// `texts`.  Empty input returns an empty vec immediately.
    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;

    /// Embedding dimensionality (fixed for the lifetime of the backend).
    fn dimension(&self) -> usize;

    /// The concrete backend kind — used for metrics and re-embed bookkeeping.
    fn backend_kind(&self) -> BackendKind;
}

/// Select an embedding backend based on environment variables.
///
/// Priority:
/// 1. `DAKERA_BACKEND` — explicit backend selection
/// 2. `DAKERA_USE_GPU=1` → Candle (CUDA/Metal) when the `candle` feature is compiled in
/// 3. Default → [`OnnxBackend`]
///
/// # Errors
///
/// Returns an error if the requested backend is not compiled in or fails to initialise.
pub async fn select_backend(
    config: &crate::models::ModelConfig,
) -> Result<Arc<dyn EmbeddingBackend>> {
    // backend_override takes precedence over env var — used by TieredEngine
    let override_str = config.backend_override.as_ref().map(|k| k.to_string());
    let explicit = override_str.or_else(|| {
        std::env::var("DAKERA_BACKEND")
            .ok()
            .map(|v| v.to_lowercase())
    });

    match explicit.as_deref() {
        Some("static") => {
            tracing::info!("Backend: static (Model2Vec vocabulary lookup)");
            let backend = StaticBackend::new(config).await?;
            Ok(Arc::new(backend))
        }
        #[cfg(feature = "candle")]
        Some("candle") => {
            tracing::info!("Backend: candle (pure-Rust Candle)");
            let backend = CandleBackend::new(config).await?;
            Ok(Arc::new(backend))
        }
        #[cfg(feature = "candle")]
        Some("gguf") => {
            tracing::info!("Backend: gguf (candle-gguf quantized)");
            let backend = GgufBackend::new(config).await?;
            Ok(Arc::new(backend))
        }
        Some("onnx") | None => {
            #[cfg(feature = "candle")]
            {
                let use_gpu = std::env::var("DAKERA_USE_GPU")
                    .map(|v| v == "1")
                    .unwrap_or(config.use_gpu);
                if use_gpu && explicit.is_none() {
                    tracing::info!("Backend: candle (GPU auto-selected via DAKERA_USE_GPU=1)");
                    let backend = CandleBackend::new(config).await?;
                    return Ok(Arc::new(backend));
                }
            }
            tracing::info!("Backend: onnx (ONNX Runtime INT8)");
            let backend = OnnxBackend::new(config).await?;
            Ok(Arc::new(backend))
        }
        Some(other) => {
            #[cfg(not(feature = "candle"))]
            if other == "candle" || other == "gguf" {
                return Err(crate::error::InferenceError::InvalidInput(format!(
                    "Backend '{other}' requires the 'candle' feature to be compiled in"
                )));
            }
            tracing::warn!("Unknown DAKERA_BACKEND={other}, falling back to onnx");
            let backend = OnnxBackend::new(config).await?;
            Ok(Arc::new(backend))
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_backend_kind_display() {
        assert_eq!(BackendKind::Onnx.to_string(), "onnx");
        assert_eq!(BackendKind::Candle.to_string(), "candle");
        assert_eq!(BackendKind::Static.to_string(), "static");
        assert_eq!(BackendKind::Gguf.to_string(), "gguf");
    }

    #[test]
    fn test_backend_kind_serde_roundtrip() {
        let kinds = [
            BackendKind::Onnx,
            BackendKind::Candle,
            BackendKind::Static,
            BackendKind::Gguf,
        ];
        for kind in kinds {
            let json = serde_json::to_string(&kind).unwrap();
            let decoded: BackendKind = serde_json::from_str(&json).unwrap();
            assert_eq!(kind, decoded);
        }
    }

    #[test]
    fn test_backend_kind_equality() {
        assert_eq!(BackendKind::Onnx, BackendKind::Onnx);
        assert_ne!(BackendKind::Onnx, BackendKind::Candle);
        assert_ne!(BackendKind::Static, BackendKind::Gguf);
    }

    #[test]
    fn test_backend_kind_copy() {
        let k = BackendKind::Candle;
        let k2 = k;
        assert_eq!(k, k2);
    }

    #[test]
    fn test_backend_kind_debug() {
        let s = format!("{:?}", BackendKind::Onnx);
        assert!(s.contains("Onnx"));
    }

    #[test]
    fn test_backend_kind_serde_snake_case() {
        // snake_case serialisation for API/storage compat
        let json = serde_json::to_string(&BackendKind::Onnx).unwrap();
        assert_eq!(json, "\"onnx\"");
        let json = serde_json::to_string(&BackendKind::Candle).unwrap();
        assert_eq!(json, "\"candle\"");
    }
}