Skip to main content

mnemara_core/
embedding.rs

1use crate::config::{EmbeddingProviderKind, EngineConfig};
2use std::fmt;
3use std::sync::Arc;
4
5#[derive(Debug, Clone, PartialEq)]
6pub struct EmbeddingVector {
7    pub values: Vec<f32>,
8}
9
10impl EmbeddingVector {
11    pub fn cosine_similarity(&self, other: &Self) -> f32 {
12        if self.values.is_empty() || self.values.len() != other.values.len() {
13            return 0.0;
14        }
15
16        let mut dot = 0.0;
17        let mut left_norm = 0.0;
18        let mut right_norm = 0.0;
19        for (left, right) in self.values.iter().zip(&other.values) {
20            dot += left * right;
21            left_norm += left * left;
22            right_norm += right * right;
23        }
24
25        if left_norm == 0.0 || right_norm == 0.0 {
26            return 0.0;
27        }
28
29        dot / (left_norm.sqrt() * right_norm.sqrt())
30    }
31}
32
33pub trait SemanticEmbedder: Send + Sync {
34    fn provider_kind(&self) -> EmbeddingProviderKind;
35    fn dimensions(&self) -> usize;
36    fn embed(&self, text: &str) -> EmbeddingVector;
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub struct DisabledEmbedder;
41
42impl SemanticEmbedder for DisabledEmbedder {
43    fn provider_kind(&self) -> EmbeddingProviderKind {
44        EmbeddingProviderKind::Disabled
45    }
46
47    fn dimensions(&self) -> usize {
48        0
49    }
50
51    fn embed(&self, _text: &str) -> EmbeddingVector {
52        EmbeddingVector { values: Vec::new() }
53    }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct DeterministicLocalEmbedder {
58    dimensions: usize,
59}
60
61impl DeterministicLocalEmbedder {
62    pub fn new(dimensions: usize) -> Self {
63        Self {
64            dimensions: dimensions.max(1),
65        }
66    }
67
68    fn hash_with_seed(term: &str, seed: u64) -> u64 {
69        let mut hash = seed;
70        for byte in term.as_bytes() {
71            hash ^= u64::from(*byte);
72            hash = hash.wrapping_mul(1099511628211);
73        }
74        hash
75    }
76
77    fn bucket_for(term: &str, dimensions: usize, seed: u64) -> usize {
78        (Self::hash_with_seed(term, seed) as usize) % dimensions
79    }
80
81    fn signed_weight(term: &str) -> f32 {
82        if Self::hash_with_seed(term, 7809847782465536322u64) & 1 == 0 {
83            1.0
84        } else {
85            -1.0
86        }
87    }
88}
89
90impl SemanticEmbedder for DeterministicLocalEmbedder {
91    fn provider_kind(&self) -> EmbeddingProviderKind {
92        EmbeddingProviderKind::DeterministicLocal
93    }
94
95    fn dimensions(&self) -> usize {
96        self.dimensions
97    }
98
99    fn embed(&self, text: &str) -> EmbeddingVector {
100        let mut values = vec![0.0; self.dimensions];
101        for term in text
102            .split_whitespace()
103            .map(|term| term.trim_matches(|ch: char| !ch.is_alphanumeric()))
104            .filter(|term| !term.is_empty())
105            .map(|term| term.to_ascii_lowercase())
106        {
107            let primary_bucket = Self::bucket_for(&term, self.dimensions, 1469598103934665603u64);
108            let secondary_bucket = Self::bucket_for(&term, self.dimensions, 1099511628211u64);
109            let sign = Self::signed_weight(&term);
110            values[primary_bucket] += sign;
111            if self.dimensions > 1 {
112                values[secondary_bucket] += sign * 0.5;
113            }
114        }
115
116        let norm = values.iter().map(|value| value * value).sum::<f32>().sqrt();
117        if norm > 0.0 {
118            for value in &mut values {
119                *value /= norm;
120            }
121        }
122
123        EmbeddingVector { values }
124    }
125}
126
127#[derive(Clone)]
128pub struct SharedSemanticEmbedder {
129    provider_note: String,
130    embedder: Arc<dyn SemanticEmbedder>,
131}
132
133impl SharedSemanticEmbedder {
134    pub fn new(embedder: Arc<dyn SemanticEmbedder>, provider_note: impl Into<String>) -> Self {
135        Self {
136            provider_note: provider_note.into(),
137            embedder,
138        }
139    }
140
141    pub fn provider_note(&self) -> &str {
142        &self.provider_note
143    }
144}
145
146impl fmt::Debug for SharedSemanticEmbedder {
147    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148        f.debug_struct("SharedSemanticEmbedder")
149            .field("provider_note", &self.provider_note)
150            .field("provider_kind", &self.embedder.provider_kind())
151            .field("dimensions", &self.embedder.dimensions())
152            .finish()
153    }
154}
155
156impl SemanticEmbedder for SharedSemanticEmbedder {
157    fn provider_kind(&self) -> EmbeddingProviderKind {
158        self.embedder.provider_kind()
159    }
160
161    fn dimensions(&self) -> usize {
162        self.embedder.dimensions()
163    }
164
165    fn embed(&self, text: &str) -> EmbeddingVector {
166        self.embedder.embed(text)
167    }
168}
169
170#[derive(Debug, Clone)]
171pub enum ConfiguredSemanticEmbedder {
172    Disabled(DisabledEmbedder),
173    DeterministicLocal(DeterministicLocalEmbedder),
174    Shared(SharedSemanticEmbedder),
175}
176
177impl ConfiguredSemanticEmbedder {
178    pub fn from_engine_config(config: &EngineConfig) -> Self {
179        match config.embedding_provider_kind {
180            EmbeddingProviderKind::Disabled => Self::Disabled(DisabledEmbedder),
181            EmbeddingProviderKind::DeterministicLocal => Self::DeterministicLocal(
182                DeterministicLocalEmbedder::new(config.embedding_dimensions),
183            ),
184        }
185    }
186
187    pub fn shared(embedder: Arc<dyn SemanticEmbedder>, provider_note: impl Into<String>) -> Self {
188        Self::Shared(SharedSemanticEmbedder::new(embedder, provider_note))
189    }
190
191    pub fn provider_note(&self) -> Option<String> {
192        match self {
193            Self::Disabled(_) => None,
194            Self::DeterministicLocal(_) => {
195                Some("embedding_provider=deterministic_local".to_string())
196            }
197            Self::Shared(embedder) => Some(embedder.provider_note().to_string()),
198        }
199    }
200}
201
202impl SemanticEmbedder for ConfiguredSemanticEmbedder {
203    fn provider_kind(&self) -> EmbeddingProviderKind {
204        match self {
205            Self::Disabled(embedder) => embedder.provider_kind(),
206            Self::DeterministicLocal(embedder) => embedder.provider_kind(),
207            Self::Shared(embedder) => embedder.provider_kind(),
208        }
209    }
210
211    fn dimensions(&self) -> usize {
212        match self {
213            Self::Disabled(embedder) => embedder.dimensions(),
214            Self::DeterministicLocal(embedder) => embedder.dimensions(),
215            Self::Shared(embedder) => embedder.dimensions(),
216        }
217    }
218
219    fn embed(&self, text: &str) -> EmbeddingVector {
220        match self {
221            Self::Disabled(embedder) => embedder.embed(text),
222            Self::DeterministicLocal(embedder) => embedder.embed(text),
223            Self::Shared(embedder) => embedder.embed(text),
224        }
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    #![allow(clippy::field_reassign_with_default)]
231
232    use super::{
233        ConfiguredSemanticEmbedder, DeterministicLocalEmbedder, EmbeddingVector, SemanticEmbedder,
234    };
235    use crate::config::{EmbeddingProviderKind, EngineConfig};
236    use std::sync::Arc;
237
238    #[derive(Debug)]
239    struct FixedEmbedder;
240
241    impl SemanticEmbedder for FixedEmbedder {
242        fn provider_kind(&self) -> EmbeddingProviderKind {
243            EmbeddingProviderKind::Disabled
244        }
245
246        fn dimensions(&self) -> usize {
247            2
248        }
249
250        fn embed(&self, text: &str) -> EmbeddingVector {
251            if text.contains("storm") {
252                EmbeddingVector {
253                    values: vec![1.0, 0.0],
254                }
255            } else {
256                EmbeddingVector {
257                    values: vec![0.0, 1.0],
258                }
259            }
260        }
261    }
262
263    #[test]
264    fn deterministic_embedder_returns_stable_dimensions() {
265        let embedder = DeterministicLocalEmbedder::new(8);
266        let vector = embedder.embed("storm checklist storm");
267        assert_eq!(vector.values.len(), 8);
268        assert!(vector.values.iter().any(|value| *value > 0.0));
269    }
270
271    #[test]
272    fn deterministic_embedder_scores_related_texts_higher() {
273        let embedder = DeterministicLocalEmbedder::new(64);
274        let related = embedder
275            .embed("verified storm checklist")
276            .cosine_similarity(&embedder.embed("storm checklist for verified runbook"));
277        let unrelated = embedder
278            .embed("verified storm checklist")
279            .cosine_similarity(&embedder.embed("audio waveform synthesis"));
280        assert!(related > unrelated);
281    }
282
283    #[test]
284    fn configured_embedder_uses_engine_config_provider() {
285        let mut config = EngineConfig::default();
286        config.embedding_provider_kind = EmbeddingProviderKind::DeterministicLocal;
287        config.embedding_dimensions = 12;
288
289        let embedder = ConfiguredSemanticEmbedder::from_engine_config(&config);
290        assert_eq!(
291            embedder.provider_kind(),
292            EmbeddingProviderKind::DeterministicLocal
293        );
294        assert_eq!(embedder.dimensions(), 12);
295    }
296
297    #[test]
298    fn configured_embedder_disabled_is_safe_fallback() {
299        let config = EngineConfig::default();
300
301        let embedder = ConfiguredSemanticEmbedder::from_engine_config(&config);
302        let vector = embedder.embed("storm checklist remediation");
303
304        assert_eq!(embedder.provider_kind(), EmbeddingProviderKind::Disabled);
305        assert_eq!(embedder.dimensions(), 0);
306        assert!(vector.values.is_empty());
307    }
308
309    #[test]
310    fn cosine_similarity_returns_zero_for_mismatched_vectors() {
311        let left = DeterministicLocalEmbedder::new(8).embed("storm checklist");
312        let right = DeterministicLocalEmbedder::new(16).embed("storm checklist");
313
314        assert_eq!(left.cosine_similarity(&right), 0.0);
315    }
316
317    #[test]
318    fn shared_embedder_keeps_custom_provider_note() {
319        let embedder = ConfiguredSemanticEmbedder::shared(
320            Arc::new(FixedEmbedder),
321            "embedding_provider=fixture_custom",
322        );
323
324        assert_eq!(embedder.dimensions(), 2);
325        assert_eq!(
326            embedder.provider_note().as_deref(),
327            Some("embedding_provider=fixture_custom")
328        );
329        assert!(
330            embedder
331                .embed("storm checklist")
332                .cosine_similarity(&embedder.embed("storm runbook"))
333                > 0.0
334        );
335    }
336}