Skip to main content

orbok_models/
lib.rs

1//! # orbok-models
2//!
3//! Local AI model vocabulary (RFC-012). Milestone M1–M6 only needs the
4//! shared types and the "what is available" summary the UI shows; the
5//! install/locate/validate workflow lands in M12.
6//!
7//! Privacy rule carried from the requirements: model *download* is the
8//! only network operation orbok may ever perform, it is explicit, and
9//! it never involves document contents.
10
11use serde::{Deserialize, Serialize};
12
13/// Model roles (catalog `models.role`).
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum ModelRole {
17    Embedding,
18    Reranker,
19}
20
21impl ModelRole {
22    pub fn as_str(&self) -> &'static str {
23        match self {
24            ModelRole::Embedding => "embedding",
25            ModelRole::Reranker => "reranker",
26        }
27    }
28}
29
30/// Model availability (catalog `models.status`).
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
32#[serde(rename_all = "snake_case")]
33pub enum ModelStatus {
34    Available,
35    Missing,
36    Invalid,
37    Installing,
38    Disabled,
39}
40
41impl ModelStatus {
42    pub fn as_str(&self) -> &'static str {
43        match self {
44            ModelStatus::Available => "available",
45            ModelStatus::Missing => "missing",
46            ModelStatus::Invalid => "invalid",
47            ModelStatus::Installing => "installing",
48            ModelStatus::Disabled => "disabled",
49        }
50    }
51}
52
53/// Search capability derived from model availability. Keyword search
54/// never depends on models (RFC-007: works with zero models installed).
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum SearchCapability {
57    /// Keyword only: no embedding model available.
58    KeywordOnly,
59    /// Keyword + semantic: embedding model available.
60    Hybrid,
61    /// Keyword + semantic + rerank refinement.
62    HybridWithRerank,
63}
64
65/// Derive the capability shown in the UI from model statuses.
66pub fn search_capability(
67    embedding: Option<ModelStatus>,
68    reranker: Option<ModelStatus>,
69) -> SearchCapability {
70    match (embedding, reranker) {
71        (Some(ModelStatus::Available), Some(ModelStatus::Available)) => {
72            SearchCapability::HybridWithRerank
73        }
74        (Some(ModelStatus::Available), _) => SearchCapability::Hybrid,
75        _ => SearchCapability::KeywordOnly,
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    // RFC-007/RFC-010: search degrades gracefully without models.
84    #[test]
85    fn capability_degrades_gracefully() {
86        assert_eq!(search_capability(None, None), SearchCapability::KeywordOnly);
87        assert_eq!(
88            search_capability(Some(ModelStatus::Missing), None),
89            SearchCapability::KeywordOnly
90        );
91        assert_eq!(
92            search_capability(Some(ModelStatus::Available), None),
93            SearchCapability::Hybrid
94        );
95        assert_eq!(
96            search_capability(Some(ModelStatus::Available), Some(ModelStatus::Missing)),
97            SearchCapability::Hybrid
98        );
99        assert_eq!(
100            search_capability(Some(ModelStatus::Available), Some(ModelStatus::Available)),
101            SearchCapability::HybridWithRerank
102        );
103    }
104}
105
106/// A vector search candidate (RFC-008 §13).
107#[derive(Debug, Clone)]
108pub struct VectorCandidate {
109    pub chunk_id: orbok_core::ChunkId,
110    pub file_id: orbok_core::FileId,
111    pub rank: u32,
112    pub score: f32,
113}
114
115/// Local embedding model abstraction (RFC-008 §6).
116///
117/// Implementations must not transmit text externally (NFR-001).
118pub trait EmbeddingModel: Send + Sync {
119    /// Stable name stored in `models.model_name`.
120    fn name(&self) -> &str;
121    /// Version string stored in `models.model_version`.
122    fn version(&self) -> &str;
123    /// Output dimension — must match stored embeddings (RFC-008 §11).
124    fn dimension(&self) -> u32;
125    /// Embed a batch of normalized texts. Returns one vector per input,
126    /// each L2-normalized.
127    fn embed_batch(&self, texts: &[&str]) -> orbok_core::OrbokResult<Vec<Vec<f32>>>;
128}
129
130/// Compute cosine similarity between two L2-normalized vectors.
131pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
132    a.iter().zip(b).map(|(x, y)| x * y).sum()
133}
134
135/// L2-normalize a vector in-place. No-op for the zero vector.
136pub fn l2_normalize(v: &mut Vec<f32>) {
137    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
138    if norm > 1e-10 {
139        for x in v.iter_mut() {
140            *x /= norm;
141        }
142    }
143}
144
145/// Serialize a vector to little-endian bytes for BLOB storage (RFC-008
146/// §12.1 "sqlite_blob with FP32").
147pub fn vec_to_blob(v: &[f32]) -> Vec<u8> {
148    v.iter().flat_map(|x| x.to_le_bytes()).collect()
149}
150
151/// Deserialize from BLOB bytes; returns `None` on length mismatch.
152pub fn blob_to_vec(blob: &[u8], expected_dim: u32) -> Option<Vec<f32>> {
153    let dim = expected_dim as usize;
154    if blob.len() != dim * 4 {
155        return None;
156    }
157    Some(
158        blob.chunks_exact(4)
159            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
160            .collect(),
161    )
162}
163
164// ── Mock model ──────────────────────────────────────────────────────
165
166/// Deterministic 8-dimensional mock embedding model.
167///
168/// Uses the SHA-256 of the input text as a pseudo-random source for 8
169/// f32 components, then L2-normalizes the result.  **Never use for
170/// semantic search** — the outputs are semantically meaningless.
171/// Suitable for pipeline correctness tests (RFC-008 §24 tests 1–10).
172pub struct MockEmbeddingModel;
173
174impl EmbeddingModel for MockEmbeddingModel {
175    fn name(&self) -> &str {
176        "mock"
177    }
178    fn version(&self) -> &str {
179        "v1"
180    }
181    fn dimension(&self) -> u32 {
182        8
183    }
184    fn embed_batch(&self, texts: &[&str]) -> orbok_core::OrbokResult<Vec<Vec<f32>>> {
185        use sha2::{Digest, Sha256};
186        texts
187            .iter()
188            .map(|text| {
189                let digest = Sha256::digest(text.as_bytes());
190                let mut v: Vec<f32> = digest[..8].iter().map(|&b| b as f32 / 255.0).collect();
191                l2_normalize(&mut v);
192                Ok(v)
193            })
194            .collect()
195    }
196}
197
198#[cfg(test)]
199mod embedding_tests {
200    use super::*;
201
202    // RFC-008 §24 test 2: embedding generation succeeds for sample chunks.
203    #[test]
204    fn mock_embed_batch() {
205        let model = MockEmbeddingModel;
206        let vecs = model.embed_batch(&["hello world", "foo bar"]).unwrap();
207        assert_eq!(vecs.len(), 2);
208        for v in &vecs {
209            assert_eq!(v.len(), model.dimension() as usize);
210        }
211    }
212
213    // RFC-008 §24 test 3: dimension mismatch can be detected by caller.
214    #[test]
215    fn blob_roundtrip_and_dim_mismatch() {
216        let v = vec![0.1_f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
217        let blob = vec_to_blob(&v);
218        assert_eq!(blob.len(), 32);
219        let back = blob_to_vec(&blob, 8).unwrap();
220        for (a, b) in v.iter().zip(&back) {
221            assert!((a - b).abs() < 1e-6);
222        }
223        assert!(
224            blob_to_vec(&blob, 16).is_none(),
225            "dim mismatch must return None"
226        );
227    }
228
229    // L2 normalization: unit-length vectors.
230    #[test]
231    fn normalize_produces_unit_vector() {
232        let mut v = vec![3.0_f32, 4.0];
233        l2_normalize(&mut v);
234        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
235        assert!((norm - 1.0).abs() < 1e-6);
236    }
237
238    // RFC-008 §24 test 9: cosine sim of identical vectors = 1.0.
239    #[test]
240    fn cosine_sim_identical_vectors() {
241        let mut v = vec![1.0_f32, 2.0, 3.0];
242        l2_normalize(&mut v);
243        let sim = cosine_similarity(&v, &v);
244        assert!((sim - 1.0).abs() < 1e-6);
245    }
246}
247
248// ── Reranker (RFC-010) ───────────────────────────────────────────────
249
250/// A candidate document passed to the reranker.
251#[derive(Debug, Clone)]
252pub struct RerankCandidate {
253    pub chunk_id: orbok_core::ChunkId,
254    /// Best available text for the passage — typically the loaded snippet.
255    pub passage_text: String,
256}
257
258/// Per-candidate rerank score (higher = more relevant).
259#[derive(Debug, Clone)]
260pub struct RerankScore {
261    pub chunk_id: orbok_core::ChunkId,
262    pub score: f32,
263}
264
265/// Optional local cross-encoder reranker (RFC-010 §5).
266///
267/// - Reranking is always optional; missing model must not break search.
268/// - Implementors must not log `passage_text` (NFR-014).
269pub trait CrossEncoderReranker: Send + Sync {
270    fn name(&self) -> &str;
271    fn version(&self) -> &str;
272    /// Maximum candidates to rerank (RFC-010 §9 top-N limit).
273    fn max_candidates(&self) -> u32;
274    fn rerank(
275        &self,
276        query: &str,
277        candidates: &[RerankCandidate],
278    ) -> orbok_core::OrbokResult<Vec<RerankScore>>;
279}
280
281/// Deterministic mock reranker: scores by passage length (longer = more
282/// informative). Useful for pipeline testing without an ML model.
283pub struct MockReranker;
284
285impl CrossEncoderReranker for MockReranker {
286    fn name(&self) -> &str {
287        "mock-reranker"
288    }
289    fn version(&self) -> &str {
290        "v1"
291    }
292    fn max_candidates(&self) -> u32 {
293        20
294    }
295    fn rerank(
296        &self,
297        _query: &str,
298        candidates: &[RerankCandidate],
299    ) -> orbok_core::OrbokResult<Vec<RerankScore>> {
300        let mut scores: Vec<RerankScore> = candidates
301            .iter()
302            .map(|c| RerankScore {
303                chunk_id: c.chunk_id.clone(),
304                score: c.passage_text.len() as f32,
305            })
306            .collect();
307        scores.sort_by(|a, b| {
308            b.score
309                .partial_cmp(&a.score)
310                .unwrap_or(std::cmp::Ordering::Equal)
311        });
312        Ok(scores)
313    }
314}
315
316#[cfg(test)]
317mod reranker_tests {
318    use super::*;
319    use orbok_core::ChunkId;
320
321    // RFC-010 §19 test 4: reranker changes final order when scores differ.
322    #[test]
323    fn mock_reranker_orders_by_length() {
324        let r = MockReranker;
325        let candidates = vec![
326            RerankCandidate {
327                chunk_id: ChunkId::from_string("c1".to_string()),
328                passage_text: "short".into(),
329            },
330            RerankCandidate {
331                chunk_id: ChunkId::from_string("c2".to_string()),
332                passage_text: "a much longer passage".into(),
333            },
334        ];
335        let scores = r.rerank("query", &candidates).unwrap();
336        assert_eq!(
337            scores[0].chunk_id.as_str(),
338            "c2",
339            "longer passage should rank first"
340        );
341    }
342
343    // RFC-010 §20: missing reranker does not break search.
344    #[test]
345    fn rerank_max_candidates_limit() {
346        assert!(MockReranker.max_candidates() > 0);
347    }
348}
349
350// ── Inference backend (M12) ──────────────────────────────────────────
351
352/// The compute backend used for local inference.
353#[derive(Debug, Clone, PartialEq, Eq)]
354pub enum InferenceBackend {
355    /// CPU-only inference via candle (no GPU required).
356    CandleCpu,
357    /// GPU inference via candle + CUDA.
358    CandleCuda,
359    /// ONNX Runtime (CPU or GPU via execution provider).
360    OnnxRuntime,
361    /// Mock backend for tests — deterministic, no model files.
362    Mock,
363}
364
365impl InferenceBackend {
366    pub fn as_str(&self) -> &'static str {
367        match self {
368            InferenceBackend::CandleCpu => "candle-cpu",
369            InferenceBackend::CandleCuda => "candle-cuda",
370            InferenceBackend::OnnxRuntime => "onnx-runtime",
371            InferenceBackend::Mock => "mock",
372        }
373    }
374}
375
376/// Configuration for loading a real embedding model from disk.
377///
378/// This is the configuration type callers populate to construct a real
379/// `EmbeddingModel` implementation via a future `BackendLoader`. The
380/// `MockEmbeddingModel` ignores this; it is used only when testing the
381/// pipeline without model files.
382///
383/// Once a `candle` or `onnx-runtime` integration crate is added (M12
384/// full implementation), it will consume this config and return a
385/// `Box<dyn EmbeddingModel>`.
386#[derive(Debug, Clone)]
387pub struct EmbeddingModelConfig {
388    /// Path to the model weights file (ONNX `.onnx` or safetensors).
389    pub weights_path: String,
390    /// Tokenizer config path (tokenizer.json for HuggingFace tokenizers).
391    pub tokenizer_path: Option<String>,
392    /// Expected embedding dimension.
393    pub dimension: u32,
394    /// Maximum input token length (truncation limit).
395    pub max_seq_len: u32,
396    /// Compute backend selection.
397    pub backend: InferenceBackend,
398    /// Model name for registry (e.g. "nomic-embed-text-v1.5").
399    pub model_name: String,
400    /// Model version string.
401    pub model_version: String,
402}
403
404impl EmbeddingModelConfig {
405    /// Check that the model weights file exists on disk.
406    pub fn weights_exist(&self) -> bool {
407        std::path::Path::new(&self.weights_path).exists()
408    }
409}
410
411/// Configuration for a cross-encoder reranker model.
412#[derive(Debug, Clone)]
413pub struct RerankerConfig {
414    pub weights_path: String,
415    pub tokenizer_path: Option<String>,
416    pub max_seq_len: u32,
417    pub backend: InferenceBackend,
418    pub model_name: String,
419    pub model_version: String,
420}
421
422// ── Vector quantization (RFC-024) ───────────────────────────────────
423
424/// Quantize an L2-normalized FP32 vector to INT8.
425///
426/// Maps `[-1.0, +1.0]` → `[-127, +127]` (values outside clip to ±127).
427/// Storage cost: 4× smaller than FP32 (1 byte vs 4 bytes per component).
428/// Quality impact: typically < 2% recall degradation for 384-dim models.
429pub fn quantize_to_i8(v: &[f32]) -> Vec<i8> {
430    v.iter()
431        .map(|&x| (x * 127.0).round().clamp(-127.0, 127.0) as i8)
432        .collect()
433}
434
435/// Dequantize INT8 back to FP32 for similarity computation.
436pub fn dequantize_from_i8(v: &[i8]) -> Vec<f32> {
437    v.iter().map(|&x| x as f32 / 127.0).collect()
438}
439
440/// Serialize INT8 vector to bytes for BLOB storage.
441pub fn i8_vec_to_blob(v: &[i8]) -> Vec<u8> {
442    // i8 values stored as raw bytes (same as u8 cast).
443    v.iter().map(|&x| x as u8).collect()
444}
445
446/// Deserialize INT8 vector from BLOB bytes.
447pub fn i8_blob_to_vec(blob: &[u8], expected_dim: u32) -> Option<Vec<i8>> {
448    if blob.len() != expected_dim as usize {
449        return None;
450    }
451    Some(blob.iter().map(|&b| b as i8).collect())
452}
453
454/// Compute approximate cosine similarity from INT8 vectors via FP32 conversion.
455/// For exact INT8 dot-product, a SIMD-optimised path would be preferable;
456/// this provides correct results at lower compute cost than full FP32.
457pub fn cosine_similarity_i8(a: &[i8], b: &[i8]) -> f32 {
458    cosine_similarity(&dequantize_from_i8(a), &dequantize_from_i8(b))
459}
460
461#[cfg(test)]
462mod quantization_tests {
463    use super::*;
464
465    // RFC-024 AC: FP32 baseline exists — quantization is optional.
466    #[test]
467    fn fp32_and_i8_both_available() {
468        let v = vec![0.6f32, 0.8, 0.0, -0.5];
469        let blob_fp32 = vec_to_blob(&v);
470        let i8_vec = quantize_to_i8(&v);
471        let blob_i8 = i8_vec_to_blob(&i8_vec);
472        // INT8 is 4× smaller.
473        assert_eq!(blob_i8.len() * 4, blob_fp32.len());
474    }
475
476    // RFC-024 AC: Storage savings measured (4× with INT8).
477    #[test]
478    fn int8_is_4x_smaller_than_fp32() {
479        let v: Vec<f32> = (0..384).map(|i| (i as f32 / 384.0) - 0.5).collect();
480        let mut vn = v.clone();
481        l2_normalize(&mut vn);
482        let fp32_bytes = vec_to_blob(&vn).len();
483        let int8_bytes = i8_vec_to_blob(&quantize_to_i8(&vn)).len();
484        assert_eq!(fp32_bytes, 384 * 4);
485        assert_eq!(int8_bytes, 384);
486        assert_eq!(fp32_bytes / int8_bytes, 4);
487    }
488
489    // RFC-024 AC: Quality loss measured (cosine sim error < 0.02 for normalised vectors).
490    #[test]
491    fn quantization_quality_loss_is_small() {
492        let mut v: Vec<f32> = (0..384).map(|i| ((i as f32 * 0.017).sin())).collect();
493        l2_normalize(&mut v);
494        let q = quantize_to_i8(&v);
495        let original_self_sim = cosine_similarity(&v, &v);
496        let quantized_self_sim = cosine_similarity_i8(&q, &q);
497        // After dequantize, self-sim should still be ~1.0.
498        assert!(
499            (quantized_self_sim - original_self_sim).abs() < 0.02,
500            "quantization quality loss too high: {:.4}",
501            (quantized_self_sim - original_self_sim).abs()
502        );
503    }
504
505    // RFC-024 AC: Vector format migration defined (FP32 ↔ INT8 round-trip).
506    #[test]
507    fn fp32_int8_roundtrip_within_tolerance() {
508        let mut v: Vec<f32> = vec![0.3, -0.7, 0.5, 0.1, -0.2, 0.8, -0.4, 0.6];
509        l2_normalize(&mut v);
510        let quantized = quantize_to_i8(&v);
511        let dequantized = dequantize_from_i8(&quantized);
512        for (orig, deq) in v.iter().zip(&dequantized) {
513            assert!(
514                (orig - deq).abs() < 0.01,
515                "round-trip error too large: {orig:.4} → {deq:.4}"
516            );
517        }
518    }
519}