Skip to main content

phago_core/
semantic.rs

1//! Semantic similarity utilities for vector embeddings.
2//!
3//! Provides cosine similarity computation and semantic wiring logic
4//! for the knowledge graph. When nodes have embeddings, edge weights
5//! can be modulated by semantic similarity.
6
7/// Compute cosine similarity between two vectors.
8///
9/// Returns a value in [-1, 1] where:
10/// - 1.0 = identical direction
11/// - 0.0 = orthogonal
12/// - -1.0 = opposite direction
13///
14/// Returns None if vectors have different lengths or are empty/zero-norm.
15pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Option<f64> {
16    if a.len() != b.len() || a.is_empty() {
17        return None;
18    }
19
20    let mut dot = 0.0f64;
21    let mut norm_a = 0.0f64;
22    let mut norm_b = 0.0f64;
23
24    for (&ai, &bi) in a.iter().zip(b.iter()) {
25        dot += ai as f64 * bi as f64;
26        norm_a += (ai as f64) * (ai as f64);
27        norm_b += (bi as f64) * (bi as f64);
28    }
29
30    let norm = (norm_a * norm_b).sqrt();
31    if norm == 0.0 {
32        return None;
33    }
34
35    Some(dot / norm)
36}
37
38/// Compute similarity between two embeddings, normalized to [0, 1].
39///
40/// Uses cosine similarity internally but maps the result from [-1, 1] to [0, 1]
41/// using: `(cosine + 1) / 2`
42///
43/// This is more suitable for edge weights which should be non-negative.
44pub fn normalized_similarity(a: &[f32], b: &[f32]) -> Option<f64> {
45    cosine_similarity(a, b).map(|cos| (cos + 1.0) / 2.0)
46}
47
48/// Configuration for semantic wiring.
49#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
50pub struct SemanticWiringConfig {
51    /// Minimum similarity threshold for creating/strengthening edges.
52    /// Edges between concepts with similarity below this are not created.
53    pub min_similarity: f64,
54
55    /// Weight multiplier for semantic similarity.
56    /// Final weight = base_weight * (1 + similarity_influence * similarity)
57    pub similarity_influence: f64,
58
59    /// Whether to require both nodes to have embeddings.
60    /// If false, edges between nodes without embeddings use base weight only.
61    pub require_embeddings: bool,
62}
63
64impl Default for SemanticWiringConfig {
65    fn default() -> Self {
66        Self {
67            min_similarity: 0.0,
68            similarity_influence: 0.5,
69            require_embeddings: false,
70        }
71    }
72}
73
74impl SemanticWiringConfig {
75    /// Create a strict configuration that only wires semantically similar concepts.
76    pub fn strict() -> Self {
77        Self {
78            min_similarity: 0.3,
79            similarity_influence: 1.0,
80            require_embeddings: true,
81        }
82    }
83
84    /// Create a relaxed configuration that uses similarity as a boost.
85    pub fn relaxed() -> Self {
86        Self {
87            min_similarity: 0.0,
88            similarity_influence: 0.3,
89            require_embeddings: false,
90        }
91    }
92}
93
94/// Compute the edge weight based on base weight and semantic similarity.
95///
96/// If both nodes have embeddings and similarity meets the threshold:
97/// `weight = base_weight * (1 + similarity_influence * similarity)`
98///
99/// If embeddings are missing and `require_embeddings` is false:
100/// `weight = base_weight`
101///
102/// If embeddings are missing and `require_embeddings` is true:
103/// Returns None (edge should not be created).
104pub fn compute_semantic_weight(
105    base_weight: f64,
106    embedding_a: Option<&[f32]>,
107    embedding_b: Option<&[f32]>,
108    config: &SemanticWiringConfig,
109) -> Option<f64> {
110    match (embedding_a, embedding_b) {
111        (Some(a), Some(b)) => {
112            let similarity = normalized_similarity(a, b)?;
113            if similarity < config.min_similarity {
114                if config.require_embeddings {
115                    return None;
116                }
117                // Below threshold but not requiring similarity — use base weight
118                return Some(base_weight);
119            }
120            // Boost weight based on similarity
121            let boosted = base_weight * (1.0 + config.similarity_influence * similarity);
122            Some(boosted.min(1.0))
123        }
124        _ => {
125            if config.require_embeddings {
126                None
127            } else {
128                Some(base_weight)
129            }
130        }
131    }
132}
133
134/// Compute L2 distance between two vectors.
135pub fn l2_distance(a: &[f32], b: &[f32]) -> Option<f64> {
136    if a.len() != b.len() || a.is_empty() {
137        return None;
138    }
139
140    let sum: f64 = a
141        .iter()
142        .zip(b.iter())
143        .map(|(&ai, &bi)| {
144            let diff = ai as f64 - bi as f64;
145            diff * diff
146        })
147        .sum();
148
149    Some(sum.sqrt())
150}
151
152/// Compute dot product between two vectors.
153pub fn dot_product(a: &[f32], b: &[f32]) -> Option<f64> {
154    if a.len() != b.len() || a.is_empty() {
155        return None;
156    }
157
158    let dot: f64 = a
159        .iter()
160        .zip(b.iter())
161        .map(|(&ai, &bi)| ai as f64 * bi as f64)
162        .sum();
163
164    Some(dot)
165}
166
167/// L2 normalize a vector in place.
168pub fn l2_normalize(v: &mut [f32]) {
169    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
170    if norm > 0.0 {
171        for x in v.iter_mut() {
172            *x /= norm;
173        }
174    }
175}
176
177/// L2 normalize a vector, returning a new vector.
178pub fn l2_normalized(v: &[f32]) -> Vec<f32> {
179    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
180    if norm > 0.0 {
181        v.iter().map(|x| x / norm).collect()
182    } else {
183        v.to_vec()
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[test]
192    fn cosine_similarity_identical_vectors() {
193        let a = vec![1.0, 2.0, 3.0];
194        let b = vec![1.0, 2.0, 3.0];
195        let sim = cosine_similarity(&a, &b).unwrap();
196        assert!((sim - 1.0).abs() < 1e-6);
197    }
198
199    #[test]
200    fn cosine_similarity_orthogonal_vectors() {
201        let a = vec![1.0, 0.0, 0.0];
202        let b = vec![0.0, 1.0, 0.0];
203        let sim = cosine_similarity(&a, &b).unwrap();
204        assert!(sim.abs() < 1e-6);
205    }
206
207    #[test]
208    fn cosine_similarity_opposite_vectors() {
209        let a = vec![1.0, 2.0, 3.0];
210        let b = vec![-1.0, -2.0, -3.0];
211        let sim = cosine_similarity(&a, &b).unwrap();
212        assert!((sim + 1.0).abs() < 1e-6);
213    }
214
215    #[test]
216    fn cosine_similarity_different_lengths() {
217        let a = vec![1.0, 2.0];
218        let b = vec![1.0, 2.0, 3.0];
219        assert!(cosine_similarity(&a, &b).is_none());
220    }
221
222    #[test]
223    fn cosine_similarity_zero_vector() {
224        let a = vec![0.0, 0.0, 0.0];
225        let b = vec![1.0, 2.0, 3.0];
226        assert!(cosine_similarity(&a, &b).is_none());
227    }
228
229    #[test]
230    fn normalized_similarity_maps_to_zero_one() {
231        // Identical = 1.0
232        let a = vec![1.0, 2.0, 3.0];
233        assert!((normalized_similarity(&a, &a).unwrap() - 1.0).abs() < 1e-6);
234
235        // Opposite = 0.0
236        let b = vec![-1.0, -2.0, -3.0];
237        assert!(normalized_similarity(&a, &b).unwrap().abs() < 1e-6);
238
239        // Orthogonal = 0.5
240        let c = vec![1.0, 0.0];
241        let d = vec![0.0, 1.0];
242        assert!((normalized_similarity(&c, &d).unwrap() - 0.5).abs() < 1e-6);
243    }
244
245    #[test]
246    fn semantic_weight_with_embeddings() {
247        let config = SemanticWiringConfig::default();
248        let a = vec![1.0, 0.0, 0.0];
249        let b = vec![0.9, 0.1, 0.0]; // Similar to a
250
251        let weight = compute_semantic_weight(0.1, Some(&a), Some(&b), &config).unwrap();
252        // Similarity is high, so weight should be boosted
253        assert!(weight > 0.1);
254    }
255
256    #[test]
257    fn semantic_weight_without_embeddings_relaxed() {
258        let config = SemanticWiringConfig::relaxed();
259        let weight = compute_semantic_weight(0.1, None, None, &config).unwrap();
260        assert!((weight - 0.1).abs() < 1e-6);
261    }
262
263    #[test]
264    fn semantic_weight_without_embeddings_strict() {
265        let config = SemanticWiringConfig::strict();
266        let weight = compute_semantic_weight(0.1, None, None, &config);
267        assert!(weight.is_none());
268    }
269
270    #[test]
271    fn semantic_weight_below_threshold() {
272        let config = SemanticWiringConfig {
273            min_similarity: 0.9,
274            similarity_influence: 1.0,
275            require_embeddings: true,
276        };
277        // Orthogonal vectors have similarity 0.5
278        let a = vec![1.0, 0.0];
279        let b = vec![0.0, 1.0];
280        let weight = compute_semantic_weight(0.1, Some(&a), Some(&b), &config);
281        assert!(weight.is_none());
282    }
283
284    #[test]
285    fn l2_distance_works() {
286        let a = vec![0.0, 0.0, 0.0];
287        let b = vec![1.0, 0.0, 0.0];
288        assert!((l2_distance(&a, &b).unwrap() - 1.0).abs() < 1e-6);
289    }
290
291    #[test]
292    fn l2_normalize_works() {
293        let mut v = vec![3.0, 4.0];
294        l2_normalize(&mut v);
295        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
296        assert!((norm - 1.0).abs() < 1e-5);
297    }
298
299    #[test]
300    fn dot_product_works() {
301        let a = vec![1.0, 2.0, 3.0];
302        let b = vec![4.0, 5.0, 6.0];
303        // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
304        assert!((dot_product(&a, &b).unwrap() - 32.0).abs() < 1e-6);
305    }
306}