Skip to main content

inference/backend/
mod.rs

1//! Pluggable embedding backend abstraction.
2//!
3//! # Overview
4//!
5//! The `EmbeddingBackend` trait decouples embedding generation from the transport layer
6//! (ONNX Runtime, Candle, static lookup, GGUF quantized).  The engine selects a backend
7//! at startup via environment variables:
8//!
9//! | Env var | Backend chosen |
10//! |---------|----------------|
11//! | `DAKERA_BACKEND=onnx` | [`OnnxBackend`] — production default (INT8 quantized) |
12//! | `DAKERA_BACKEND=candle` | [`CandleBackend`] — pure-Rust Candle (requires `candle` feature) |
13//! | `DAKERA_BACKEND=static` | [`StaticBackend`] — Model2Vec static lookup (~500× faster ingest) |
14//! | `DAKERA_BACKEND=gguf` | [`GgufBackend`] — GGUF quantized via candle-gguf (requires `candle` feature) |
15//! | *(unset)* | [`OnnxBackend`] — same as `onnx` |
16//!
17//! All backends implement the same `EmbeddingBackend` trait so callers are fully backend-agnostic.
18//!
19//! # Feature flags
20//!
21//! * `candle` — enables [`CandleBackend`] and [`GgufBackend`] (requires HuggingFace Candle crates)
22//!
23//! The `onnx` and `static` backends are always compiled — they have no additional dependencies
24//! beyond what the `dakera-inference` crate already requires.
25
26use crate::error::Result;
27use async_trait::async_trait;
28use std::sync::Arc;
29
30pub mod onnx;
31pub mod static_backend;
32
33#[cfg(feature = "candle")]
34pub mod candle_backend;
35#[cfg(feature = "candle")]
36pub mod gguf_backend;
37
38pub use onnx::OnnxBackend;
39pub use static_backend::StaticBackend;
40
41#[cfg(feature = "candle")]
42pub use candle_backend::CandleBackend;
43#[cfg(feature = "candle")]
44pub use gguf_backend::GgufBackend;
45
46/// Identifies which backend is active.  Stored in HNSW node metadata so the
47/// background re-embed job knows which memories were indexed with a fast
48/// (static) embedding and need to be upgraded with the quality transformer.
49#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
50#[serde(rename_all = "snake_case")]
51pub enum BackendKind {
52    /// ONNX Runtime INT8 quantized — production default.
53    Onnx,
54    /// Candle pure-Rust FP32/BF16 transformer.
55    Candle,
56    /// Model2Vec static vocabulary lookup — fast ingest path.
57    Static,
58    /// GGUF quantized inference via candle-gguf.
59    Gguf,
60}
61
62impl std::fmt::Display for BackendKind {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            BackendKind::Onnx => write!(f, "onnx"),
66            BackendKind::Candle => write!(f, "candle"),
67            BackendKind::Static => write!(f, "static"),
68            BackendKind::Gguf => write!(f, "gguf"),
69        }
70    }
71}
72
73/// Core embedding backend trait.  All backends must be `Send + Sync` so they
74/// can be held behind an `Arc` and called from async tasks.
75#[async_trait]
76pub trait EmbeddingBackend: Send + Sync {
77    /// Embed a batch of raw texts and return a float32 embedding per input.
78    ///
79    /// Callers must pre-process texts (apply prefixes, truncation) before
80    /// calling this method.  The returned `Vec` has the same length as
81    /// `texts`.  Empty input returns an empty vec immediately.
82    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
83
84    /// Embedding dimensionality (fixed for the lifetime of the backend).
85    fn dimension(&self) -> usize;
86
87    /// The concrete backend kind — used for metrics and re-embed bookkeeping.
88    fn backend_kind(&self) -> BackendKind;
89}
90
91/// Select an embedding backend based on environment variables.
92///
93/// Priority:
94/// 1. `DAKERA_BACKEND` — explicit backend selection
95/// 2. `DAKERA_USE_GPU=1` → Candle (CUDA/Metal) when the `candle` feature is compiled in
96/// 3. Default → [`OnnxBackend`]
97///
98/// # Errors
99///
100/// Returns an error if the requested backend is not compiled in or fails to initialise.
101pub async fn select_backend(
102    config: &crate::models::ModelConfig,
103) -> Result<Arc<dyn EmbeddingBackend>> {
104    // backend_override takes precedence over env var — used by TieredEngine
105    let override_str = config.backend_override.as_ref().map(|k| k.to_string());
106    let explicit = override_str.or_else(|| {
107        std::env::var("DAKERA_BACKEND")
108            .ok()
109            .map(|v| v.to_lowercase())
110    });
111
112    match explicit.as_deref() {
113        Some("static") => {
114            tracing::info!("Backend: static (Model2Vec vocabulary lookup)");
115            let backend = StaticBackend::new(config).await?;
116            Ok(Arc::new(backend))
117        }
118        #[cfg(feature = "candle")]
119        Some("candle") => {
120            tracing::info!("Backend: candle (pure-Rust Candle)");
121            let backend = CandleBackend::new(config).await?;
122            Ok(Arc::new(backend))
123        }
124        #[cfg(feature = "candle")]
125        Some("gguf") => {
126            tracing::info!("Backend: gguf (candle-gguf quantized)");
127            let backend = GgufBackend::new(config).await?;
128            Ok(Arc::new(backend))
129        }
130        Some("onnx") | None => {
131            #[cfg(feature = "candle")]
132            {
133                let use_gpu = std::env::var("DAKERA_USE_GPU")
134                    .map(|v| v == "1")
135                    .unwrap_or(config.use_gpu);
136                if use_gpu && explicit.is_none() {
137                    tracing::info!("Backend: candle (GPU auto-selected via DAKERA_USE_GPU=1)");
138                    let backend = CandleBackend::new(config).await?;
139                    return Ok(Arc::new(backend));
140                }
141            }
142            tracing::info!("Backend: onnx (ONNX Runtime INT8)");
143            let backend = OnnxBackend::new(config).await?;
144            Ok(Arc::new(backend))
145        }
146        Some(other) => {
147            #[cfg(not(feature = "candle"))]
148            if other == "candle" || other == "gguf" {
149                return Err(crate::error::InferenceError::InvalidInput(format!(
150                    "Backend '{other}' requires the 'candle' feature to be compiled in"
151                )));
152            }
153            tracing::warn!("Unknown DAKERA_BACKEND={other}, falling back to onnx");
154            let backend = OnnxBackend::new(config).await?;
155            Ok(Arc::new(backend))
156        }
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_backend_kind_display() {
166        assert_eq!(BackendKind::Onnx.to_string(), "onnx");
167        assert_eq!(BackendKind::Candle.to_string(), "candle");
168        assert_eq!(BackendKind::Static.to_string(), "static");
169        assert_eq!(BackendKind::Gguf.to_string(), "gguf");
170    }
171
172    #[test]
173    fn test_backend_kind_serde_roundtrip() {
174        let kinds = [
175            BackendKind::Onnx,
176            BackendKind::Candle,
177            BackendKind::Static,
178            BackendKind::Gguf,
179        ];
180        for kind in kinds {
181            let json = serde_json::to_string(&kind).unwrap();
182            let decoded: BackendKind = serde_json::from_str(&json).unwrap();
183            assert_eq!(kind, decoded);
184        }
185    }
186
187    #[test]
188    fn test_backend_kind_equality() {
189        assert_eq!(BackendKind::Onnx, BackendKind::Onnx);
190        assert_ne!(BackendKind::Onnx, BackendKind::Candle);
191        assert_ne!(BackendKind::Static, BackendKind::Gguf);
192    }
193
194    #[test]
195    fn test_backend_kind_copy() {
196        let k = BackendKind::Candle;
197        let k2 = k;
198        assert_eq!(k, k2);
199    }
200
201    #[test]
202    fn test_backend_kind_debug() {
203        let s = format!("{:?}", BackendKind::Onnx);
204        assert!(s.contains("Onnx"));
205    }
206
207    #[test]
208    fn test_backend_kind_serde_snake_case() {
209        // snake_case serialisation for API/storage compat
210        let json = serde_json::to_string(&BackendKind::Onnx).unwrap();
211        assert_eq!(json, "\"onnx\"");
212        let json = serde_json::to_string(&BackendKind::Candle).unwrap();
213        assert_eq!(json, "\"candle\"");
214    }
215}