Skip to main content

ainl_context_compiler/
embedder.rs

1//! Tier 2 embedding surface (M3).
2//!
3//! M1 ships the *trait* (so callers can already opt in via `with_embedder`); the M3 milestone
4//! adds a concrete adapter (e.g. `OnnxMiniLMEmbedder`) and the rerank step in the orchestrator.
5//! With no embedder injected, the orchestrator stays at Tier 0 / Tier 1.
6
7use std::error::Error;
8use std::fmt;
9use std::hash::{Hash, Hasher};
10
11const PLACEHOLDER_EMBED_DIM: usize = 16;
12
13/// Errors an [`Embedder`] implementation may return.
14#[derive(Debug)]
15pub enum EmbedderError {
16    /// Network / IO error.
17    Transport(String),
18    /// Model not loaded.
19    ModelMissing,
20    /// Catch-all.
21    Other(String),
22}
23
24impl fmt::Display for EmbedderError {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        match self {
27            Self::Transport(m) => write!(f, "embedder transport: {m}"),
28            Self::ModelMissing => f.write_str("embedder model not loaded"),
29            Self::Other(m) => write!(f, "embedder: {m}"),
30        }
31    }
32}
33
34impl Error for EmbedderError {}
35
36/// Pluggable embedding backend (M3).
37///
38/// Implementations should return a fixed-dimension `Vec<f32>` per text; the orchestrator
39/// computes cosine similarity between the latest user query and each segment to rerank.
40///
41/// Marked `Send + Sync` so a single embedder instance can be shared via `Arc`.
42pub trait Embedder: Send + Sync {
43    /// Embed a single text. Returns a fixed-dimension vector.
44    fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError>;
45
46    /// Embed a batch (default impl loops; backends should override for efficiency).
47    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedderError> {
48        texts.iter().map(|t| self.embed(t)).collect()
49    }
50}
51
52/// Deterministic low-dimensional embedding for **tests** and offline M3 development.
53/// Maps text to a L2-normalized `PLACEHOLDER_EMBED_DIM`-vector from a 64-bit hash of the
54/// string — not semantically meaningful; cosine ranks correlate weakly with lexical overlap.
55#[derive(Debug, Default, Clone, Copy)]
56pub struct PlaceholderEmbedder;
57
58impl PlaceholderEmbedder {
59    /// Create a new deterministic hasher-based embedder (L2-normalized, fixed dimension).
60    #[must_use]
61    pub const fn new() -> Self {
62        Self
63    }
64}
65
66impl Embedder for PlaceholderEmbedder {
67    fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
68        use std::collections::hash_map::DefaultHasher;
69        let mut h = DefaultHasher::new();
70        text.hash(&mut h);
71        let x = h.finish();
72        let mut v = vec![0f32; PLACEHOLDER_EMBED_DIM];
73        for (i, slot) in v.iter_mut().enumerate() {
74            *slot = (((x >> (i * 4)) & 0xF) as f32) / 15.0;
75        }
76        let n: f32 = v.iter().map(|e| e * e).sum::<f32>().sqrt();
77        if n > 0.0 {
78            for t in v.iter_mut() {
79                *t /= n;
80            }
81        }
82        Ok(v)
83    }
84}
85
86/// Cosine similarity between two equal-length vectors. Returns 0.0 on length mismatch.
87#[must_use]
88pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
89    if a.len() != b.len() || a.is_empty() {
90        return 0.0;
91    }
92    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
93    let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
94    let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
95    if na == 0.0 || nb == 0.0 {
96        0.0
97    } else {
98        dot / (na * nb)
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn cosine_identical_is_one() {
108        let v = vec![1.0, 2.0, 3.0];
109        assert!((cosine(&v, &v) - 1.0).abs() < 1e-6);
110    }
111
112    #[test]
113    fn cosine_orthogonal_is_zero() {
114        let a = vec![1.0, 0.0];
115        let b = vec![0.0, 1.0];
116        assert!(cosine(&a, &b).abs() < 1e-6);
117    }
118
119    #[test]
120    fn cosine_mismatched_lengths_returns_zero() {
121        let a = vec![1.0, 0.0];
122        let b = vec![1.0, 0.0, 0.0];
123        assert_eq!(cosine(&a, &b), 0.0);
124    }
125
126    #[test]
127    fn placeholder_l2_unit_vector() {
128        let e = PlaceholderEmbedder::new();
129        let v = e.embed("hello world").expect("ok");
130        assert_eq!(v.len(), PLACEHOLDER_EMBED_DIM);
131        let n: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
132        assert!((n - 1.0).abs() < 1e-5 || n.abs() < 1e-5);
133    }
134}