Skip to main content

harn_hostlib/embed/
mod.rs

1//! Text-similarity / embedding host capability.
2//!
3//! A cross-platform, offline, DRY core for cosine/semantic similarity. It
4//! is the single source of truth two consumers share:
5//!
6//! 1. **Push-context Tier-2** (host pipelines): auto-injecting
7//!    skills/canon/memory/few-shot above a similarity threshold.
8//! 2. **`SymbolRelevance`** (host-side): symbol ranking, today split
9//!    between macOS-only `NLEmbedding` and a Linux Jaccard fallback. Both
10//!    can now route through these builtins for one cross-platform path.
11//!
12//! ## Surface
13//!
14//! | Builtin                       | What it does                                              |
15//! |-------------------------------|-----------------------------------------------------------|
16//! | `hostlib_embed_similarity`    | Cosine similarity of two strings via the active backend.  |
17//! | `hostlib_embed_top_k`         | Rank a corpus of strings against a query, return top `k`. |
18//! | `hostlib_embed_vector`        | Embed one string to its raw `f32` vector.                 |
19//! | `hostlib_embed_info`          | Active backend name + dimensionality.                     |
20//!
21//! ## Backend selection
22//!
23//! The capability owns one [`backend::Embedder`] behind an `Arc`, shared
24//! across every VM/thread (mirroring `code_index`). Default is the
25//! always-available [`backend::LexicalEmbedder`] (zero asset, microsecond,
26//! deterministic across OSes). When a Model2Vec/"potion"-style static asset
27//! is resolvable (settings/sandbox-aware, no network), the capability
28//! upgrades to [`backend::StaticEmbedder`]; if the asset is missing or
29//! malformed it degrades cleanly back to lexical. A future candle/ONNX
30//! transformer tier slots in behind a Cargo feature without changing this
31//! surface or either consumer.
32
33mod backend;
34mod similarity;
35mod tokenize;
36
37pub use backend::{resolve_asset_dir, Embedder, LexicalEmbedder, StaticEmbedder};
38pub use similarity::{cosine, top_k, Scored};
39
40use std::path::Path;
41use std::sync::Arc;
42
43use harn_vm::VmValue;
44
45use crate::error::HostlibError;
46use crate::registry::{BuiltinRegistry, HostlibCapability, RegisteredBuiltin, SyncHandler};
47use crate::tools::args::{build_dict, dict_arg, optional_int, require_string};
48use crate::value_args;
49
50/// Builtin name for cosine similarity of two strings.
51pub const BUILTIN_SIMILARITY: &str = "hostlib_embed_similarity";
52/// Builtin name for top-k corpus ranking against a query.
53pub const BUILTIN_TOP_K: &str = "hostlib_embed_top_k";
54/// Builtin name for embedding one string to its raw vector.
55pub const BUILTIN_VECTOR: &str = "hostlib_embed_vector";
56/// Builtin name for reporting the active backend.
57pub const BUILTIN_INFO: &str = "hostlib_embed_info";
58
59/// Embedding capability handle. Cloning shares the active backend.
60#[derive(Clone)]
61pub struct EmbedCapability {
62    embedder: Arc<dyn Embedder>,
63}
64
65impl Default for EmbedCapability {
66    fn default() -> Self {
67        Self::lexical()
68    }
69}
70
71impl EmbedCapability {
72    /// Capability using the always-available lexical backend.
73    pub fn lexical() -> Self {
74        Self {
75            embedder: Arc::new(LexicalEmbedder::default()),
76        }
77    }
78
79    /// Capability backed by an explicit [`Embedder`]. Embedders construct
80    /// their own fallback, so this never fails.
81    pub fn with_embedder(embedder: Arc<dyn Embedder>) -> Self {
82        Self { embedder }
83    }
84
85    /// Resolve a static (Model2Vec-style) asset and use it if present,
86    /// otherwise fall back to lexical. Sandbox/settings-aware: pass the
87    /// override dir (from a setting) and the data dir; nothing else is
88    /// touched and there is no network access.
89    pub fn resolve(override_dir: Option<&Path>, data_dir: Option<&Path>, model: &str) -> Self {
90        if let Some(dir) = resolve_asset_dir(override_dir, data_dir, model) {
91            if let Ok(static_embedder) = StaticEmbedder::from_asset_dir(&dir) {
92                return Self {
93                    embedder: Arc::new(static_embedder),
94                };
95            }
96        }
97        Self::lexical()
98    }
99
100    /// Borrow the active backend (tests / embedders).
101    pub fn embedder(&self) -> &Arc<dyn Embedder> {
102        &self.embedder
103    }
104
105    fn run_similarity(&self, args: &[VmValue]) -> Result<VmValue, HostlibError> {
106        let raw = dict_arg(BUILTIN_SIMILARITY, args)?;
107        let dict = raw.as_ref();
108        let a = require_string(BUILTIN_SIMILARITY, dict, "a")?;
109        let b = require_string(BUILTIN_SIMILARITY, dict, "b")?;
110        let va = self.embedder.embed(&a);
111        let vb = self.embedder.embed(&b);
112        let sim = cosine(&va, &vb);
113        // Surface both the raw cosine ([-1,1]) and a clamped relatedness
114        // score ([0,1]) so each consumer picks the shape it wants without
115        // re-deriving it (DRY: one definition of "relatedness").
116        Ok(build_dict([
117            ("similarity", VmValue::Float(sim as f64)),
118            ("relatedness", VmValue::Float(sim.max(0.0) as f64)),
119        ]))
120    }
121
122    fn run_top_k(&self, args: &[VmValue]) -> Result<VmValue, HostlibError> {
123        let raw = dict_arg(BUILTIN_TOP_K, args)?;
124        let dict = raw.as_ref();
125        let query = require_string(BUILTIN_TOP_K, dict, "query")?;
126        let corpus = require_string_list(BUILTIN_TOP_K, dict, "corpus")?;
127        let k = optional_int(BUILTIN_TOP_K, dict, "k", 10)?.max(0) as usize;
128        let min_score =
129            optional_float(BUILTIN_TOP_K, dict, "min_score")?.unwrap_or(f64::NEG_INFINITY);
130
131        let query_vec = self.embedder.embed(&query);
132        let corpus_vecs: Vec<Vec<f32>> = self.embedder.embed_batch(&corpus);
133        let ranked = top_k(&query_vec, &corpus_vecs, k);
134
135        let results: Vec<VmValue> = ranked
136            .into_iter()
137            .filter(|s| (s.score as f64) >= min_score)
138            .map(|s| {
139                build_dict([
140                    ("index", VmValue::Int(s.index as i64)),
141                    (
142                        "text",
143                        VmValue::string(corpus.get(s.index).map(String::as_str).unwrap_or("")),
144                    ),
145                    ("score", VmValue::Float(s.score as f64)),
146                    ("relatedness", VmValue::Float((s.score.max(0.0)) as f64)),
147                ])
148            })
149            .collect();
150        Ok(build_dict([("results", VmValue::List(Arc::new(results)))]))
151    }
152
153    fn run_vector(&self, args: &[VmValue]) -> Result<VmValue, HostlibError> {
154        let raw = dict_arg(BUILTIN_VECTOR, args)?;
155        let dict = raw.as_ref();
156        let text = require_string(BUILTIN_VECTOR, dict, "text")?;
157        let v = self.embedder.embed(&text);
158        let values: Vec<VmValue> = v.into_iter().map(|x| VmValue::Float(x as f64)).collect();
159        Ok(build_dict([
160            ("dim", VmValue::Int(self.embedder.dim() as i64)),
161            ("vector", VmValue::List(Arc::new(values))),
162        ]))
163    }
164
165    fn run_info(&self, _args: &[VmValue]) -> Result<VmValue, HostlibError> {
166        Ok(build_dict([
167            ("backend", VmValue::string(self.embedder.name())),
168            ("dim", VmValue::Int(self.embedder.dim() as i64)),
169        ]))
170    }
171}
172
173impl HostlibCapability for EmbedCapability {
174    fn module_name(&self) -> &'static str {
175        "embed"
176    }
177
178    fn register_builtins(&self, registry: &mut BuiltinRegistry) {
179        let cap = self.clone();
180        let handler: SyncHandler = Arc::new(move |args| cap.run_similarity(args));
181        registry.register(RegisteredBuiltin {
182            name: BUILTIN_SIMILARITY,
183            module: "embed",
184            method: "similarity",
185            handler,
186        });
187
188        let cap = self.clone();
189        let handler: SyncHandler = Arc::new(move |args| cap.run_top_k(args));
190        registry.register(RegisteredBuiltin {
191            name: BUILTIN_TOP_K,
192            module: "embed",
193            method: "top_k",
194            handler,
195        });
196
197        let cap = self.clone();
198        let handler: SyncHandler = Arc::new(move |args| cap.run_vector(args));
199        registry.register(RegisteredBuiltin {
200            name: BUILTIN_VECTOR,
201            module: "embed",
202            method: "vector",
203            handler,
204        });
205
206        let cap = self.clone();
207        let handler: SyncHandler = Arc::new(move |args| cap.run_info(args));
208        registry.register(RegisteredBuiltin {
209            name: BUILTIN_INFO,
210            module: "embed",
211            method: "info",
212            handler,
213        });
214    }
215}
216
217// --- local arg helpers (string list + float) -----------------------------
218
219fn require_string_list(
220    builtin: &'static str,
221    dict: &harn_vm::value::DictMap,
222    key: &'static str,
223) -> Result<Vec<String>, HostlibError> {
224    match value_args::optional_string_list(builtin, dict, key)? {
225        Some(list) => Ok(list),
226        None => Err(HostlibError::MissingParameter {
227            builtin,
228            param: key,
229        }),
230    }
231}
232
233fn optional_float(
234    builtin: &'static str,
235    dict: &harn_vm::value::DictMap,
236    key: &'static str,
237) -> Result<Option<f64>, HostlibError> {
238    match dict.get(key) {
239        None | Some(VmValue::Nil) => Ok(None),
240        Some(VmValue::Float(f)) => Ok(Some(*f)),
241        Some(VmValue::Int(i)) => Ok(Some(*i as f64)),
242        Some(other) => Err(HostlibError::InvalidParameter {
243            builtin,
244            param: key,
245            message: format!("expected number, got {}", value_args::describe(other)),
246        }),
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use harn_vm::value::{intern_key, DictMap};
254
255    fn call(cap: &EmbedCapability, builtin: &str, dict: DictMap) -> VmValue {
256        let args = [VmValue::dict(dict)];
257        match builtin {
258            BUILTIN_SIMILARITY => cap.run_similarity(&args).unwrap(),
259            BUILTIN_TOP_K => cap.run_top_k(&args).unwrap(),
260            BUILTIN_VECTOR => cap.run_vector(&args).unwrap(),
261            BUILTIN_INFO => cap.run_info(&args).unwrap(),
262            _ => panic!("unknown builtin"),
263        }
264    }
265
266    fn dict_of(pairs: &[(&str, VmValue)]) -> DictMap {
267        let mut m = DictMap::new();
268        for (k, v) in pairs {
269            m.insert(intern_key(k), v.clone());
270        }
271        m
272    }
273
274    fn get_float(v: &VmValue, key: &str) -> f64 {
275        if let VmValue::Dict(d) = v {
276            if let Some(VmValue::Float(f)) = d.get(key) {
277                return *f;
278            }
279        }
280        panic!("no float {key} in {v:?}");
281    }
282
283    fn dict_int(d: &DictMap, key: &str) -> i64 {
284        match d.get(key) {
285            Some(VmValue::Int(i)) => *i,
286            other => panic!("no int {key}: {other:?}"),
287        }
288    }
289
290    fn dict_str(d: &DictMap, key: &str) -> String {
291        match d.get(key) {
292            Some(VmValue::String(s)) => s.to_string(),
293            other => panic!("no string {key}: {other:?}"),
294        }
295    }
296
297    #[test]
298    fn similarity_self_is_one() {
299        let cap = EmbedCapability::lexical();
300        let out = call(
301            &cap,
302            BUILTIN_SIMILARITY,
303            dict_of(&[
304                ("a", VmValue::string("rate limiter")),
305                ("b", VmValue::string("rate limiter")),
306            ]),
307        );
308        assert!((get_float(&out, "similarity") - 1.0).abs() < 1e-5);
309        assert!((get_float(&out, "relatedness") - 1.0).abs() < 1e-5);
310    }
311
312    #[test]
313    fn similarity_relatedness_is_clamped() {
314        let cap = EmbedCapability::lexical();
315        let out = call(
316            &cap,
317            BUILTIN_SIMILARITY,
318            dict_of(&[
319                ("a", VmValue::string("alpha beta gamma")),
320                ("b", VmValue::string("delta epsilon zeta")),
321            ]),
322        );
323        // Disjoint text can score slightly negative under signed hashing;
324        // relatedness must never be < 0.
325        assert!(get_float(&out, "relatedness") >= 0.0);
326    }
327
328    #[test]
329    fn top_k_ranks_corpus() {
330        let cap = EmbedCapability::lexical();
331        let out = call(
332            &cap,
333            BUILTIN_TOP_K,
334            dict_of(&[
335                ("query", VmValue::string("rate limiter middleware")),
336                (
337                    "corpus",
338                    VmValue::List(Arc::new(vec![
339                        VmValue::string("markdown table renderer"),
340                        VmValue::string("RateLimiter middleware for the API"),
341                        VmValue::string("json parser"),
342                    ])),
343                ),
344                ("k", VmValue::Int(2)),
345            ]),
346        );
347        let VmValue::Dict(d) = &out else { panic!() };
348        let VmValue::List(results) = d.get("results").unwrap() else {
349            panic!()
350        };
351        assert_eq!(results.len(), 2);
352        // Top hit must be the rate-limiter entry (index 1).
353        let VmValue::Dict(first) = &results[0] else {
354            panic!()
355        };
356        assert_eq!(dict_int(first, "index"), 1);
357    }
358
359    #[test]
360    fn top_k_min_score_filters() {
361        let cap = EmbedCapability::lexical();
362        let out = call(
363            &cap,
364            BUILTIN_TOP_K,
365            dict_of(&[
366                ("query", VmValue::string("rate limiter")),
367                (
368                    "corpus",
369                    VmValue::List(Arc::new(vec![VmValue::string(
370                        "completely different topic",
371                    )])),
372                ),
373                ("k", VmValue::Int(5)),
374                ("min_score", VmValue::Float(0.99)),
375            ]),
376        );
377        let VmValue::Dict(d) = &out else { panic!() };
378        let VmValue::List(results) = d.get("results").unwrap() else {
379            panic!()
380        };
381        assert!(results.is_empty(), "min_score should filter out weak match");
382    }
383
384    #[test]
385    fn vector_has_declared_dim() {
386        let cap = EmbedCapability::lexical();
387        let out = call(
388            &cap,
389            BUILTIN_VECTOR,
390            dict_of(&[("text", VmValue::string("hello"))]),
391        );
392        let VmValue::Dict(d) = &out else { panic!() };
393        assert_eq!(dict_int(d, "dim"), 256);
394        let VmValue::List(v) = d.get("vector").unwrap() else {
395            panic!()
396        };
397        assert_eq!(v.len(), 256);
398    }
399
400    #[test]
401    fn info_reports_lexical_default() {
402        let cap = EmbedCapability::lexical();
403        let out = call(&cap, BUILTIN_INFO, DictMap::new());
404        let VmValue::Dict(d) = &out else { panic!() };
405        assert_eq!(dict_str(d, "backend"), "lexical-hash");
406        assert_eq!(dict_int(d, "dim"), 256);
407    }
408
409    #[test]
410    fn resolve_degrades_to_lexical_when_absent() {
411        let absent = std::env::temp_dir().join("embed-cap-absent-xyz-123");
412        let _ = std::fs::remove_dir_all(&absent);
413        let cap = EmbedCapability::resolve(Some(&absent), None, "potion");
414        assert_eq!(cap.embedder().name(), "lexical-hash");
415    }
416
417    #[test]
418    fn resolve_uses_static_asset_when_present() {
419        let dir = std::env::temp_dir().join("embed-cap-present-xyz-456");
420        let _ = std::fs::create_dir_all(&dir);
421        std::fs::write(
422            dir.join("static-embeddings.json"),
423            r#"{ "dim": 2, "vectors": { "rate": [1.0, 0.0], "limit": [0.0, 1.0] } }"#,
424        )
425        .unwrap();
426        let cap = EmbedCapability::resolve(Some(&dir), None, "potion");
427        assert_eq!(cap.embedder().name(), "static-model2vec");
428        assert_eq!(cap.embedder().dim(), 2);
429        let _ = std::fs::remove_dir_all(&dir);
430    }
431
432    #[test]
433    fn missing_required_param_errors() {
434        let cap = EmbedCapability::lexical();
435        let args = [VmValue::dict(dict_of(&[("a", VmValue::string("x"))]))];
436        assert!(matches!(
437            cap.run_similarity(&args),
438            Err(HostlibError::MissingParameter { param: "b", .. })
439        ));
440    }
441
442    #[test]
443    fn registers_four_builtins() {
444        let cap = EmbedCapability::lexical();
445        let mut reg = BuiltinRegistry::new();
446        cap.register_builtins(&mut reg);
447        let names: Vec<_> = reg.iter().map(|b| b.name).collect();
448        assert_eq!(
449            names,
450            vec![
451                BUILTIN_SIMILARITY,
452                BUILTIN_TOP_K,
453                BUILTIN_VECTOR,
454                BUILTIN_INFO
455            ]
456        );
457    }
458}