assay_core/embeddings/
util.rs1use sha2::{Digest, Sha256};
2
3pub fn encode_vec_f32(v: &[f32]) -> Vec<u8> {
4 let mut out = Vec::with_capacity(v.len() * 4);
5 for x in v {
6 out.extend_from_slice(&x.to_le_bytes());
7 }
8 out
9}
10
11#[allow(clippy::manual_is_multiple_of)]
12pub fn decode_vec_f32(bytes: &[u8]) -> anyhow::Result<Vec<f32>> {
13 if bytes.len() % 4 != 0 {
14 anyhow::bail!("config error: invalid embedding blob size");
15 }
16 let mut v = Vec::with_capacity(bytes.len() / 4);
17 for chunk in bytes.chunks_exact(4) {
18 v.push(f32::from_le_bytes(chunk.try_into().unwrap()));
19 }
20 Ok(v)
21}
22
23pub fn sha256_hex(s: &str) -> String {
24 let mut h = Sha256::new();
25 h.update(s.as_bytes());
26 hex::encode(h.finalize())
27}
28
29pub fn embed_cache_key(model_id: &str, text: &str) -> String {
30 format!("emb|{}|{}", model_id, sha256_hex(text))
31}
32
33pub fn cosine_similarity(a: &[f32], b: &[f32]) -> anyhow::Result<f64> {
34 let af: Vec<f64> = a.iter().map(|x| *x as f64).collect();
35 let bf: Vec<f64> = b.iter().map(|x| *x as f64).collect();
36 cosine_similarity_f64(&af, &bf)
37}
38
39pub fn cosine_similarity_f64(a: &[f64], b: &[f64]) -> anyhow::Result<f64> {
40 if a.is_empty() || a.len() != b.len() {
41 anyhow::bail!(
42 "config error: embedding dims mismatch (a={}, b={})",
43 a.len(),
44 b.len()
45 );
46 }
47 let mut dot = 0.0f64;
48 let mut na = 0.0f64;
49 let mut nb = 0.0f64;
50
51 for i in 0..a.len() {
52 let x = a[i];
53 let y = b[i];
54 dot += x * y;
55 na += x * x;
56 nb += y * y;
57 }
58 let denom = na.sqrt() * nb.sqrt();
59 if denom == 0.0 {
60 anyhow::bail!("config error: zero-norm embedding");
61 }
62 let s = dot / denom;
63 Ok(s.clamp(-1.0, 1.0))
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69
70 #[test]
71 fn encode_decode_roundtrip() -> anyhow::Result<()> {
72 let v = vec![0.1_f32, -0.2_f32, 3.5_f32];
73 let blob = encode_vec_f32(&v);
74 let out = decode_vec_f32(&blob)?;
75 assert_eq!(v.len(), out.len());
76 for i in 0..v.len() {
77 assert!((v[i] - out[i]).abs() < 1e-6);
78 }
79 Ok(())
80 }
81
82 #[test]
83 fn cosine_identical_is_one() -> anyhow::Result<()> {
84 let a = vec![1.0_f32, 0.0, 0.0];
85 let b = vec![1.0_f32, 0.0, 0.0];
86 let s = cosine_similarity(&a, &b)?;
87 assert!((s - 1.0).abs() < 1e-9);
88 Ok(())
89 }
90}