1pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Option<f64> {
16 if a.len() != b.len() || a.is_empty() {
17 return None;
18 }
19
20 let mut dot = 0.0f64;
21 let mut norm_a = 0.0f64;
22 let mut norm_b = 0.0f64;
23
24 for (&ai, &bi) in a.iter().zip(b.iter()) {
25 dot += ai as f64 * bi as f64;
26 norm_a += (ai as f64) * (ai as f64);
27 norm_b += (bi as f64) * (bi as f64);
28 }
29
30 let norm = (norm_a * norm_b).sqrt();
31 if norm == 0.0 {
32 return None;
33 }
34
35 Some(dot / norm)
36}
37
38pub fn normalized_similarity(a: &[f32], b: &[f32]) -> Option<f64> {
45 cosine_similarity(a, b).map(|cos| (cos + 1.0) / 2.0)
46}
47
48#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
50pub struct SemanticWiringConfig {
51 pub min_similarity: f64,
54
55 pub similarity_influence: f64,
58
59 pub require_embeddings: bool,
62}
63
64impl Default for SemanticWiringConfig {
65 fn default() -> Self {
66 Self {
67 min_similarity: 0.0,
68 similarity_influence: 0.5,
69 require_embeddings: false,
70 }
71 }
72}
73
74impl SemanticWiringConfig {
75 pub fn strict() -> Self {
77 Self {
78 min_similarity: 0.3,
79 similarity_influence: 1.0,
80 require_embeddings: true,
81 }
82 }
83
84 pub fn relaxed() -> Self {
86 Self {
87 min_similarity: 0.0,
88 similarity_influence: 0.3,
89 require_embeddings: false,
90 }
91 }
92}
93
94pub fn compute_semantic_weight(
105 base_weight: f64,
106 embedding_a: Option<&[f32]>,
107 embedding_b: Option<&[f32]>,
108 config: &SemanticWiringConfig,
109) -> Option<f64> {
110 match (embedding_a, embedding_b) {
111 (Some(a), Some(b)) => {
112 let similarity = normalized_similarity(a, b)?;
113 if similarity < config.min_similarity {
114 if config.require_embeddings {
115 return None;
116 }
117 return Some(base_weight);
119 }
120 let boosted = base_weight * (1.0 + config.similarity_influence * similarity);
122 Some(boosted.min(1.0))
123 }
124 _ => {
125 if config.require_embeddings {
126 None
127 } else {
128 Some(base_weight)
129 }
130 }
131 }
132}
133
134pub fn l2_distance(a: &[f32], b: &[f32]) -> Option<f64> {
136 if a.len() != b.len() || a.is_empty() {
137 return None;
138 }
139
140 let sum: f64 = a
141 .iter()
142 .zip(b.iter())
143 .map(|(&ai, &bi)| {
144 let diff = ai as f64 - bi as f64;
145 diff * diff
146 })
147 .sum();
148
149 Some(sum.sqrt())
150}
151
152pub fn dot_product(a: &[f32], b: &[f32]) -> Option<f64> {
154 if a.len() != b.len() || a.is_empty() {
155 return None;
156 }
157
158 let dot: f64 = a
159 .iter()
160 .zip(b.iter())
161 .map(|(&ai, &bi)| ai as f64 * bi as f64)
162 .sum();
163
164 Some(dot)
165}
166
167pub fn l2_normalize(v: &mut [f32]) {
169 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
170 if norm > 0.0 {
171 for x in v.iter_mut() {
172 *x /= norm;
173 }
174 }
175}
176
177pub fn l2_normalized(v: &[f32]) -> Vec<f32> {
179 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
180 if norm > 0.0 {
181 v.iter().map(|x| x / norm).collect()
182 } else {
183 v.to_vec()
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn cosine_similarity_identical_vectors() {
193 let a = vec![1.0, 2.0, 3.0];
194 let b = vec![1.0, 2.0, 3.0];
195 let sim = cosine_similarity(&a, &b).unwrap();
196 assert!((sim - 1.0).abs() < 1e-6);
197 }
198
199 #[test]
200 fn cosine_similarity_orthogonal_vectors() {
201 let a = vec![1.0, 0.0, 0.0];
202 let b = vec![0.0, 1.0, 0.0];
203 let sim = cosine_similarity(&a, &b).unwrap();
204 assert!(sim.abs() < 1e-6);
205 }
206
207 #[test]
208 fn cosine_similarity_opposite_vectors() {
209 let a = vec![1.0, 2.0, 3.0];
210 let b = vec![-1.0, -2.0, -3.0];
211 let sim = cosine_similarity(&a, &b).unwrap();
212 assert!((sim + 1.0).abs() < 1e-6);
213 }
214
215 #[test]
216 fn cosine_similarity_different_lengths() {
217 let a = vec![1.0, 2.0];
218 let b = vec![1.0, 2.0, 3.0];
219 assert!(cosine_similarity(&a, &b).is_none());
220 }
221
222 #[test]
223 fn cosine_similarity_zero_vector() {
224 let a = vec![0.0, 0.0, 0.0];
225 let b = vec![1.0, 2.0, 3.0];
226 assert!(cosine_similarity(&a, &b).is_none());
227 }
228
229 #[test]
230 fn normalized_similarity_maps_to_zero_one() {
231 let a = vec![1.0, 2.0, 3.0];
233 assert!((normalized_similarity(&a, &a).unwrap() - 1.0).abs() < 1e-6);
234
235 let b = vec![-1.0, -2.0, -3.0];
237 assert!(normalized_similarity(&a, &b).unwrap().abs() < 1e-6);
238
239 let c = vec![1.0, 0.0];
241 let d = vec![0.0, 1.0];
242 assert!((normalized_similarity(&c, &d).unwrap() - 0.5).abs() < 1e-6);
243 }
244
245 #[test]
246 fn semantic_weight_with_embeddings() {
247 let config = SemanticWiringConfig::default();
248 let a = vec![1.0, 0.0, 0.0];
249 let b = vec![0.9, 0.1, 0.0]; let weight = compute_semantic_weight(0.1, Some(&a), Some(&b), &config).unwrap();
252 assert!(weight > 0.1);
254 }
255
256 #[test]
257 fn semantic_weight_without_embeddings_relaxed() {
258 let config = SemanticWiringConfig::relaxed();
259 let weight = compute_semantic_weight(0.1, None, None, &config).unwrap();
260 assert!((weight - 0.1).abs() < 1e-6);
261 }
262
263 #[test]
264 fn semantic_weight_without_embeddings_strict() {
265 let config = SemanticWiringConfig::strict();
266 let weight = compute_semantic_weight(0.1, None, None, &config);
267 assert!(weight.is_none());
268 }
269
270 #[test]
271 fn semantic_weight_below_threshold() {
272 let config = SemanticWiringConfig {
273 min_similarity: 0.9,
274 similarity_influence: 1.0,
275 require_embeddings: true,
276 };
277 let a = vec![1.0, 0.0];
279 let b = vec![0.0, 1.0];
280 let weight = compute_semantic_weight(0.1, Some(&a), Some(&b), &config);
281 assert!(weight.is_none());
282 }
283
284 #[test]
285 fn l2_distance_works() {
286 let a = vec![0.0, 0.0, 0.0];
287 let b = vec![1.0, 0.0, 0.0];
288 assert!((l2_distance(&a, &b).unwrap() - 1.0).abs() < 1e-6);
289 }
290
291 #[test]
292 fn l2_normalize_works() {
293 let mut v = vec![3.0, 4.0];
294 l2_normalize(&mut v);
295 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
296 assert!((norm - 1.0).abs() < 1e-5);
297 }
298
299 #[test]
300 fn dot_product_works() {
301 let a = vec![1.0, 2.0, 3.0];
302 let b = vec![4.0, 5.0, 6.0];
303 assert!((dot_product(&a, &b).unwrap() - 32.0).abs() < 1e-6);
305 }
306}