Skip to main content

innate_core/
embedding.rs

1use crate::errors::Result;
2use crate::utils::{content_hash, pack_embedding};
3
4/// Embedding provider trait — swap for real models at construction time.
5pub trait EmbeddingProvider: Send + Sync {
6    fn model_name(&self) -> &'static str {
7        "custom"
8    }
9    fn content_dim(&self) -> usize;
10    fn trigger_dim(&self) -> usize;
11    fn embed_content(&self, text: &str) -> Result<Vec<f32>>;
12    fn embed_trigger(&self, text: &str) -> Result<Vec<f32>>;
13
14    /// Embed `text` for both the content and trigger spaces. Default issues two
15    /// separate calls. Providers backed by a single shared model (e.g. a remote
16    /// embedding endpoint) should override this to make one request and avoid the
17    /// duplicate round trip on the recall hot path.
18    fn embed_both(&self, text: &str) -> Result<(Vec<f32>, Vec<f32>)> {
19        Ok((self.embed_content(text)?, self.embed_trigger(text)?))
20    }
21}
22
23/// Hash-based deterministic embeddings — no model needed, good for tests.
24pub struct DummyEmbeddingProvider {
25    content_dim: usize,
26    trigger_dim: usize,
27}
28
29impl DummyEmbeddingProvider {
30    pub fn new(content_dim: usize, trigger_dim: usize) -> Self {
31        Self {
32            content_dim,
33            trigger_dim,
34        }
35    }
36}
37
38impl Default for DummyEmbeddingProvider {
39    fn default() -> Self {
40        Self::new(1024, 256)
41    }
42}
43
44impl EmbeddingProvider for DummyEmbeddingProvider {
45    fn model_name(&self) -> &'static str {
46        "DummyEmbeddingProvider"
47    }
48
49    fn content_dim(&self) -> usize {
50        self.content_dim
51    }
52    fn trigger_dim(&self) -> usize {
53        self.trigger_dim
54    }
55
56    fn embed_content(&self, text: &str) -> Result<Vec<f32>> {
57        Ok(hash_to_vec(text, self.content_dim))
58    }
59
60    fn embed_trigger(&self, text: &str) -> Result<Vec<f32>> {
61        Ok(hash_to_vec(text, self.trigger_dim))
62    }
63}
64
65fn hash_to_vec(text: &str, dim: usize) -> Vec<f32> {
66    let h = content_hash(text);
67    let bytes = h.as_bytes();
68    let mut v: Vec<f32> = (0..dim)
69        .map(|i| {
70            let b = bytes[i % bytes.len()] as f32;
71            (b / 255.0) * 2.0 - 1.0
72        })
73        .collect();
74    // L2-normalise
75    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
76    if norm > 0.0 {
77        for x in &mut v {
78            *x /= norm;
79        }
80    }
81    v
82}
83
84/// Serialise a provider's embedding as raw bytes for DB storage.
85pub fn embed_to_bytes(
86    provider: &dyn EmbeddingProvider,
87    text: &str,
88    trigger: bool,
89) -> Result<Vec<u8>> {
90    let vec = if trigger {
91        provider.embed_trigger(text)?
92    } else {
93        provider.embed_content(text)?
94    };
95    Ok(pack_embedding(&vec))
96}