1mod backend;
34mod similarity;
35mod tokenize;
36
37pub use backend::{resolve_asset_dir, Embedder, LexicalEmbedder, StaticEmbedder};
38pub use similarity::{cosine, top_k, Scored};
39
40use std::path::Path;
41use std::sync::Arc;
42
43use harn_vm::VmValue;
44
45use crate::error::HostlibError;
46use crate::registry::{BuiltinRegistry, HostlibCapability, RegisteredBuiltin, SyncHandler};
47use crate::tools::args::{build_dict, dict_arg, optional_int, require_string};
48use crate::value_args;
49
50pub const BUILTIN_SIMILARITY: &str = "hostlib_embed_similarity";
52pub const BUILTIN_TOP_K: &str = "hostlib_embed_top_k";
54pub const BUILTIN_VECTOR: &str = "hostlib_embed_vector";
56pub const BUILTIN_INFO: &str = "hostlib_embed_info";
58
59#[derive(Clone)]
61pub struct EmbedCapability {
62 embedder: Arc<dyn Embedder>,
63}
64
65impl Default for EmbedCapability {
66 fn default() -> Self {
67 Self::lexical()
68 }
69}
70
71impl EmbedCapability {
72 pub fn lexical() -> Self {
74 Self {
75 embedder: Arc::new(LexicalEmbedder::default()),
76 }
77 }
78
79 pub fn with_embedder(embedder: Arc<dyn Embedder>) -> Self {
82 Self { embedder }
83 }
84
85 pub fn resolve(override_dir: Option<&Path>, data_dir: Option<&Path>, model: &str) -> Self {
90 if let Some(dir) = resolve_asset_dir(override_dir, data_dir, model) {
91 if let Ok(static_embedder) = StaticEmbedder::from_asset_dir(&dir) {
92 return Self {
93 embedder: Arc::new(static_embedder),
94 };
95 }
96 }
97 Self::lexical()
98 }
99
100 pub fn embedder(&self) -> &Arc<dyn Embedder> {
102 &self.embedder
103 }
104
105 fn run_similarity(&self, args: &[VmValue]) -> Result<VmValue, HostlibError> {
106 let raw = dict_arg(BUILTIN_SIMILARITY, args)?;
107 let dict = raw.as_ref();
108 let a = require_string(BUILTIN_SIMILARITY, dict, "a")?;
109 let b = require_string(BUILTIN_SIMILARITY, dict, "b")?;
110 let va = self.embedder.embed(&a);
111 let vb = self.embedder.embed(&b);
112 let sim = cosine(&va, &vb);
113 Ok(build_dict([
117 ("similarity", VmValue::Float(sim as f64)),
118 ("relatedness", VmValue::Float(sim.max(0.0) as f64)),
119 ]))
120 }
121
122 fn run_top_k(&self, args: &[VmValue]) -> Result<VmValue, HostlibError> {
123 let raw = dict_arg(BUILTIN_TOP_K, args)?;
124 let dict = raw.as_ref();
125 let query = require_string(BUILTIN_TOP_K, dict, "query")?;
126 let corpus = require_string_list(BUILTIN_TOP_K, dict, "corpus")?;
127 let k = optional_int(BUILTIN_TOP_K, dict, "k", 10)?.max(0) as usize;
128 let min_score =
129 optional_float(BUILTIN_TOP_K, dict, "min_score")?.unwrap_or(f64::NEG_INFINITY);
130
131 let query_vec = self.embedder.embed(&query);
132 let corpus_vecs: Vec<Vec<f32>> = self.embedder.embed_batch(&corpus);
133 let ranked = top_k(&query_vec, &corpus_vecs, k);
134
135 let results: Vec<VmValue> = ranked
136 .into_iter()
137 .filter(|s| (s.score as f64) >= min_score)
138 .map(|s| {
139 build_dict([
140 ("index", VmValue::Int(s.index as i64)),
141 (
142 "text",
143 VmValue::string(corpus.get(s.index).map(String::as_str).unwrap_or("")),
144 ),
145 ("score", VmValue::Float(s.score as f64)),
146 ("relatedness", VmValue::Float((s.score.max(0.0)) as f64)),
147 ])
148 })
149 .collect();
150 Ok(build_dict([("results", VmValue::List(Arc::new(results)))]))
151 }
152
153 fn run_vector(&self, args: &[VmValue]) -> Result<VmValue, HostlibError> {
154 let raw = dict_arg(BUILTIN_VECTOR, args)?;
155 let dict = raw.as_ref();
156 let text = require_string(BUILTIN_VECTOR, dict, "text")?;
157 let v = self.embedder.embed(&text);
158 let values: Vec<VmValue> = v.into_iter().map(|x| VmValue::Float(x as f64)).collect();
159 Ok(build_dict([
160 ("dim", VmValue::Int(self.embedder.dim() as i64)),
161 ("vector", VmValue::List(Arc::new(values))),
162 ]))
163 }
164
165 fn run_info(&self, _args: &[VmValue]) -> Result<VmValue, HostlibError> {
166 Ok(build_dict([
167 ("backend", VmValue::string(self.embedder.name())),
168 ("dim", VmValue::Int(self.embedder.dim() as i64)),
169 ]))
170 }
171}
172
173impl HostlibCapability for EmbedCapability {
174 fn module_name(&self) -> &'static str {
175 "embed"
176 }
177
178 fn register_builtins(&self, registry: &mut BuiltinRegistry) {
179 let cap = self.clone();
180 let handler: SyncHandler = Arc::new(move |args| cap.run_similarity(args));
181 registry.register(RegisteredBuiltin {
182 name: BUILTIN_SIMILARITY,
183 module: "embed",
184 method: "similarity",
185 handler,
186 });
187
188 let cap = self.clone();
189 let handler: SyncHandler = Arc::new(move |args| cap.run_top_k(args));
190 registry.register(RegisteredBuiltin {
191 name: BUILTIN_TOP_K,
192 module: "embed",
193 method: "top_k",
194 handler,
195 });
196
197 let cap = self.clone();
198 let handler: SyncHandler = Arc::new(move |args| cap.run_vector(args));
199 registry.register(RegisteredBuiltin {
200 name: BUILTIN_VECTOR,
201 module: "embed",
202 method: "vector",
203 handler,
204 });
205
206 let cap = self.clone();
207 let handler: SyncHandler = Arc::new(move |args| cap.run_info(args));
208 registry.register(RegisteredBuiltin {
209 name: BUILTIN_INFO,
210 module: "embed",
211 method: "info",
212 handler,
213 });
214 }
215}
216
217fn require_string_list(
220 builtin: &'static str,
221 dict: &harn_vm::value::DictMap,
222 key: &'static str,
223) -> Result<Vec<String>, HostlibError> {
224 match value_args::optional_string_list(builtin, dict, key)? {
225 Some(list) => Ok(list),
226 None => Err(HostlibError::MissingParameter {
227 builtin,
228 param: key,
229 }),
230 }
231}
232
233fn optional_float(
234 builtin: &'static str,
235 dict: &harn_vm::value::DictMap,
236 key: &'static str,
237) -> Result<Option<f64>, HostlibError> {
238 match dict.get(key) {
239 None | Some(VmValue::Nil) => Ok(None),
240 Some(VmValue::Float(f)) => Ok(Some(*f)),
241 Some(VmValue::Int(i)) => Ok(Some(*i as f64)),
242 Some(other) => Err(HostlibError::InvalidParameter {
243 builtin,
244 param: key,
245 message: format!("expected number, got {}", value_args::describe(other)),
246 }),
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use harn_vm::value::{intern_key, DictMap};
254
255 fn call(cap: &EmbedCapability, builtin: &str, dict: DictMap) -> VmValue {
256 let args = [VmValue::dict(dict)];
257 match builtin {
258 BUILTIN_SIMILARITY => cap.run_similarity(&args).unwrap(),
259 BUILTIN_TOP_K => cap.run_top_k(&args).unwrap(),
260 BUILTIN_VECTOR => cap.run_vector(&args).unwrap(),
261 BUILTIN_INFO => cap.run_info(&args).unwrap(),
262 _ => panic!("unknown builtin"),
263 }
264 }
265
266 fn dict_of(pairs: &[(&str, VmValue)]) -> DictMap {
267 let mut m = DictMap::new();
268 for (k, v) in pairs {
269 m.insert(intern_key(k), v.clone());
270 }
271 m
272 }
273
274 fn get_float(v: &VmValue, key: &str) -> f64 {
275 if let VmValue::Dict(d) = v {
276 if let Some(VmValue::Float(f)) = d.get(key) {
277 return *f;
278 }
279 }
280 panic!("no float {key} in {v:?}");
281 }
282
283 fn dict_int(d: &DictMap, key: &str) -> i64 {
284 match d.get(key) {
285 Some(VmValue::Int(i)) => *i,
286 other => panic!("no int {key}: {other:?}"),
287 }
288 }
289
290 fn dict_str(d: &DictMap, key: &str) -> String {
291 match d.get(key) {
292 Some(VmValue::String(s)) => s.to_string(),
293 other => panic!("no string {key}: {other:?}"),
294 }
295 }
296
297 #[test]
298 fn similarity_self_is_one() {
299 let cap = EmbedCapability::lexical();
300 let out = call(
301 &cap,
302 BUILTIN_SIMILARITY,
303 dict_of(&[
304 ("a", VmValue::string("rate limiter")),
305 ("b", VmValue::string("rate limiter")),
306 ]),
307 );
308 assert!((get_float(&out, "similarity") - 1.0).abs() < 1e-5);
309 assert!((get_float(&out, "relatedness") - 1.0).abs() < 1e-5);
310 }
311
312 #[test]
313 fn similarity_relatedness_is_clamped() {
314 let cap = EmbedCapability::lexical();
315 let out = call(
316 &cap,
317 BUILTIN_SIMILARITY,
318 dict_of(&[
319 ("a", VmValue::string("alpha beta gamma")),
320 ("b", VmValue::string("delta epsilon zeta")),
321 ]),
322 );
323 assert!(get_float(&out, "relatedness") >= 0.0);
326 }
327
328 #[test]
329 fn top_k_ranks_corpus() {
330 let cap = EmbedCapability::lexical();
331 let out = call(
332 &cap,
333 BUILTIN_TOP_K,
334 dict_of(&[
335 ("query", VmValue::string("rate limiter middleware")),
336 (
337 "corpus",
338 VmValue::List(Arc::new(vec![
339 VmValue::string("markdown table renderer"),
340 VmValue::string("RateLimiter middleware for the API"),
341 VmValue::string("json parser"),
342 ])),
343 ),
344 ("k", VmValue::Int(2)),
345 ]),
346 );
347 let VmValue::Dict(d) = &out else { panic!() };
348 let VmValue::List(results) = d.get("results").unwrap() else {
349 panic!()
350 };
351 assert_eq!(results.len(), 2);
352 let VmValue::Dict(first) = &results[0] else {
354 panic!()
355 };
356 assert_eq!(dict_int(first, "index"), 1);
357 }
358
359 #[test]
360 fn top_k_min_score_filters() {
361 let cap = EmbedCapability::lexical();
362 let out = call(
363 &cap,
364 BUILTIN_TOP_K,
365 dict_of(&[
366 ("query", VmValue::string("rate limiter")),
367 (
368 "corpus",
369 VmValue::List(Arc::new(vec![VmValue::string(
370 "completely different topic",
371 )])),
372 ),
373 ("k", VmValue::Int(5)),
374 ("min_score", VmValue::Float(0.99)),
375 ]),
376 );
377 let VmValue::Dict(d) = &out else { panic!() };
378 let VmValue::List(results) = d.get("results").unwrap() else {
379 panic!()
380 };
381 assert!(results.is_empty(), "min_score should filter out weak match");
382 }
383
384 #[test]
385 fn vector_has_declared_dim() {
386 let cap = EmbedCapability::lexical();
387 let out = call(
388 &cap,
389 BUILTIN_VECTOR,
390 dict_of(&[("text", VmValue::string("hello"))]),
391 );
392 let VmValue::Dict(d) = &out else { panic!() };
393 assert_eq!(dict_int(d, "dim"), 256);
394 let VmValue::List(v) = d.get("vector").unwrap() else {
395 panic!()
396 };
397 assert_eq!(v.len(), 256);
398 }
399
400 #[test]
401 fn info_reports_lexical_default() {
402 let cap = EmbedCapability::lexical();
403 let out = call(&cap, BUILTIN_INFO, DictMap::new());
404 let VmValue::Dict(d) = &out else { panic!() };
405 assert_eq!(dict_str(d, "backend"), "lexical-hash");
406 assert_eq!(dict_int(d, "dim"), 256);
407 }
408
409 #[test]
410 fn resolve_degrades_to_lexical_when_absent() {
411 let absent = std::env::temp_dir().join("embed-cap-absent-xyz-123");
412 let _ = std::fs::remove_dir_all(&absent);
413 let cap = EmbedCapability::resolve(Some(&absent), None, "potion");
414 assert_eq!(cap.embedder().name(), "lexical-hash");
415 }
416
417 #[test]
418 fn resolve_uses_static_asset_when_present() {
419 let dir = std::env::temp_dir().join("embed-cap-present-xyz-456");
420 let _ = std::fs::create_dir_all(&dir);
421 std::fs::write(
422 dir.join("static-embeddings.json"),
423 r#"{ "dim": 2, "vectors": { "rate": [1.0, 0.0], "limit": [0.0, 1.0] } }"#,
424 )
425 .unwrap();
426 let cap = EmbedCapability::resolve(Some(&dir), None, "potion");
427 assert_eq!(cap.embedder().name(), "static-model2vec");
428 assert_eq!(cap.embedder().dim(), 2);
429 let _ = std::fs::remove_dir_all(&dir);
430 }
431
432 #[test]
433 fn missing_required_param_errors() {
434 let cap = EmbedCapability::lexical();
435 let args = [VmValue::dict(dict_of(&[("a", VmValue::string("x"))]))];
436 assert!(matches!(
437 cap.run_similarity(&args),
438 Err(HostlibError::MissingParameter { param: "b", .. })
439 ));
440 }
441
442 #[test]
443 fn registers_four_builtins() {
444 let cap = EmbedCapability::lexical();
445 let mut reg = BuiltinRegistry::new();
446 cap.register_builtins(&mut reg);
447 let names: Vec<_> = reg.iter().map(|b| b.name).collect();
448 assert_eq!(
449 names,
450 vec![
451 BUILTIN_SIMILARITY,
452 BUILTIN_TOP_K,
453 BUILTIN_VECTOR,
454 BUILTIN_INFO
455 ]
456 );
457 }
458}