use crate::error::Result;
use async_trait::async_trait;
use std::sync::Arc;
pub mod onnx;
pub mod static_backend;
#[cfg(feature = "candle")]
pub mod candle_backend;
#[cfg(feature = "candle")]
pub mod gguf_backend;
pub use onnx::OnnxBackend;
pub use static_backend::StaticBackend;
#[cfg(feature = "candle")]
pub use candle_backend::CandleBackend;
#[cfg(feature = "candle")]
pub use gguf_backend::GgufBackend;
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum BackendKind {
Onnx,
Candle,
Static,
Gguf,
}
impl std::fmt::Display for BackendKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BackendKind::Onnx => write!(f, "onnx"),
BackendKind::Candle => write!(f, "candle"),
BackendKind::Static => write!(f, "static"),
BackendKind::Gguf => write!(f, "gguf"),
}
}
}
#[async_trait]
pub trait EmbeddingBackend: Send + Sync {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
fn dimension(&self) -> usize;
fn backend_kind(&self) -> BackendKind;
}
pub async fn select_backend(
config: &crate::models::ModelConfig,
) -> Result<Arc<dyn EmbeddingBackend>> {
let override_str = config.backend_override.as_ref().map(|k| k.to_string());
let explicit = override_str.or_else(|| {
std::env::var("DAKERA_BACKEND")
.ok()
.map(|v| v.to_lowercase())
});
match explicit.as_deref() {
Some("static") => {
tracing::info!("Backend: static (Model2Vec vocabulary lookup)");
let backend = StaticBackend::new(config).await?;
Ok(Arc::new(backend))
}
#[cfg(feature = "candle")]
Some("candle") => {
tracing::info!("Backend: candle (pure-Rust Candle)");
let backend = CandleBackend::new(config).await?;
Ok(Arc::new(backend))
}
#[cfg(feature = "candle")]
Some("gguf") => {
tracing::info!("Backend: gguf (candle-gguf quantized)");
let backend = GgufBackend::new(config).await?;
Ok(Arc::new(backend))
}
Some("onnx") | None => {
#[cfg(feature = "candle")]
{
let use_gpu = std::env::var("DAKERA_USE_GPU")
.map(|v| v == "1")
.unwrap_or(config.use_gpu);
if use_gpu && explicit.is_none() {
tracing::info!("Backend: candle (GPU auto-selected via DAKERA_USE_GPU=1)");
let backend = CandleBackend::new(config).await?;
return Ok(Arc::new(backend));
}
}
tracing::info!("Backend: onnx (ONNX Runtime INT8)");
let backend = OnnxBackend::new(config).await?;
Ok(Arc::new(backend))
}
Some(other) => {
#[cfg(not(feature = "candle"))]
if other == "candle" || other == "gguf" {
return Err(crate::error::InferenceError::InvalidInput(format!(
"Backend '{other}' requires the 'candle' feature to be compiled in"
)));
}
tracing::warn!("Unknown DAKERA_BACKEND={other}, falling back to onnx");
let backend = OnnxBackend::new(config).await?;
Ok(Arc::new(backend))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_kind_display() {
assert_eq!(BackendKind::Onnx.to_string(), "onnx");
assert_eq!(BackendKind::Candle.to_string(), "candle");
assert_eq!(BackendKind::Static.to_string(), "static");
assert_eq!(BackendKind::Gguf.to_string(), "gguf");
}
#[test]
fn test_backend_kind_serde_roundtrip() {
let kinds = [
BackendKind::Onnx,
BackendKind::Candle,
BackendKind::Static,
BackendKind::Gguf,
];
for kind in kinds {
let json = serde_json::to_string(&kind).unwrap();
let decoded: BackendKind = serde_json::from_str(&json).unwrap();
assert_eq!(kind, decoded);
}
}
#[test]
fn test_backend_kind_equality() {
assert_eq!(BackendKind::Onnx, BackendKind::Onnx);
assert_ne!(BackendKind::Onnx, BackendKind::Candle);
assert_ne!(BackendKind::Static, BackendKind::Gguf);
}
#[test]
fn test_backend_kind_copy() {
let k = BackendKind::Candle;
let k2 = k;
assert_eq!(k, k2);
}
#[test]
fn test_backend_kind_debug() {
let s = format!("{:?}", BackendKind::Onnx);
assert!(s.contains("Onnx"));
}
#[test]
fn test_backend_kind_serde_snake_case() {
let json = serde_json::to_string(&BackendKind::Onnx).unwrap();
assert_eq!(json, "\"onnx\"");
let json = serde_json::to_string(&BackendKind::Candle).unwrap();
assert_eq!(json, "\"candle\"");
}
}