Skip to main content

harn_hostlib/embed/
backend.rs

1//! Pluggable embedding backends.
2//!
3//! The [`Embedder`] trait is the futureproof seam: every backend turns a
4//! string into a fixed-dimension `f32` vector that downstream cosine math
5//! ranks. Two pure-Rust, zero-asset, fully-offline backends ship in-tree:
6//!
7//! * [`LexicalEmbedder`] — the always-available default. A hashed
8//!   bag-of-features (word tokens + char trigrams) projected into a fixed
9//!   dimension via the hashing trick, then L2-normalized. No model, no
10//!   asset, no network; microsecond latency; deterministic across OSes.
11//!   This is the graceful-degradation floor: it is what every other
12//!   backend falls back to when its asset is missing.
13//!
14//! * [`StaticEmbedder`] — a Model2Vec / "potion"-style static
15//!   token-pooled embedder. Loads precomputed per-token vectors from a
16//!   resolved-on-disk asset, then `tokenize -> lookup -> mean -> normalize`
17//!   with no neural-network inference. Microsecond latency, ~92% of
18//!   MiniLM-class quality when the asset is present. Constructed via
19//!   [`StaticEmbedder::from_asset_dir`], which fails cleanly (so callers
20//!   fall back to lexical) when the asset is absent.
21//!
22//! A higher-accuracy on-device transformer backend (candle / ONNX) can be
23//! added later behind a Cargo feature without changing this trait or any
24//! consumer: implement [`Embedder`], resolve its asset the same way, and
25//! register it as the active backend when the feature + setting are on.
26
27use std::collections::HashMap;
28use std::path::{Path, PathBuf};
29
30use super::tokenize;
31
32/// A backend that maps text to a fixed-dimension embedding vector.
33///
34/// Implementations must be `Send + Sync` so the capability can share one
35/// instance across every Harn VM / thread, matching the `code_index`
36/// concurrency model. Vectors returned SHOULD be L2-normalized so the
37/// cosine math degenerates to a dot product, but [`super::similarity`]
38/// normalizes defensively regardless.
39pub trait Embedder: Send + Sync {
40    /// Embed a single string. Empty/degenerate input returns a zero vector
41    /// of length [`Embedder::dim`] (cosine against it is `0.0`).
42    fn embed(&self, text: &str) -> Vec<f32>;
43
44    /// Output dimensionality. Stable for the life of the backend.
45    fn dim(&self) -> usize;
46
47    /// Stable backend identifier surfaced in `hostlib_embed_info` so
48    /// consumers (and evals) can record which backend produced a score.
49    fn name(&self) -> &str;
50
51    /// Embed a batch. Default maps [`Embedder::embed`]; backends with a
52    /// cheaper batched path may override.
53    fn embed_batch(&self, texts: &[String]) -> Vec<Vec<f32>> {
54        texts.iter().map(|t| self.embed(t)).collect()
55    }
56}
57
58/// Deterministic 64-bit FNV-1a hash — stable across platforms and runs (the
59/// stdlib `DefaultHasher` is explicitly not stability-guaranteed, which
60/// would make embeddings drift between toolchains). We need the projection
61/// to be identical on macOS, Linux, and Windows so a query embedded on one
62/// host ranks a corpus embedded on another.
63fn fnv1a(bytes: &[u8], seed: u64) -> u64 {
64    const FNV_PRIME: u64 = 0x0000_0100_0000_01B3;
65    let mut hash = seed ^ 0xcbf2_9ce4_8422_2325;
66    for &b in bytes {
67        hash ^= b as u64;
68        hash = hash.wrapping_mul(FNV_PRIME);
69    }
70    hash
71}
72
73/// Hashing-trick lexical embedder. Always available, no asset.
74///
75/// Features are word tokens (camel/snake-split, weighted higher) plus char
76/// trigrams (weighted lower, for typo/root robustness). Each feature is
77/// hashed into one of `dim` buckets with a signed contribution (a second
78/// hash bit picks the sign, which de-biases hash collisions — the standard
79/// signed hashing trick). The accumulated vector is L2-normalized.
80pub struct LexicalEmbedder {
81    dim: usize,
82    name: String,
83}
84
85impl LexicalEmbedder {
86    /// Construct with a given output dimension (clamped to `>= 16`). 256 is
87    /// a good default: enough buckets to keep collisions rare for
88    /// sentence-length inputs while staying cache-friendly.
89    pub fn new(dim: usize) -> Self {
90        Self {
91            dim: dim.max(16),
92            name: "lexical-hash".to_string(),
93        }
94    }
95
96    fn add_feature(&self, vec: &mut [f32], feature: &str, weight: f32) {
97        let h = fnv1a(feature.as_bytes(), 0);
98        let bucket = (h % self.dim as u64) as usize;
99        // Sign from an independent hash so collisions cancel in expectation.
100        let sign = if fnv1a(feature.as_bytes(), 0x9e37_79b9_7f4a_7c15) & 1 == 0 {
101            1.0
102        } else {
103            -1.0
104        };
105        vec[bucket] += sign * weight;
106    }
107}
108
109impl Default for LexicalEmbedder {
110    fn default() -> Self {
111        Self::new(256)
112    }
113}
114
115impl Embedder for LexicalEmbedder {
116    fn embed(&self, text: &str) -> Vec<f32> {
117        let mut vec = vec![0.0f32; self.dim];
118        for token in tokenize::word_tokens(text) {
119            self.add_feature(&mut vec, &token, 1.0);
120        }
121        for gram in tokenize::char_ngrams(text, 3) {
122            // Lower weight: char-ngrams are a denser, noisier signal.
123            self.add_feature(&mut vec, &gram, 0.35);
124        }
125        l2_normalize(&mut vec);
126        vec
127    }
128
129    fn dim(&self) -> usize {
130        self.dim
131    }
132
133    fn name(&self) -> &str {
134        &self.name
135    }
136}
137
138/// Model2Vec / potion-style static token-pooled embedder.
139///
140/// Holds a precomputed `token -> vector` table loaded from a vendored
141/// asset. Embedding is `tokenize -> lookup each token's vector -> mean ->
142/// L2-normalize`, with NO neural-network forward pass — that is the entire
143/// point of static embeddings (microsecond latency, tiny footprint).
144///
145/// ## Asset format (intentionally simple + dependency-free)
146///
147/// The asset directory contains `static-embeddings.json`:
148/// ```json
149/// { "dim": 8, "vectors": { "rate": [...8 floats...], "limit": [...] } }
150/// ```
151/// Tokens are the same word tokens [`tokenize::word_tokens`] produces, so a
152/// distilled potion table can be exported into this shape offline. A real
153/// `.safetensors` loader (via `model2vec-rs`) can be added behind a feature
154/// later without touching this trait or the JSON fallback — both satisfy
155/// the same `token -> vector` contract.
156pub struct StaticEmbedder {
157    dim: usize,
158    vectors: HashMap<String, Vec<f32>>,
159    name: String,
160    /// Lexical fallback used when a query contains *no* known tokens, so a
161    /// previously-unseen identifier still gets a non-degenerate vector
162    /// instead of collapsing to zero.
163    fallback: LexicalEmbedder,
164}
165
166impl StaticEmbedder {
167    /// Resolve and load `static-embeddings.json` under `asset_dir`.
168    ///
169    /// Returns `Err` (so the caller can fall back to lexical) when the
170    /// directory or file is missing, unreadable, malformed, or empty. This
171    /// is the sandbox/settings-aware degradation contract: a missing asset
172    /// never panics and never blocks — it just selects the lexical floor.
173    pub fn from_asset_dir(asset_dir: &Path) -> Result<Self, String> {
174        let path = asset_dir.join("static-embeddings.json");
175        let raw = std::fs::read_to_string(&path)
176            .map_err(|e| format!("static embedding asset {} unreadable: {e}", path.display()))?;
177        Self::from_json(&raw)
178    }
179
180    /// Parse an in-memory asset document. Split out for testability and so
181    /// future loaders (safetensors) can reuse the validation.
182    pub fn from_json(raw: &str) -> Result<Self, String> {
183        // Hand-rolled minimal parse keeps the default build dependency-free
184        // (no serde_json pulled in just for an optional asset). The format
185        // is small and fixed; we accept the documented shape only.
186        let doc: AssetDoc = parse_asset(raw)?;
187        if doc.vectors.is_empty() {
188            return Err("static embedding asset has no vectors".to_string());
189        }
190        for (tok, v) in &doc.vectors {
191            if v.len() != doc.dim {
192                return Err(format!(
193                    "static embedding vector for `{tok}` has length {} but dim is {}",
194                    v.len(),
195                    doc.dim
196                ));
197            }
198        }
199        Ok(Self {
200            dim: doc.dim,
201            vectors: doc.vectors,
202            name: "static-model2vec".to_string(),
203            fallback: LexicalEmbedder::new(doc.dim),
204        })
205    }
206}
207
208impl Embedder for StaticEmbedder {
209    fn embed(&self, text: &str) -> Vec<f32> {
210        let mut acc = vec![0.0f32; self.dim];
211        let mut hits = 0usize;
212        for token in tokenize::word_tokens(text) {
213            if let Some(v) = self.vectors.get(&token) {
214                for (a, x) in acc.iter_mut().zip(v.iter()) {
215                    *a += x;
216                }
217                hits += 1;
218            }
219        }
220        if hits == 0 {
221            // No known tokens: fall back to the lexical projection so a
222            // novel identifier is still comparable rather than all-zero.
223            return self.fallback.embed(text);
224        }
225        let inv = 1.0 / hits as f32;
226        for a in acc.iter_mut() {
227            *a *= inv;
228        }
229        l2_normalize(&mut acc);
230        acc
231    }
232
233    fn dim(&self) -> usize {
234        self.dim
235    }
236
237    fn name(&self) -> &str {
238        &self.name
239    }
240}
241
242/// L2-normalize in place. Zero vectors are left as-is (cosine treats them
243/// as `0.0`).
244pub(crate) fn l2_normalize(vec: &mut [f32]) {
245    let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
246    if norm > 0.0 {
247        let inv = 1.0 / norm;
248        for x in vec.iter_mut() {
249            *x *= inv;
250        }
251    }
252}
253
254// --- minimal, dependency-free asset parser -------------------------------
255
256struct AssetDoc {
257    dim: usize,
258    vectors: HashMap<String, Vec<f32>>,
259}
260
261/// Parse the fixed `{ "dim": N, "vectors": { "tok": [floats] } }` shape.
262/// This avoids adding a JSON dependency to the default build for what is an
263/// optional asset; a future safetensors loader supersedes it entirely.
264fn parse_asset(raw: &str) -> Result<AssetDoc, String> {
265    // We lean on the harn-vm value layer? No — keep it standalone. Use a
266    // tiny tolerant scanner: find "dim": <int>, then "vectors": { ... }.
267    let dim = extract_int(raw, "\"dim\"")
268        .ok_or_else(|| "static embedding asset missing integer `dim`".to_string())?;
269    if dim == 0 {
270        return Err("static embedding `dim` must be > 0".to_string());
271    }
272    let vectors = extract_vectors(raw)?;
273    Ok(AssetDoc {
274        dim: dim as usize,
275        vectors,
276    })
277}
278
279fn extract_int(raw: &str, key: &str) -> Option<i64> {
280    let idx = raw.find(key)?;
281    let after = &raw[idx + key.len()..];
282    let colon = after.find(':')?;
283    let rest = after[colon + 1..].trim_start();
284    let end = rest
285        .find(|c: char| !c.is_ascii_digit() && c != '-')
286        .unwrap_or(rest.len());
287    rest[..end].parse::<i64>().ok()
288}
289
290fn extract_vectors(raw: &str) -> Result<HashMap<String, Vec<f32>>, String> {
291    let key = "\"vectors\"";
292    let idx = raw
293        .find(key)
294        .ok_or_else(|| "static embedding asset missing `vectors`".to_string())?;
295    let after = &raw[idx + key.len()..];
296    let open = after
297        .find('{')
298        .ok_or_else(|| "`vectors` is not an object".to_string())?;
299    let body = &after[open + 1..];
300    let mut map = HashMap::new();
301    let bytes = body.as_bytes();
302    let mut i = 0usize;
303    while i < bytes.len() {
304        // find next quote (start of a key) or closing brace of the object
305        match bytes[i] {
306            b'}' => break,
307            b'"' => {
308                let (k, next) = parse_string(body, i)?;
309                i = next;
310                // skip to colon
311                while i < bytes.len() && bytes[i] != b':' {
312                    i += 1;
313                }
314                i += 1;
315                // skip to array open
316                while i < bytes.len() && bytes[i] != b'[' {
317                    i += 1;
318                }
319                let (vec, next) = parse_float_array(body, i)?;
320                i = next;
321                map.insert(k, vec);
322            }
323            _ => i += 1,
324        }
325    }
326    Ok(map)
327}
328
329fn parse_string(s: &str, start: usize) -> Result<(String, usize), String> {
330    let bytes = s.as_bytes();
331    debug_assert_eq!(bytes[start], b'"');
332    let mut i = start + 1;
333    let mut out = String::new();
334    while i < bytes.len() {
335        match bytes[i] {
336            b'"' => return Ok((out, i + 1)),
337            b'\\' if i + 1 < bytes.len() => {
338                out.push(bytes[i + 1] as char);
339                i += 2;
340            }
341            c => {
342                out.push(c as char);
343                i += 1;
344            }
345        }
346    }
347    Err("unterminated string in static embedding asset".to_string())
348}
349
350fn parse_float_array(s: &str, start: usize) -> Result<(Vec<f32>, usize), String> {
351    let bytes = s.as_bytes();
352    if start >= bytes.len() || bytes[start] != b'[' {
353        return Err("expected float array in static embedding asset".to_string());
354    }
355    let mut i = start + 1;
356    let mut out = Vec::new();
357    let mut num = String::new();
358    let flush = |num: &mut String, out: &mut Vec<f32>| -> Result<(), String> {
359        let t = num.trim();
360        if !t.is_empty() {
361            out.push(
362                t.parse::<f32>()
363                    .map_err(|_| format!("bad float `{t}` in static embedding asset"))?,
364            );
365        }
366        num.clear();
367        Ok(())
368    };
369    while i < bytes.len() {
370        match bytes[i] {
371            b']' => {
372                flush(&mut num, &mut out)?;
373                return Ok((out, i + 1));
374            }
375            b',' => {
376                flush(&mut num, &mut out)?;
377                i += 1;
378            }
379            c if c.is_ascii_whitespace() => i += 1,
380            c => {
381                num.push(c as char);
382                i += 1;
383            }
384        }
385    }
386    Err("unterminated float array in static embedding asset".to_string())
387}
388
389/// Resolve the asset directory for a named embedding model, honoring an
390/// explicit override before falling back to a conventional location under
391/// the data dir. Returns `None` when nothing resolvable exists, which the
392/// caller treats as "use the lexical floor".
393///
394/// Resolution order (sandbox/settings-aware):
395/// 1. explicit `override_dir` (from a Harn setting / host call param),
396/// 2. `<data_dir>/embeddings/<model>` (conventional vendored location).
397///
398/// The function never touches the network and never reads outside the
399/// provided roots, so it is safe to call from inside a sandbox.
400pub fn resolve_asset_dir(
401    override_dir: Option<&Path>,
402    data_dir: Option<&Path>,
403    model: &str,
404) -> Option<PathBuf> {
405    if let Some(dir) = override_dir {
406        if dir.join("static-embeddings.json").is_file() {
407            return Some(dir.to_path_buf());
408        }
409    }
410    if let Some(base) = data_dir {
411        let candidate = base.join("embeddings").join(model);
412        if candidate.join("static-embeddings.json").is_file() {
413            return Some(candidate);
414        }
415    }
416    None
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn lexical_identical_text_is_self_similar() {
425        let e = LexicalEmbedder::default();
426        let v = e.embed("rate limiter middleware");
427        assert_eq!(v.len(), 256);
428        let sim = super::super::similarity::cosine(&v, &v);
429        assert!((sim - 1.0).abs() < 1e-5, "self-sim was {sim}");
430    }
431
432    #[test]
433    fn lexical_related_beats_unrelated() {
434        let e = LexicalEmbedder::default();
435        let query = e.embed("rate limiter for the API");
436        let related = e.embed("RateLimiter API throttle");
437        let unrelated = e.embed("parse markdown table renderer");
438        let s_rel = super::super::similarity::cosine(&query, &related);
439        let s_unrel = super::super::similarity::cosine(&query, &unrelated);
440        assert!(
441            s_rel > s_unrel,
442            "related {s_rel} should beat unrelated {s_unrel}"
443        );
444    }
445
446    #[test]
447    fn lexical_empty_is_zero_vector() {
448        let e = LexicalEmbedder::default();
449        let v = e.embed("");
450        assert!(v.iter().all(|&x| x == 0.0));
451    }
452
453    #[test]
454    fn lexical_is_l2_normalized() {
455        let e = LexicalEmbedder::default();
456        let v = e.embed("hello world embedding test");
457        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
458        assert!((norm - 1.0).abs() < 1e-5, "norm was {norm}");
459    }
460
461    #[test]
462    fn lexical_is_deterministic_cross_run() {
463        // The whole cross-platform contract: same input -> same vector.
464        let e = LexicalEmbedder::default();
465        assert_eq!(e.embed("getUserById"), e.embed("getUserById"));
466    }
467
468    #[test]
469    fn static_embedder_pools_known_tokens() {
470        let json = r#"{ "dim": 2, "vectors": {
471            "rate": [1.0, 0.0],
472            "limit": [0.0, 1.0],
473            "throttle": [0.7071, 0.7071]
474        } }"#;
475        let e = StaticEmbedder::from_json(json).expect("parse");
476        assert_eq!(e.dim(), 2);
477        // "rate limit" pools (1,0)+(0,1) -> (0.5,0.5) -> normalized (1/sqrt2, 1/sqrt2)
478        let v = e.embed("rate limit");
479        let expected = std::f32::consts::FRAC_1_SQRT_2;
480        assert!((v[0] - expected).abs() < 1e-3, "{v:?}");
481        assert!((v[1] - expected).abs() < 1e-3, "{v:?}");
482        // "throttle" should be very close to "rate limit" semantically here.
483        let sim = super::super::similarity::cosine(&v, &e.embed("throttle"));
484        assert!(sim > 0.99, "throttle sim {sim}");
485    }
486
487    #[test]
488    fn static_embedder_falls_back_for_unknown_tokens() {
489        let json = r#"{ "dim": 2, "vectors": { "rate": [1.0, 0.0] } }"#;
490        let e = StaticEmbedder::from_json(json).expect("parse");
491        // Unknown tokens -> lexical fallback, non-zero, comparable.
492        let v = e.embed("zzz totally unknown words");
493        assert!(v.iter().any(|&x| x != 0.0));
494    }
495
496    #[test]
497    fn static_embedder_rejects_malformed_asset() {
498        assert!(StaticEmbedder::from_json("not json").is_err());
499        assert!(StaticEmbedder::from_json(r#"{ "dim": 2, "vectors": {} }"#).is_err());
500        // length mismatch
501        assert!(
502            StaticEmbedder::from_json(r#"{ "dim": 3, "vectors": { "x": [1.0, 2.0] } }"#).is_err()
503        );
504    }
505
506    #[test]
507    fn resolve_asset_dir_respects_override_and_absence() {
508        let tmp = std::env::temp_dir().join("embed-resolve-test-absent-xyz");
509        let _ = std::fs::remove_dir_all(&tmp);
510        assert_eq!(resolve_asset_dir(Some(&tmp), None, "potion"), None);
511        assert_eq!(resolve_asset_dir(None, Some(&tmp), "potion"), None);
512    }
513
514    #[test]
515    fn parse_handles_negative_and_scientific_floats() {
516        let json = r#"{ "dim": 3, "vectors": { "x": [-1.5, 0.0, 2.0] } }"#;
517        let e = StaticEmbedder::from_json(json).expect("parse");
518        assert_eq!(e.dim(), 3);
519    }
520}