Skip to main content

inference/
tiered.rs

1//! Tiered embedding engine — fast writes, quality reads.
2//!
3//! # Architecture
4//!
5//! ```text
6//!  ┌────────────────────────────────────────────────────────────────┐
7//!  │                        TieredEngine                           │
8//!  │                                                                │
9//!  │  WRITE path → fast_backend (StaticBackend / Model2Vec)        │
10//!  │               ~500× faster than transformer, no GPU needed    │
11//!  │                                                                │
12//!  │  READ path  → quality_backend (CandleBackend / OnnxBackend)   │
13//!  │               full transformer quality for recall scoring      │
14//!  └────────────────────────────────────────────────────────────────┘
15//! ```
16//!
17//! Controlled by `DAKERA_TIERED=1` (default: off, legacy ONNX path).
18//! When enabled, `DAKERA_BACKEND` selects the quality backend.
19//!
20//! # Background Re-embedding
21//!
22//! New memories are stored with fast (static) embeddings first.
23//! A background job (`ReembedJob`) upgrades high-importance memories to
24//! transformer-quality embeddings at a rate of ≤50 per 5 minutes.
25
26use crate::backend::{select_backend, EmbeddingBackend};
27use crate::error::{InferenceError, Result};
28use crate::models::ModelConfig;
29use std::sync::Arc;
30use tracing::{debug, info};
31
32/// Tiered embedding engine: fast writes via static backend, quality reads via transformer.
33///
34/// Enabled with `DAKERA_TIERED=1`. When disabled, `embed_for_write` and `embed_for_read`
35/// both delegate to the single quality backend (no tier split).
36pub struct TieredEngine {
37    fast_backend: Arc<dyn EmbeddingBackend>,
38    quality_backend: Arc<dyn EmbeddingBackend>,
39    tiered_enabled: bool,
40}
41
42impl TieredEngine {
43    /// Build from the given `ModelConfig`.
44    ///
45    /// When `DAKERA_TIERED=1`, initialises both the static fast backend and the
46    /// transformer quality backend.  Otherwise both slots use the quality backend.
47    pub async fn new(config: &ModelConfig) -> Result<Self> {
48        let tiered_enabled = std::env::var("DAKERA_TIERED")
49            .ok()
50            .as_deref()
51            .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
52            .unwrap_or(false);
53
54        let quality_backend = select_backend(config).await?;
55
56        let fast_backend: Arc<dyn EmbeddingBackend> = if tiered_enabled {
57            info!(
58                "TieredEngine: tiered mode enabled — fast=static, quality={}",
59                quality_backend.backend_kind()
60            );
61            // Build a static backend for fast writes
62            let static_config = ModelConfig {
63                backend_override: Some(crate::backend::BackendKind::Static),
64                ..config.clone()
65            };
66            select_backend(&static_config).await?
67        } else {
68            debug!("TieredEngine: tiered mode disabled — single backend");
69            Arc::clone(&quality_backend)
70        };
71
72        Ok(Self {
73            fast_backend,
74            quality_backend,
75            tiered_enabled,
76        })
77    }
78
79    /// Embed texts for storage (write path).
80    ///
81    /// Uses the fast backend when tiered mode is enabled, quality backend otherwise.
82    /// Returns embeddings suitable for indexing — may be static Model2Vec quality.
83    pub async fn embed_for_write(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
84        if texts.is_empty() {
85            return Ok(vec![]);
86        }
87        self.fast_backend.embed_batch(texts).await
88    }
89
90    /// Embed texts for retrieval scoring (read path).
91    ///
92    /// Always uses the quality backend so recall scoring is at transformer quality.
93    pub async fn embed_for_read(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
94        if texts.is_empty() {
95            return Ok(vec![]);
96        }
97        self.quality_backend.embed_batch(texts).await
98    }
99
100    /// Embed a single query for retrieval.
101    pub async fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
102        let mut results = self.embed_for_read(&[query.to_string()]).await?;
103        results
104            .pop()
105            .ok_or_else(|| InferenceError::InferenceError("empty embedding result".into()))
106    }
107
108    /// Whether the tiered split is active (fast writes / quality reads).
109    pub fn is_tiered(&self) -> bool {
110        self.tiered_enabled
111    }
112
113    /// Dimension of the fast (write) backend.
114    pub fn fast_dimension(&self) -> usize {
115        self.fast_backend.dimension()
116    }
117
118    /// Dimension of the quality (read) backend.
119    pub fn quality_dimension(&self) -> usize {
120        self.quality_backend.dimension()
121    }
122
123    /// Expose fast backend for background re-embedding tasks.
124    pub fn fast_backend(&self) -> Arc<dyn EmbeddingBackend> {
125        Arc::clone(&self.fast_backend)
126    }
127
128    /// Expose quality backend for background re-embedding tasks.
129    pub fn quality_backend(&self) -> Arc<dyn EmbeddingBackend> {
130        Arc::clone(&self.quality_backend)
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    // Unit tests for TieredEngine use only the env-var dispatch logic and mock
139    // backends — no model downloads.  Integration tests with real backends are
140    // in the `#[ignore]` block.
141
142    use crate::backend::BackendKind;
143    use async_trait::async_trait;
144
145    /// Minimal mock backend for unit testing.
146    struct MockBackend {
147        dim: usize,
148        kind: BackendKind,
149        /// Fixed embedding returned for every text (zero-padded to dim).
150        fixed: Vec<f32>,
151    }
152
153    impl MockBackend {
154        fn new(dim: usize, kind: BackendKind) -> Self {
155            Self {
156                dim,
157                kind,
158                fixed: vec![1.0f32 / (dim as f32).sqrt(); dim],
159            }
160        }
161    }
162
163    #[async_trait]
164    impl EmbeddingBackend for MockBackend {
165        async fn embed_batch(&self, texts: &[String]) -> crate::error::Result<Vec<Vec<f32>>> {
166            Ok(texts.iter().map(|_| self.fixed.clone()).collect())
167        }
168        fn dimension(&self) -> usize {
169            self.dim
170        }
171        fn backend_kind(&self) -> BackendKind {
172            self.kind
173        }
174    }
175
176    fn mock_tiered(fast_dim: usize, quality_dim: usize) -> TieredEngine {
177        TieredEngine {
178            fast_backend: Arc::new(MockBackend::new(fast_dim, BackendKind::Static)),
179            quality_backend: Arc::new(MockBackend::new(quality_dim, BackendKind::Onnx)),
180            tiered_enabled: true,
181        }
182    }
183
184    fn mock_single(dim: usize) -> TieredEngine {
185        let b: Arc<dyn EmbeddingBackend> = Arc::new(MockBackend::new(dim, BackendKind::Onnx));
186        TieredEngine {
187            fast_backend: Arc::clone(&b),
188            quality_backend: b,
189            tiered_enabled: false,
190        }
191    }
192
193    #[tokio::test]
194    async fn test_embed_for_write_returns_fast_dim() {
195        let engine = mock_tiered(256, 1024);
196        let embs = engine
197            .embed_for_write(&["hello".to_string()])
198            .await
199            .unwrap();
200        assert_eq!(embs.len(), 1);
201        assert_eq!(embs[0].len(), 256, "write path must use fast backend dim");
202    }
203
204    #[tokio::test]
205    async fn test_embed_for_read_returns_quality_dim() {
206        let engine = mock_tiered(256, 1024);
207        let embs = engine.embed_for_read(&["hello".to_string()]).await.unwrap();
208        assert_eq!(embs.len(), 1);
209        assert_eq!(
210            embs[0].len(),
211            1024,
212            "read path must use quality backend dim"
213        );
214    }
215
216    #[tokio::test]
217    async fn test_embed_query_returns_quality_dim() {
218        let engine = mock_tiered(256, 1024);
219        let emb = engine.embed_query("test query").await.unwrap();
220        assert_eq!(emb.len(), 1024, "embed_query must use quality backend");
221    }
222
223    #[tokio::test]
224    async fn test_single_backend_write_read_same_dim() {
225        let engine = mock_single(768);
226        let w = engine.embed_for_write(&["x".to_string()]).await.unwrap();
227        let r = engine.embed_for_read(&["x".to_string()]).await.unwrap();
228        assert_eq!(w[0].len(), r[0].len(), "non-tiered: write/read same dim");
229        assert_eq!(w[0].len(), 768);
230    }
231
232    #[tokio::test]
233    async fn test_empty_write_returns_empty() {
234        let engine = mock_tiered(256, 1024);
235        let embs = engine.embed_for_write(&[]).await.unwrap();
236        assert!(embs.is_empty());
237    }
238
239    #[tokio::test]
240    async fn test_empty_read_returns_empty() {
241        let engine = mock_tiered(256, 1024);
242        let embs = engine.embed_for_read(&[]).await.unwrap();
243        assert!(embs.is_empty());
244    }
245
246    #[tokio::test]
247    async fn test_is_tiered_flag() {
248        assert!(mock_tiered(256, 1024).is_tiered());
249        assert!(!mock_single(768).is_tiered());
250    }
251
252    #[tokio::test]
253    async fn test_fast_dimension_accessor() {
254        let engine = mock_tiered(256, 1024);
255        assert_eq!(engine.fast_dimension(), 256);
256    }
257
258    #[tokio::test]
259    async fn test_quality_dimension_accessor() {
260        let engine = mock_tiered(256, 1024);
261        assert_eq!(engine.quality_dimension(), 1024);
262    }
263
264    #[tokio::test]
265    async fn test_batch_write_multiple_texts() {
266        let engine = mock_tiered(256, 1024);
267        let texts: Vec<String> = (0..5).map(|i| format!("text {i}")).collect();
268        let embs = engine.embed_for_write(&texts).await.unwrap();
269        assert_eq!(embs.len(), 5, "must return one embedding per text");
270        for e in &embs {
271            assert_eq!(e.len(), 256);
272        }
273    }
274
275    #[tokio::test]
276    async fn test_backend_arc_accessors() {
277        let engine = mock_tiered(256, 1024);
278        assert_eq!(engine.fast_backend().backend_kind(), BackendKind::Static);
279        assert_eq!(engine.quality_backend().backend_kind(), BackendKind::Onnx);
280    }
281}