1use crate::backend::{select_backend, EmbeddingBackend};
27use crate::error::{InferenceError, Result};
28use crate::models::ModelConfig;
29use std::sync::Arc;
30use tracing::{debug, info};
31
32pub struct TieredEngine {
37 fast_backend: Arc<dyn EmbeddingBackend>,
38 quality_backend: Arc<dyn EmbeddingBackend>,
39 tiered_enabled: bool,
40}
41
42impl TieredEngine {
43 pub async fn new(config: &ModelConfig) -> Result<Self> {
48 let tiered_enabled = std::env::var("DAKERA_TIERED")
49 .ok()
50 .as_deref()
51 .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
52 .unwrap_or(false);
53
54 let quality_backend = select_backend(config).await?;
55
56 let fast_backend: Arc<dyn EmbeddingBackend> = if tiered_enabled {
57 info!(
58 "TieredEngine: tiered mode enabled — fast=static, quality={}",
59 quality_backend.backend_kind()
60 );
61 let static_config = ModelConfig {
63 backend_override: Some(crate::backend::BackendKind::Static),
64 ..config.clone()
65 };
66 select_backend(&static_config).await?
67 } else {
68 debug!("TieredEngine: tiered mode disabled — single backend");
69 Arc::clone(&quality_backend)
70 };
71
72 Ok(Self {
73 fast_backend,
74 quality_backend,
75 tiered_enabled,
76 })
77 }
78
79 pub async fn embed_for_write(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
84 if texts.is_empty() {
85 return Ok(vec![]);
86 }
87 self.fast_backend.embed_batch(texts).await
88 }
89
90 pub async fn embed_for_read(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
94 if texts.is_empty() {
95 return Ok(vec![]);
96 }
97 self.quality_backend.embed_batch(texts).await
98 }
99
100 pub async fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
102 let mut results = self.embed_for_read(&[query.to_string()]).await?;
103 results
104 .pop()
105 .ok_or_else(|| InferenceError::InferenceError("empty embedding result".into()))
106 }
107
108 pub fn is_tiered(&self) -> bool {
110 self.tiered_enabled
111 }
112
113 pub fn fast_dimension(&self) -> usize {
115 self.fast_backend.dimension()
116 }
117
118 pub fn quality_dimension(&self) -> usize {
120 self.quality_backend.dimension()
121 }
122
123 pub fn fast_backend(&self) -> Arc<dyn EmbeddingBackend> {
125 Arc::clone(&self.fast_backend)
126 }
127
128 pub fn quality_backend(&self) -> Arc<dyn EmbeddingBackend> {
130 Arc::clone(&self.quality_backend)
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 use crate::backend::BackendKind;
143 use async_trait::async_trait;
144
145 struct MockBackend {
147 dim: usize,
148 kind: BackendKind,
149 fixed: Vec<f32>,
151 }
152
153 impl MockBackend {
154 fn new(dim: usize, kind: BackendKind) -> Self {
155 Self {
156 dim,
157 kind,
158 fixed: vec![1.0f32 / (dim as f32).sqrt(); dim],
159 }
160 }
161 }
162
163 #[async_trait]
164 impl EmbeddingBackend for MockBackend {
165 async fn embed_batch(&self, texts: &[String]) -> crate::error::Result<Vec<Vec<f32>>> {
166 Ok(texts.iter().map(|_| self.fixed.clone()).collect())
167 }
168 fn dimension(&self) -> usize {
169 self.dim
170 }
171 fn backend_kind(&self) -> BackendKind {
172 self.kind
173 }
174 }
175
176 fn mock_tiered(fast_dim: usize, quality_dim: usize) -> TieredEngine {
177 TieredEngine {
178 fast_backend: Arc::new(MockBackend::new(fast_dim, BackendKind::Static)),
179 quality_backend: Arc::new(MockBackend::new(quality_dim, BackendKind::Onnx)),
180 tiered_enabled: true,
181 }
182 }
183
184 fn mock_single(dim: usize) -> TieredEngine {
185 let b: Arc<dyn EmbeddingBackend> = Arc::new(MockBackend::new(dim, BackendKind::Onnx));
186 TieredEngine {
187 fast_backend: Arc::clone(&b),
188 quality_backend: b,
189 tiered_enabled: false,
190 }
191 }
192
193 #[tokio::test]
194 async fn test_embed_for_write_returns_fast_dim() {
195 let engine = mock_tiered(256, 1024);
196 let embs = engine
197 .embed_for_write(&["hello".to_string()])
198 .await
199 .unwrap();
200 assert_eq!(embs.len(), 1);
201 assert_eq!(embs[0].len(), 256, "write path must use fast backend dim");
202 }
203
204 #[tokio::test]
205 async fn test_embed_for_read_returns_quality_dim() {
206 let engine = mock_tiered(256, 1024);
207 let embs = engine.embed_for_read(&["hello".to_string()]).await.unwrap();
208 assert_eq!(embs.len(), 1);
209 assert_eq!(
210 embs[0].len(),
211 1024,
212 "read path must use quality backend dim"
213 );
214 }
215
216 #[tokio::test]
217 async fn test_embed_query_returns_quality_dim() {
218 let engine = mock_tiered(256, 1024);
219 let emb = engine.embed_query("test query").await.unwrap();
220 assert_eq!(emb.len(), 1024, "embed_query must use quality backend");
221 }
222
223 #[tokio::test]
224 async fn test_single_backend_write_read_same_dim() {
225 let engine = mock_single(768);
226 let w = engine.embed_for_write(&["x".to_string()]).await.unwrap();
227 let r = engine.embed_for_read(&["x".to_string()]).await.unwrap();
228 assert_eq!(w[0].len(), r[0].len(), "non-tiered: write/read same dim");
229 assert_eq!(w[0].len(), 768);
230 }
231
232 #[tokio::test]
233 async fn test_empty_write_returns_empty() {
234 let engine = mock_tiered(256, 1024);
235 let embs = engine.embed_for_write(&[]).await.unwrap();
236 assert!(embs.is_empty());
237 }
238
239 #[tokio::test]
240 async fn test_empty_read_returns_empty() {
241 let engine = mock_tiered(256, 1024);
242 let embs = engine.embed_for_read(&[]).await.unwrap();
243 assert!(embs.is_empty());
244 }
245
246 #[tokio::test]
247 async fn test_is_tiered_flag() {
248 assert!(mock_tiered(256, 1024).is_tiered());
249 assert!(!mock_single(768).is_tiered());
250 }
251
252 #[tokio::test]
253 async fn test_fast_dimension_accessor() {
254 let engine = mock_tiered(256, 1024);
255 assert_eq!(engine.fast_dimension(), 256);
256 }
257
258 #[tokio::test]
259 async fn test_quality_dimension_accessor() {
260 let engine = mock_tiered(256, 1024);
261 assert_eq!(engine.quality_dimension(), 1024);
262 }
263
264 #[tokio::test]
265 async fn test_batch_write_multiple_texts() {
266 let engine = mock_tiered(256, 1024);
267 let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
268 let embs = engine.embed_for_write(&texts).await.unwrap();
269 assert_eq!(embs.len(), 5, "must return one embedding per text");
270 for e in &embs {
271 assert_eq!(e.len(), 256);
272 }
273 }
274
275 #[tokio::test]
276 async fn test_backend_arc_accessors() {
277 let engine = mock_tiered(256, 1024);
278 assert_eq!(engine.fast_backend().backend_kind(), BackendKind::Static);
279 assert_eq!(engine.quality_backend().backend_kind(), BackendKind::Onnx);
280 }
281}