ainl_context_compiler/
embedder.rs1use std::error::Error;
8use std::fmt;
9use std::hash::{Hash, Hasher};
10
11const PLACEHOLDER_EMBED_DIM: usize = 16;
12
13#[derive(Debug)]
15pub enum EmbedderError {
16 Transport(String),
18 ModelMissing,
20 Other(String),
22}
23
24impl fmt::Display for EmbedderError {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match self {
27 Self::Transport(m) => write!(f, "embedder transport: {m}"),
28 Self::ModelMissing => f.write_str("embedder model not loaded"),
29 Self::Other(m) => write!(f, "embedder: {m}"),
30 }
31 }
32}
33
34impl Error for EmbedderError {}
35
36pub trait Embedder: Send + Sync {
43 fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError>;
45
46 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbedderError> {
48 texts.iter().map(|t| self.embed(t)).collect()
49 }
50}
51
52#[derive(Debug, Default, Clone, Copy)]
56pub struct PlaceholderEmbedder;
57
58impl PlaceholderEmbedder {
59 #[must_use]
61 pub const fn new() -> Self {
62 Self
63 }
64}
65
66impl Embedder for PlaceholderEmbedder {
67 fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedderError> {
68 use std::collections::hash_map::DefaultHasher;
69 let mut h = DefaultHasher::new();
70 text.hash(&mut h);
71 let x = h.finish();
72 let mut v = vec![0f32; PLACEHOLDER_EMBED_DIM];
73 for (i, slot) in v.iter_mut().enumerate() {
74 *slot = (((x >> (i * 4)) & 0xF) as f32) / 15.0;
75 }
76 let n: f32 = v.iter().map(|e| e * e).sum::<f32>().sqrt();
77 if n > 0.0 {
78 for t in v.iter_mut() {
79 *t /= n;
80 }
81 }
82 Ok(v)
83 }
84}
85
86#[must_use]
88pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
89 if a.len() != b.len() || a.is_empty() {
90 return 0.0;
91 }
92 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
93 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
94 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
95 if na == 0.0 || nb == 0.0 {
96 0.0
97 } else {
98 dot / (na * nb)
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn cosine_identical_is_one() {
108 let v = vec![1.0, 2.0, 3.0];
109 assert!((cosine(&v, &v) - 1.0).abs() < 1e-6);
110 }
111
112 #[test]
113 fn cosine_orthogonal_is_zero() {
114 let a = vec![1.0, 0.0];
115 let b = vec![0.0, 1.0];
116 assert!(cosine(&a, &b).abs() < 1e-6);
117 }
118
119 #[test]
120 fn cosine_mismatched_lengths_returns_zero() {
121 let a = vec![1.0, 0.0];
122 let b = vec![1.0, 0.0, 0.0];
123 assert_eq!(cosine(&a, &b), 0.0);
124 }
125
126 #[test]
127 fn placeholder_l2_unit_vector() {
128 let e = PlaceholderEmbedder::new();
129 let v = e.embed("hello world").expect("ok");
130 assert_eq!(v.len(), PLACEHOLDER_EMBED_DIM);
131 let n: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
132 assert!((n - 1.0).abs() < 1e-5 || n.abs() < 1e-5);
133 }
134}