assay_core/embeddings/
util.rs

1use sha2::{Digest, Sha256};
2
3pub fn encode_vec_f32(v: &[f32]) -> Vec<u8> {
4    let mut out = Vec::with_capacity(v.len() * 4);
5    for x in v {
6        out.extend_from_slice(&x.to_le_bytes());
7    }
8    out
9}
10
11#[allow(clippy::manual_is_multiple_of)]
12pub fn decode_vec_f32(bytes: &[u8]) -> anyhow::Result<Vec<f32>> {
13    if bytes.len() % 4 != 0 {
14        anyhow::bail!("config error: invalid embedding blob size");
15    }
16    let mut v = Vec::with_capacity(bytes.len() / 4);
17    for chunk in bytes.chunks_exact(4) {
18        v.push(f32::from_le_bytes(chunk.try_into().unwrap()));
19    }
20    Ok(v)
21}
22
23pub fn sha256_hex(s: &str) -> String {
24    let mut h = Sha256::new();
25    h.update(s.as_bytes());
26    hex::encode(h.finalize())
27}
28
29pub fn embed_cache_key(model_id: &str, text: &str) -> String {
30    format!("emb|{}|{}", model_id, sha256_hex(text))
31}
32
33pub fn cosine_similarity(a: &[f32], b: &[f32]) -> anyhow::Result<f64> {
34    let af: Vec<f64> = a.iter().map(|x| *x as f64).collect();
35    let bf: Vec<f64> = b.iter().map(|x| *x as f64).collect();
36    cosine_similarity_f64(&af, &bf)
37}
38
39pub fn cosine_similarity_f64(a: &[f64], b: &[f64]) -> anyhow::Result<f64> {
40    if a.is_empty() || a.len() != b.len() {
41        anyhow::bail!(
42            "config error: embedding dims mismatch (a={}, b={})",
43            a.len(),
44            b.len()
45        );
46    }
47    let mut dot = 0.0f64;
48    let mut na = 0.0f64;
49    let mut nb = 0.0f64;
50
51    for i in 0..a.len() {
52        let x = a[i];
53        let y = b[i];
54        dot += x * y;
55        na += x * x;
56        nb += y * y;
57    }
58    let denom = na.sqrt() * nb.sqrt();
59    if denom == 0.0 {
60        anyhow::bail!("config error: zero-norm embedding");
61    }
62    let s = dot / denom;
63    Ok(s.clamp(-1.0, 1.0))
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn encode_decode_roundtrip() -> anyhow::Result<()> {
72        let v = vec![0.1_f32, -0.2_f32, 3.5_f32];
73        let blob = encode_vec_f32(&v);
74        let out = decode_vec_f32(&blob)?;
75        assert_eq!(v.len(), out.len());
76        for i in 0..v.len() {
77            assert!((v[i] - out[i]).abs() < 1e-6);
78        }
79        Ok(())
80    }
81
82    #[test]
83    fn cosine_identical_is_one() -> anyhow::Result<()> {
84        let a = vec![1.0_f32, 0.0, 0.0];
85        let b = vec![1.0_f32, 0.0, 0.0];
86        let s = cosine_similarity(&a, &b)?;
87        assert!((s - 1.0).abs() < 1e-9);
88        Ok(())
89    }
90}