mnemara_core/
embedding.rs1use crate::config::{EmbeddingProviderKind, EngineConfig};
2use std::fmt;
3use std::sync::Arc;
4
5#[derive(Debug, Clone, PartialEq)]
6pub struct EmbeddingVector {
7 pub values: Vec<f32>,
8}
9
10impl EmbeddingVector {
11 pub fn cosine_similarity(&self, other: &Self) -> f32 {
12 if self.values.is_empty() || self.values.len() != other.values.len() {
13 return 0.0;
14 }
15
16 let mut dot = 0.0;
17 let mut left_norm = 0.0;
18 let mut right_norm = 0.0;
19 for (left, right) in self.values.iter().zip(&other.values) {
20 dot += left * right;
21 left_norm += left * left;
22 right_norm += right * right;
23 }
24
25 if left_norm == 0.0 || right_norm == 0.0 {
26 return 0.0;
27 }
28
29 dot / (left_norm.sqrt() * right_norm.sqrt())
30 }
31}
32
33pub trait SemanticEmbedder: Send + Sync {
34 fn provider_kind(&self) -> EmbeddingProviderKind;
35 fn dimensions(&self) -> usize;
36 fn embed(&self, text: &str) -> EmbeddingVector;
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
40pub struct DisabledEmbedder;
41
42impl SemanticEmbedder for DisabledEmbedder {
43 fn provider_kind(&self) -> EmbeddingProviderKind {
44 EmbeddingProviderKind::Disabled
45 }
46
47 fn dimensions(&self) -> usize {
48 0
49 }
50
51 fn embed(&self, _text: &str) -> EmbeddingVector {
52 EmbeddingVector { values: Vec::new() }
53 }
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct DeterministicLocalEmbedder {
58 dimensions: usize,
59}
60
61impl DeterministicLocalEmbedder {
62 pub fn new(dimensions: usize) -> Self {
63 Self {
64 dimensions: dimensions.max(1),
65 }
66 }
67
68 fn hash_with_seed(term: &str, seed: u64) -> u64 {
69 let mut hash = seed;
70 for byte in term.as_bytes() {
71 hash ^= u64::from(*byte);
72 hash = hash.wrapping_mul(1099511628211);
73 }
74 hash
75 }
76
77 fn bucket_for(term: &str, dimensions: usize, seed: u64) -> usize {
78 (Self::hash_with_seed(term, seed) as usize) % dimensions
79 }
80
81 fn signed_weight(term: &str) -> f32 {
82 if Self::hash_with_seed(term, 7809847782465536322u64) & 1 == 0 {
83 1.0
84 } else {
85 -1.0
86 }
87 }
88}
89
90impl SemanticEmbedder for DeterministicLocalEmbedder {
91 fn provider_kind(&self) -> EmbeddingProviderKind {
92 EmbeddingProviderKind::DeterministicLocal
93 }
94
95 fn dimensions(&self) -> usize {
96 self.dimensions
97 }
98
99 fn embed(&self, text: &str) -> EmbeddingVector {
100 let mut values = vec![0.0; self.dimensions];
101 for term in text
102 .split_whitespace()
103 .map(|term| term.trim_matches(|ch: char| !ch.is_alphanumeric()))
104 .filter(|term| !term.is_empty())
105 .map(|term| term.to_ascii_lowercase())
106 {
107 let primary_bucket = Self::bucket_for(&term, self.dimensions, 1469598103934665603u64);
108 let secondary_bucket = Self::bucket_for(&term, self.dimensions, 1099511628211u64);
109 let sign = Self::signed_weight(&term);
110 values[primary_bucket] += sign;
111 if self.dimensions > 1 {
112 values[secondary_bucket] += sign * 0.5;
113 }
114 }
115
116 let norm = values.iter().map(|value| value * value).sum::<f32>().sqrt();
117 if norm > 0.0 {
118 for value in &mut values {
119 *value /= norm;
120 }
121 }
122
123 EmbeddingVector { values }
124 }
125}
126
127#[derive(Clone)]
128pub struct SharedSemanticEmbedder {
129 provider_note: String,
130 embedder: Arc<dyn SemanticEmbedder>,
131}
132
133impl SharedSemanticEmbedder {
134 pub fn new(embedder: Arc<dyn SemanticEmbedder>, provider_note: impl Into<String>) -> Self {
135 Self {
136 provider_note: provider_note.into(),
137 embedder,
138 }
139 }
140
141 pub fn provider_note(&self) -> &str {
142 &self.provider_note
143 }
144}
145
146impl fmt::Debug for SharedSemanticEmbedder {
147 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148 f.debug_struct("SharedSemanticEmbedder")
149 .field("provider_note", &self.provider_note)
150 .field("provider_kind", &self.embedder.provider_kind())
151 .field("dimensions", &self.embedder.dimensions())
152 .finish()
153 }
154}
155
156impl SemanticEmbedder for SharedSemanticEmbedder {
157 fn provider_kind(&self) -> EmbeddingProviderKind {
158 self.embedder.provider_kind()
159 }
160
161 fn dimensions(&self) -> usize {
162 self.embedder.dimensions()
163 }
164
165 fn embed(&self, text: &str) -> EmbeddingVector {
166 self.embedder.embed(text)
167 }
168}
169
170#[derive(Debug, Clone)]
171pub enum ConfiguredSemanticEmbedder {
172 Disabled(DisabledEmbedder),
173 DeterministicLocal(DeterministicLocalEmbedder),
174 Shared(SharedSemanticEmbedder),
175}
176
177impl ConfiguredSemanticEmbedder {
178 pub fn from_engine_config(config: &EngineConfig) -> Self {
179 match config.embedding_provider_kind {
180 EmbeddingProviderKind::Disabled => Self::Disabled(DisabledEmbedder),
181 EmbeddingProviderKind::DeterministicLocal => Self::DeterministicLocal(
182 DeterministicLocalEmbedder::new(config.embedding_dimensions),
183 ),
184 }
185 }
186
187 pub fn shared(embedder: Arc<dyn SemanticEmbedder>, provider_note: impl Into<String>) -> Self {
188 Self::Shared(SharedSemanticEmbedder::new(embedder, provider_note))
189 }
190
191 pub fn provider_note(&self) -> Option<String> {
192 match self {
193 Self::Disabled(_) => None,
194 Self::DeterministicLocal(_) => {
195 Some("embedding_provider=deterministic_local".to_string())
196 }
197 Self::Shared(embedder) => Some(embedder.provider_note().to_string()),
198 }
199 }
200}
201
202impl SemanticEmbedder for ConfiguredSemanticEmbedder {
203 fn provider_kind(&self) -> EmbeddingProviderKind {
204 match self {
205 Self::Disabled(embedder) => embedder.provider_kind(),
206 Self::DeterministicLocal(embedder) => embedder.provider_kind(),
207 Self::Shared(embedder) => embedder.provider_kind(),
208 }
209 }
210
211 fn dimensions(&self) -> usize {
212 match self {
213 Self::Disabled(embedder) => embedder.dimensions(),
214 Self::DeterministicLocal(embedder) => embedder.dimensions(),
215 Self::Shared(embedder) => embedder.dimensions(),
216 }
217 }
218
219 fn embed(&self, text: &str) -> EmbeddingVector {
220 match self {
221 Self::Disabled(embedder) => embedder.embed(text),
222 Self::DeterministicLocal(embedder) => embedder.embed(text),
223 Self::Shared(embedder) => embedder.embed(text),
224 }
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 #![allow(clippy::field_reassign_with_default)]
231
232 use super::{
233 ConfiguredSemanticEmbedder, DeterministicLocalEmbedder, EmbeddingVector, SemanticEmbedder,
234 };
235 use crate::config::{EmbeddingProviderKind, EngineConfig};
236 use std::sync::Arc;
237
238 #[derive(Debug)]
239 struct FixedEmbedder;
240
241 impl SemanticEmbedder for FixedEmbedder {
242 fn provider_kind(&self) -> EmbeddingProviderKind {
243 EmbeddingProviderKind::Disabled
244 }
245
246 fn dimensions(&self) -> usize {
247 2
248 }
249
250 fn embed(&self, text: &str) -> EmbeddingVector {
251 if text.contains("storm") {
252 EmbeddingVector {
253 values: vec![1.0, 0.0],
254 }
255 } else {
256 EmbeddingVector {
257 values: vec![0.0, 1.0],
258 }
259 }
260 }
261 }
262
263 #[test]
264 fn deterministic_embedder_returns_stable_dimensions() {
265 let embedder = DeterministicLocalEmbedder::new(8);
266 let vector = embedder.embed("storm checklist storm");
267 assert_eq!(vector.values.len(), 8);
268 assert!(vector.values.iter().any(|value| *value > 0.0));
269 }
270
271 #[test]
272 fn deterministic_embedder_scores_related_texts_higher() {
273 let embedder = DeterministicLocalEmbedder::new(64);
274 let related = embedder
275 .embed("verified storm checklist")
276 .cosine_similarity(&embedder.embed("storm checklist for verified runbook"));
277 let unrelated = embedder
278 .embed("verified storm checklist")
279 .cosine_similarity(&embedder.embed("audio waveform synthesis"));
280 assert!(related > unrelated);
281 }
282
283 #[test]
284 fn configured_embedder_uses_engine_config_provider() {
285 let mut config = EngineConfig::default();
286 config.embedding_provider_kind = EmbeddingProviderKind::DeterministicLocal;
287 config.embedding_dimensions = 12;
288
289 let embedder = ConfiguredSemanticEmbedder::from_engine_config(&config);
290 assert_eq!(
291 embedder.provider_kind(),
292 EmbeddingProviderKind::DeterministicLocal
293 );
294 assert_eq!(embedder.dimensions(), 12);
295 }
296
297 #[test]
298 fn configured_embedder_disabled_is_safe_fallback() {
299 let config = EngineConfig::default();
300
301 let embedder = ConfiguredSemanticEmbedder::from_engine_config(&config);
302 let vector = embedder.embed("storm checklist remediation");
303
304 assert_eq!(embedder.provider_kind(), EmbeddingProviderKind::Disabled);
305 assert_eq!(embedder.dimensions(), 0);
306 assert!(vector.values.is_empty());
307 }
308
309 #[test]
310 fn cosine_similarity_returns_zero_for_mismatched_vectors() {
311 let left = DeterministicLocalEmbedder::new(8).embed("storm checklist");
312 let right = DeterministicLocalEmbedder::new(16).embed("storm checklist");
313
314 assert_eq!(left.cosine_similarity(&right), 0.0);
315 }
316
317 #[test]
318 fn shared_embedder_keeps_custom_provider_note() {
319 let embedder = ConfiguredSemanticEmbedder::shared(
320 Arc::new(FixedEmbedder),
321 "embedding_provider=fixture_custom",
322 );
323
324 assert_eq!(embedder.dimensions(), 2);
325 assert_eq!(
326 embedder.provider_note().as_deref(),
327 Some("embedding_provider=fixture_custom")
328 );
329 assert!(
330 embedder
331 .embed("storm checklist")
332 .cosine_similarity(&embedder.embed("storm runbook"))
333 > 0.0
334 );
335 }
336}