use crate::backend::{select_backend, EmbeddingBackend};
use crate::error::{InferenceError, Result};
use crate::models::ModelConfig;
use std::sync::Arc;
use tracing::{debug, info};
pub struct TieredEngine {
fast_backend: Arc<dyn EmbeddingBackend>,
quality_backend: Arc<dyn EmbeddingBackend>,
tiered_enabled: bool,
}
impl TieredEngine {
pub async fn new(config: &ModelConfig) -> Result<Self> {
let tiered_enabled = std::env::var("DAKERA_TIERED")
.ok()
.as_deref()
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);
let quality_backend = select_backend(config).await?;
let fast_backend: Arc<dyn EmbeddingBackend> = if tiered_enabled {
info!(
"TieredEngine: tiered mode enabled — fast=static, quality={}",
quality_backend.backend_kind()
);
let static_config = ModelConfig {
backend_override: Some(crate::backend::BackendKind::Static),
..config.clone()
};
select_backend(&static_config).await?
} else {
debug!("TieredEngine: tiered mode disabled — single backend");
Arc::clone(&quality_backend)
};
Ok(Self {
fast_backend,
quality_backend,
tiered_enabled,
})
}
pub async fn embed_for_write(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
self.fast_backend.embed_batch(texts).await
}
pub async fn embed_for_read(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
self.quality_backend.embed_batch(texts).await
}
pub async fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
let mut results = self.embed_for_read(&[query.to_string()]).await?;
results
.pop()
.ok_or_else(|| InferenceError::InferenceError("empty embedding result".into()))
}
pub fn is_tiered(&self) -> bool {
self.tiered_enabled
}
pub fn fast_dimension(&self) -> usize {
self.fast_backend.dimension()
}
pub fn quality_dimension(&self) -> usize {
self.quality_backend.dimension()
}
pub fn fast_backend(&self) -> Arc<dyn EmbeddingBackend> {
Arc::clone(&self.fast_backend)
}
pub fn quality_backend(&self) -> Arc<dyn EmbeddingBackend> {
Arc::clone(&self.quality_backend)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::BackendKind;
use async_trait::async_trait;
struct MockBackend {
dim: usize,
kind: BackendKind,
fixed: Vec<f32>,
}
impl MockBackend {
fn new(dim: usize, kind: BackendKind) -> Self {
Self {
dim,
kind,
fixed: vec![1.0f32 / (dim as f32).sqrt(); dim],
}
}
}
#[async_trait]
impl EmbeddingBackend for MockBackend {
async fn embed_batch(&self, texts: &[String]) -> crate::error::Result<Vec<Vec<f32>>> {
Ok(texts.iter().map(|_| self.fixed.clone()).collect())
}
fn dimension(&self) -> usize {
self.dim
}
fn backend_kind(&self) -> BackendKind {
self.kind
}
}
fn mock_tiered(fast_dim: usize, quality_dim: usize) -> TieredEngine {
TieredEngine {
fast_backend: Arc::new(MockBackend::new(fast_dim, BackendKind::Static)),
quality_backend: Arc::new(MockBackend::new(quality_dim, BackendKind::Onnx)),
tiered_enabled: true,
}
}
fn mock_single(dim: usize) -> TieredEngine {
let b: Arc<dyn EmbeddingBackend> = Arc::new(MockBackend::new(dim, BackendKind::Onnx));
TieredEngine {
fast_backend: Arc::clone(&b),
quality_backend: b,
tiered_enabled: false,
}
}
#[tokio::test]
async fn test_embed_for_write_returns_fast_dim() {
let engine = mock_tiered(256, 1024);
let embs = engine
.embed_for_write(&["hello".to_string()])
.await
.unwrap();
assert_eq!(embs.len(), 1);
assert_eq!(embs[0].len(), 256, "write path must use fast backend dim");
}
#[tokio::test]
async fn test_embed_for_read_returns_quality_dim() {
let engine = mock_tiered(256, 1024);
let embs = engine.embed_for_read(&["hello".to_string()]).await.unwrap();
assert_eq!(embs.len(), 1);
assert_eq!(
embs[0].len(),
1024,
"read path must use quality backend dim"
);
}
#[tokio::test]
async fn test_embed_query_returns_quality_dim() {
let engine = mock_tiered(256, 1024);
let emb = engine.embed_query("test query").await.unwrap();
assert_eq!(emb.len(), 1024, "embed_query must use quality backend");
}
#[tokio::test]
async fn test_single_backend_write_read_same_dim() {
let engine = mock_single(768);
let w = engine.embed_for_write(&["x".to_string()]).await.unwrap();
let r = engine.embed_for_read(&["x".to_string()]).await.unwrap();
assert_eq!(w[0].len(), r[0].len(), "non-tiered: write/read same dim");
assert_eq!(w[0].len(), 768);
}
#[tokio::test]
async fn test_empty_write_returns_empty() {
let engine = mock_tiered(256, 1024);
let embs = engine.embed_for_write(&[]).await.unwrap();
assert!(embs.is_empty());
}
#[tokio::test]
async fn test_empty_read_returns_empty() {
let engine = mock_tiered(256, 1024);
let embs = engine.embed_for_read(&[]).await.unwrap();
assert!(embs.is_empty());
}
#[tokio::test]
async fn test_is_tiered_flag() {
assert!(mock_tiered(256, 1024).is_tiered());
assert!(!mock_single(768).is_tiered());
}
#[tokio::test]
async fn test_fast_dimension_accessor() {
let engine = mock_tiered(256, 1024);
assert_eq!(engine.fast_dimension(), 256);
}
#[tokio::test]
async fn test_quality_dimension_accessor() {
let engine = mock_tiered(256, 1024);
assert_eq!(engine.quality_dimension(), 1024);
}
#[tokio::test]
async fn test_batch_write_multiple_texts() {
let engine = mock_tiered(256, 1024);
let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
let embs = engine.embed_for_write(&texts).await.unwrap();
assert_eq!(embs.len(), 5, "must return one embedding per text");
for e in &embs {
assert_eq!(e.len(), 256);
}
}
#[tokio::test]
async fn test_backend_arc_accessors() {
let engine = mock_tiered(256, 1024);
assert_eq!(engine.fast_backend().backend_kind(), BackendKind::Static);
assert_eq!(engine.quality_backend().backend_kind(), BackendKind::Onnx);
}
}