1use crate::error::Result;
27use async_trait::async_trait;
28use std::sync::Arc;
29
30pub mod onnx;
31pub mod static_backend;
32
33#[cfg(feature = "candle")]
34pub mod candle_backend;
35#[cfg(feature = "candle")]
36pub mod gguf_backend;
37
38pub use onnx::OnnxBackend;
39pub use static_backend::StaticBackend;
40
41#[cfg(feature = "candle")]
42pub use candle_backend::CandleBackend;
43#[cfg(feature = "candle")]
44pub use gguf_backend::GgufBackend;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
50#[serde(rename_all = "snake_case")]
51pub enum BackendKind {
52 Onnx,
54 Candle,
56 Static,
58 Gguf,
60}
61
62impl std::fmt::Display for BackendKind {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 BackendKind::Onnx => write!(f, "onnx"),
66 BackendKind::Candle => write!(f, "candle"),
67 BackendKind::Static => write!(f, "static"),
68 BackendKind::Gguf => write!(f, "gguf"),
69 }
70 }
71}
72
73#[async_trait]
76pub trait EmbeddingBackend: Send + Sync {
77 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
83
84 fn dimension(&self) -> usize;
86
87 fn backend_kind(&self) -> BackendKind;
89}
90
91pub async fn select_backend(
102 config: &crate::models::ModelConfig,
103) -> Result<Arc<dyn EmbeddingBackend>> {
104 let override_str = config.backend_override.as_ref().map(|k| k.to_string());
106 let explicit = override_str.or_else(|| {
107 std::env::var("DAKERA_BACKEND")
108 .ok()
109 .map(|v| v.to_lowercase())
110 });
111
112 match explicit.as_deref() {
113 Some("static") => {
114 tracing::info!("Backend: static (Model2Vec vocabulary lookup)");
115 let backend = StaticBackend::new(config).await?;
116 Ok(Arc::new(backend))
117 }
118 #[cfg(feature = "candle")]
119 Some("candle") => {
120 tracing::info!("Backend: candle (pure-Rust Candle)");
121 let backend = CandleBackend::new(config).await?;
122 Ok(Arc::new(backend))
123 }
124 #[cfg(feature = "candle")]
125 Some("gguf") => {
126 tracing::info!("Backend: gguf (candle-gguf quantized)");
127 let backend = GgufBackend::new(config).await?;
128 Ok(Arc::new(backend))
129 }
130 Some("onnx") | None => {
131 #[cfg(feature = "candle")]
132 {
133 let use_gpu = std::env::var("DAKERA_USE_GPU")
134 .map(|v| v == "1")
135 .unwrap_or(config.use_gpu);
136 if use_gpu && explicit.is_none() {
137 tracing::info!("Backend: candle (GPU auto-selected via DAKERA_USE_GPU=1)");
138 let backend = CandleBackend::new(config).await?;
139 return Ok(Arc::new(backend));
140 }
141 }
142 tracing::info!("Backend: onnx (ONNX Runtime INT8)");
143 let backend = OnnxBackend::new(config).await?;
144 Ok(Arc::new(backend))
145 }
146 Some(other) => {
147 #[cfg(not(feature = "candle"))]
148 if other == "candle" || other == "gguf" {
149 return Err(crate::error::InferenceError::InvalidInput(format!(
150 "Backend '{other}' requires the 'candle' feature to be compiled in"
151 )));
152 }
153 tracing::warn!("Unknown DAKERA_BACKEND={other}, falling back to onnx");
154 let backend = OnnxBackend::new(config).await?;
155 Ok(Arc::new(backend))
156 }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 #[test]
165 fn test_backend_kind_display() {
166 assert_eq!(BackendKind::Onnx.to_string(), "onnx");
167 assert_eq!(BackendKind::Candle.to_string(), "candle");
168 assert_eq!(BackendKind::Static.to_string(), "static");
169 assert_eq!(BackendKind::Gguf.to_string(), "gguf");
170 }
171
172 #[test]
173 fn test_backend_kind_serde_roundtrip() {
174 let kinds = [
175 BackendKind::Onnx,
176 BackendKind::Candle,
177 BackendKind::Static,
178 BackendKind::Gguf,
179 ];
180 for kind in kinds {
181 let json = serde_json::to_string(&kind).unwrap();
182 let decoded: BackendKind = serde_json::from_str(&json).unwrap();
183 assert_eq!(kind, decoded);
184 }
185 }
186
187 #[test]
188 fn test_backend_kind_equality() {
189 assert_eq!(BackendKind::Onnx, BackendKind::Onnx);
190 assert_ne!(BackendKind::Onnx, BackendKind::Candle);
191 assert_ne!(BackendKind::Static, BackendKind::Gguf);
192 }
193
194 #[test]
195 fn test_backend_kind_copy() {
196 let k = BackendKind::Candle;
197 let k2 = k;
198 assert_eq!(k, k2);
199 }
200
201 #[test]
202 fn test_backend_kind_debug() {
203 let s = format!("{:?}", BackendKind::Onnx);
204 assert!(s.contains("Onnx"));
205 }
206
207 #[test]
208 fn test_backend_kind_serde_snake_case() {
209 let json = serde_json::to_string(&BackendKind::Onnx).unwrap();
211 assert_eq!(json, "\"onnx\"");
212 let json = serde_json::to_string(&BackendKind::Candle).unwrap();
213 assert_eq!(json, "\"candle\"");
214 }
215}