Skip to main content

orbok_models/
lib.rs

1//! # orbok-models
2//!
3//! Local AI model vocabulary (RFC-012). Milestone M1–M6 only needs the
4//! shared types and the "what is available" summary the UI shows; the
5//! install/locate/validate workflow lands in M12.
6//!
7//! Privacy rule carried from the requirements: model *download* is the
8//! only network operation orbok may ever perform, it is explicit, and
9//! it never involves document contents.
10
11use serde::{Deserialize, Serialize};
12
13/// Model roles (catalog `models.role`).
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum ModelRole {
17    Embedding,
18    Reranker,
19}
20
21impl ModelRole {
22    pub fn as_str(&self) -> &'static str {
23        match self {
24            ModelRole::Embedding => "embedding",
25            ModelRole::Reranker => "reranker",
26        }
27    }
28}
29
30/// Model availability (catalog `models.status`).
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum ModelStatus {
34    Available,
35    Missing,
36    Invalid,
37    Installing,
38    Disabled,
39}
40
41impl ModelStatus {
42    pub fn as_str(&self) -> &'static str {
43        match self {
44            ModelStatus::Available => "available",
45            ModelStatus::Missing => "missing",
46            ModelStatus::Invalid => "invalid",
47            ModelStatus::Installing => "installing",
48            ModelStatus::Disabled => "disabled",
49        }
50    }
51}
52
53/// Search capability derived from model availability. Keyword search
54/// never depends on models (RFC-007: works with zero models installed).
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum SearchCapability {
57    /// Keyword only: no embedding model available.
58    KeywordOnly,
59    /// Keyword + semantic: embedding model available.
60    Hybrid,
61    /// Keyword + semantic + rerank refinement.
62    HybridWithRerank,
63}
64
65/// Derive the capability shown in the UI from model statuses.
66pub fn search_capability(
67    embedding: Option<ModelStatus>,
68    reranker: Option<ModelStatus>,
69) -> SearchCapability {
70    match (embedding, reranker) {
71        (Some(ModelStatus::Available), Some(ModelStatus::Available)) => {
72            SearchCapability::HybridWithRerank
73        }
74        (Some(ModelStatus::Available), _) => SearchCapability::Hybrid,
75        _ => SearchCapability::KeywordOnly,
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    // RFC-007/RFC-010: search degrades gracefully without models.
84    #[test]
85    fn capability_degrades_gracefully() {
86        assert_eq!(search_capability(None, None), SearchCapability::KeywordOnly);
87        assert_eq!(
88            search_capability(Some(ModelStatus::Missing), None),
89            SearchCapability::KeywordOnly
90        );
91        assert_eq!(
92            search_capability(Some(ModelStatus::Available), None),
93            SearchCapability::Hybrid
94        );
95        assert_eq!(
96            search_capability(Some(ModelStatus::Available), Some(ModelStatus::Missing)),
97            SearchCapability::Hybrid
98        );
99        assert_eq!(
100            search_capability(Some(ModelStatus::Available), Some(ModelStatus::Available)),
101            SearchCapability::HybridWithRerank
102        );
103    }
104}
105
106/// A vector search candidate (RFC-008 §13).
107#[derive(Debug, Clone)]
108pub struct VectorCandidate {
109    pub chunk_id: orbok_core::ChunkId,
110    pub file_id: orbok_core::FileId,
111    pub rank: u32,
112    pub score: f32,
113}
114
115/// Local embedding model abstraction (RFC-008 §6).
116///
117/// Implementations must not transmit text externally (NFR-001).
118pub trait EmbeddingModel: Send + Sync {
119    /// Stable name stored in `models.model_name`.
120    fn name(&self) -> &str;
121    /// Version string stored in `models.model_version`.
122    fn version(&self) -> &str;
123    /// Output dimension — must match stored embeddings (RFC-008 §11).
124    fn dimension(&self) -> u32;
125    /// Embed a batch of normalized texts. Returns one vector per input,
126    /// each L2-normalized.
127    fn embed_batch(&self, texts: &[&str]) -> orbok_core::OrbokResult<Vec<Vec<f32>>>;
128}
129
130/// Compute cosine similarity between two L2-normalized vectors.
131pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
132    a.iter().zip(b).map(|(x, y)| x * y).sum()
133}
134
135/// L2-normalize a vector in-place. No-op for the zero vector.
136pub fn l2_normalize(v: &mut Vec<f32>) {
137    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
138    if norm > 1e-10 {
139        for x in v.iter_mut() {
140            *x /= norm;
141        }
142    }
143}
144
145/// Serialize a vector to little-endian bytes for BLOB storage (RFC-008
146/// §12.1 "sqlite_blob with FP32").
147pub fn vec_to_blob(v: &[f32]) -> Vec<u8> {
148    v.iter().flat_map(|x| x.to_le_bytes()).collect()
149}
150
151/// Deserialize from BLOB bytes; returns `None` on length mismatch.
152pub fn blob_to_vec(blob: &[u8], expected_dim: u32) -> Option<Vec<f32>> {
153    let dim = expected_dim as usize;
154    if blob.len() != dim * 4 {
155        return None;
156    }
157    Some(
158        blob.chunks_exact(4)
159            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
160            .collect(),
161    )
162}
163
164// ── Mock model ──────────────────────────────────────────────────────
165
166/// Deterministic 8-dimensional mock embedding model.
167///
168/// Uses the SHA-256 of the input text as a pseudo-random source for 8
169/// f32 components, then L2-normalizes the result.  **Never use for
170/// semantic search** — the outputs are semantically meaningless.
171/// Suitable for pipeline correctness tests (RFC-008 §24 tests 1–10).
172pub struct MockEmbeddingModel;
173
174impl EmbeddingModel for MockEmbeddingModel {
175    fn name(&self) -> &str {
176        "mock"
177    }
178    fn version(&self) -> &str {
179        "v1"
180    }
181    fn dimension(&self) -> u32 {
182        8
183    }
184    fn embed_batch(&self, texts: &[&str]) -> orbok_core::OrbokResult<Vec<Vec<f32>>> {
185        use sha2::{Digest, Sha256};
186        texts
187            .iter()
188            .map(|text| {
189                let digest = Sha256::digest(text.as_bytes());
190                let mut v: Vec<f32> = digest[..8]
191                    .iter()
192                    .map(|&b| b as f32 / 255.0)
193                    .collect();
194                l2_normalize(&mut v);
195                Ok(v)
196            })
197            .collect()
198    }
199}
200
201#[cfg(test)]
202mod embedding_tests {
203    use super::*;
204
205    // RFC-008 §24 test 2: embedding generation succeeds for sample chunks.
206    #[test]
207    fn mock_embed_batch() {
208        let model = MockEmbeddingModel;
209        let vecs = model.embed_batch(&["hello world", "foo bar"]).unwrap();
210        assert_eq!(vecs.len(), 2);
211        for v in &vecs {
212            assert_eq!(v.len(), model.dimension() as usize);
213        }
214    }
215
216    // RFC-008 §24 test 3: dimension mismatch can be detected by caller.
217    #[test]
218    fn blob_roundtrip_and_dim_mismatch() {
219        let v = vec![0.1_f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
220        let blob = vec_to_blob(&v);
221        assert_eq!(blob.len(), 32);
222        let back = blob_to_vec(&blob, 8).unwrap();
223        for (a, b) in v.iter().zip(&back) {
224            assert!((a - b).abs() < 1e-6);
225        }
226        assert!(blob_to_vec(&blob, 16).is_none(), "dim mismatch must return None");
227    }
228
229    // L2 normalization: unit-length vectors.
230    #[test]
231    fn normalize_produces_unit_vector() {
232        let mut v = vec![3.0_f32, 4.0];
233        l2_normalize(&mut v);
234        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
235        assert!((norm - 1.0).abs() < 1e-6);
236    }
237
238    // RFC-008 §24 test 9: cosine sim of identical vectors = 1.0.
239    #[test]
240    fn cosine_sim_identical_vectors() {
241        let mut v = vec![1.0_f32, 2.0, 3.0];
242        l2_normalize(&mut v);
243        let sim = cosine_similarity(&v, &v);
244        assert!((sim - 1.0).abs() < 1e-6);
245    }
246}
247
248// ── Reranker (RFC-010) ───────────────────────────────────────────────
249
250/// A candidate document passed to the reranker.
251#[derive(Debug, Clone)]
252pub struct RerankCandidate {
253    pub chunk_id: orbok_core::ChunkId,
254    /// Best available text for the passage — typically the loaded snippet.
255    pub passage_text: String,
256}
257
258/// Per-candidate rerank score (higher = more relevant).
259#[derive(Debug, Clone)]
260pub struct RerankScore {
261    pub chunk_id: orbok_core::ChunkId,
262    pub score: f32,
263}
264
265/// Optional local cross-encoder reranker (RFC-010 §5).
266///
267/// - Reranking is always optional; missing model must not break search.
268/// - Implementors must not log `passage_text` (NFR-014).
269pub trait CrossEncoderReranker: Send + Sync {
270    fn name(&self) -> &str;
271    fn version(&self) -> &str;
272    /// Maximum candidates to rerank (RFC-010 §9 top-N limit).
273    fn max_candidates(&self) -> u32;
274    fn rerank(&self, query: &str, candidates: &[RerankCandidate])
275        -> orbok_core::OrbokResult<Vec<RerankScore>>;
276}
277
278/// Deterministic mock reranker: scores by passage length (longer = more
279/// informative). Useful for pipeline testing without an ML model.
280pub struct MockReranker;
281
282impl CrossEncoderReranker for MockReranker {
283    fn name(&self) -> &str { "mock-reranker" }
284    fn version(&self) -> &str { "v1" }
285    fn max_candidates(&self) -> u32 { 20 }
286    fn rerank(
287        &self,
288        _query: &str,
289        candidates: &[RerankCandidate],
290    ) -> orbok_core::OrbokResult<Vec<RerankScore>> {
291        let mut scores: Vec<RerankScore> = candidates
292            .iter()
293            .map(|c| RerankScore {
294                chunk_id: c.chunk_id.clone(),
295                score: c.passage_text.len() as f32,
296            })
297            .collect();
298        scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
299        Ok(scores)
300    }
301}
302
303#[cfg(test)]
304mod reranker_tests {
305    use super::*;
306    use orbok_core::ChunkId;
307
308    // RFC-010 §19 test 4: reranker changes final order when scores differ.
309    #[test]
310    fn mock_reranker_orders_by_length() {
311        let r = MockReranker;
312        let candidates = vec![
313            RerankCandidate { chunk_id: ChunkId::from_string("c1".to_string()), passage_text: "short".into() },
314            RerankCandidate { chunk_id: ChunkId::from_string("c2".to_string()), passage_text: "a much longer passage".into() },
315        ];
316        let scores = r.rerank("query", &candidates).unwrap();
317        assert_eq!(scores[0].chunk_id.as_str(), "c2", "longer passage should rank first");
318    }
319
320    // RFC-010 §20: missing reranker does not break search.
321    #[test]
322    fn rerank_max_candidates_limit() {
323        assert!(MockReranker.max_candidates() > 0);
324    }
325}
326
327// ── Inference backend (M12) ──────────────────────────────────────────
328
329/// The compute backend used for local inference.
330#[derive(Debug, Clone, PartialEq, Eq)]
331pub enum InferenceBackend {
332    /// CPU-only inference via candle (no GPU required).
333    CandleCpu,
334    /// GPU inference via candle + CUDA.
335    CandleCuda,
336    /// ONNX Runtime (CPU or GPU via execution provider).
337    OnnxRuntime,
338    /// Mock backend for tests — deterministic, no model files.
339    Mock,
340}
341
342impl InferenceBackend {
343    pub fn as_str(&self) -> &'static str {
344        match self {
345            InferenceBackend::CandleCpu => "candle-cpu",
346            InferenceBackend::CandleCuda => "candle-cuda",
347            InferenceBackend::OnnxRuntime => "onnx-runtime",
348            InferenceBackend::Mock => "mock",
349        }
350    }
351}
352
353/// Configuration for loading a real embedding model from disk.
354///
355/// This is the configuration type callers populate to construct a real
356/// `EmbeddingModel` implementation via a future `BackendLoader`. The
357/// `MockEmbeddingModel` ignores this; it is used only when testing the
358/// pipeline without model files.
359///
360/// Once a `candle` or `onnx-runtime` integration crate is added (M12
361/// full implementation), it will consume this config and return a
362/// `Box<dyn EmbeddingModel>`.
363#[derive(Debug, Clone)]
364pub struct EmbeddingModelConfig {
365    /// Path to the model weights file (ONNX `.onnx` or safetensors).
366    pub weights_path: String,
367    /// Tokenizer config path (tokenizer.json for HuggingFace tokenizers).
368    pub tokenizer_path: Option<String>,
369    /// Expected embedding dimension.
370    pub dimension: u32,
371    /// Maximum input token length (truncation limit).
372    pub max_seq_len: u32,
373    /// Compute backend selection.
374    pub backend: InferenceBackend,
375    /// Model name for registry (e.g. "nomic-embed-text-v1.5").
376    pub model_name: String,
377    /// Model version string.
378    pub model_version: String,
379}
380
381impl EmbeddingModelConfig {
382    /// Check that the model weights file exists on disk.
383    pub fn weights_exist(&self) -> bool {
384        std::path::Path::new(&self.weights_path).exists()
385    }
386}
387
388/// Configuration for a cross-encoder reranker model.
389#[derive(Debug, Clone)]
390pub struct RerankerConfig {
391    pub weights_path: String,
392    pub tokenizer_path: Option<String>,
393    pub max_seq_len: u32,
394    pub backend: InferenceBackend,
395    pub model_name: String,
396    pub model_version: String,
397}
398
399// ── Vector quantization (RFC-024) ───────────────────────────────────
400
401/// Quantize an L2-normalized FP32 vector to INT8.
402///
403/// Maps `[-1.0, +1.0]` → `[-127, +127]` (values outside clip to ±127).
404/// Storage cost: 4× smaller than FP32 (1 byte vs 4 bytes per component).
405/// Quality impact: typically < 2% recall degradation for 384-dim models.
406pub fn quantize_to_i8(v: &[f32]) -> Vec<i8> {
407    v.iter()
408        .map(|&x| (x * 127.0).round().clamp(-127.0, 127.0) as i8)
409        .collect()
410}
411
412/// Dequantize INT8 back to FP32 for similarity computation.
413pub fn dequantize_from_i8(v: &[i8]) -> Vec<f32> {
414    v.iter().map(|&x| x as f32 / 127.0).collect()
415}
416
417/// Serialize INT8 vector to bytes for BLOB storage.
418pub fn i8_vec_to_blob(v: &[i8]) -> Vec<u8> {
419    // i8 values stored as raw bytes (same as u8 cast).
420    v.iter().map(|&x| x as u8).collect()
421}
422
423/// Deserialize INT8 vector from BLOB bytes.
424pub fn i8_blob_to_vec(blob: &[u8], expected_dim: u32) -> Option<Vec<i8>> {
425    if blob.len() != expected_dim as usize {
426        return None;
427    }
428    Some(blob.iter().map(|&b| b as i8).collect())
429}
430
431/// Compute approximate cosine similarity from INT8 vectors via FP32 conversion.
432/// For exact INT8 dot-product, a SIMD-optimised path would be preferable;
433/// this provides correct results at lower compute cost than full FP32.
434pub fn cosine_similarity_i8(a: &[i8], b: &[i8]) -> f32 {
435    cosine_similarity(&dequantize_from_i8(a), &dequantize_from_i8(b))
436}
437
438#[cfg(test)]
439mod quantization_tests {
440    use super::*;
441
442    // RFC-024 AC: FP32 baseline exists — quantization is optional.
443    #[test]
444    fn fp32_and_i8_both_available() {
445        let v = vec![0.6f32, 0.8, 0.0, -0.5];
446        let blob_fp32 = vec_to_blob(&v);
447        let i8_vec = quantize_to_i8(&v);
448        let blob_i8 = i8_vec_to_blob(&i8_vec);
449        // INT8 is 4× smaller.
450        assert_eq!(blob_i8.len() * 4, blob_fp32.len());
451    }
452
453    // RFC-024 AC: Storage savings measured (4× with INT8).
454    #[test]
455    fn int8_is_4x_smaller_than_fp32() {
456        let v: Vec<f32> = (0..384).map(|i| (i as f32 / 384.0) - 0.5).collect();
457        let mut vn = v.clone();
458        l2_normalize(&mut vn);
459        let fp32_bytes = vec_to_blob(&vn).len();
460        let int8_bytes = i8_vec_to_blob(&quantize_to_i8(&vn)).len();
461        assert_eq!(fp32_bytes, 384 * 4);
462        assert_eq!(int8_bytes, 384);
463        assert_eq!(fp32_bytes / int8_bytes, 4);
464    }
465
466    // RFC-024 AC: Quality loss measured (cosine sim error < 0.02 for normalised vectors).
467    #[test]
468    fn quantization_quality_loss_is_small() {
469        let mut v: Vec<f32> = (0..384)
470            .map(|i| ((i as f32 * 0.017).sin()))
471            .collect();
472        l2_normalize(&mut v);
473        let q = quantize_to_i8(&v);
474        let original_self_sim = cosine_similarity(&v, &v);
475        let quantized_self_sim = cosine_similarity_i8(&q, &q);
476        // After dequantize, self-sim should still be ~1.0.
477        assert!((quantized_self_sim - original_self_sim).abs() < 0.02,
478            "quantization quality loss too high: {:.4}", (quantized_self_sim - original_self_sim).abs());
479    }
480
481    // RFC-024 AC: Vector format migration defined (FP32 ↔ INT8 round-trip).
482    #[test]
483    fn fp32_int8_roundtrip_within_tolerance() {
484        let mut v: Vec<f32> = vec![0.3, -0.7, 0.5, 0.1, -0.2, 0.8, -0.4, 0.6];
485        l2_normalize(&mut v);
486        let quantized = quantize_to_i8(&v);
487        let dequantized = dequantize_from_i8(&quantized);
488        for (orig, deq) in v.iter().zip(&dequantized) {
489            assert!((orig - deq).abs() < 0.01,
490                "round-trip error too large: {orig:.4} → {deq:.4}");
491        }
492    }
493}