rustkernel_graph/
similarity.rs

1//! Graph similarity kernels.
2//!
3//! This module provides similarity measures for graph nodes:
4//! - Jaccard similarity (neighbor set overlap)
5//! - Cosine similarity (normalized dot product)
6//! - Adamic-Adar index (weighted common neighbors)
7
8use crate::types::{CsrGraph, SimilarityScore};
9use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
10use std::collections::HashSet;
11
12// ============================================================================
13// Jaccard Similarity Kernel
14// ============================================================================
15
16/// Jaccard similarity kernel.
17///
18/// Computes Jaccard similarity: |N(u) ∩ N(v)| / |N(u) ∪ N(v)|
19/// where N(x) is the neighbor set of node x.
20#[derive(Debug, Clone)]
21pub struct JaccardSimilarity {
22    metadata: KernelMetadata,
23}
24
25impl JaccardSimilarity {
26    /// Create a new Jaccard similarity kernel.
27    #[must_use]
28    pub fn new() -> Self {
29        Self {
30            metadata: KernelMetadata::batch("graph/jaccard-similarity", Domain::GraphAnalytics)
31                .with_description("Jaccard similarity (neighbor set overlap)")
32                .with_throughput(100_000)
33                .with_latency_us(10.0),
34        }
35    }
36
37    /// Compute Jaccard similarity between two nodes.
38    pub fn compute_pair(graph: &CsrGraph, node_a: u64, node_b: u64) -> f64 {
39        let neighbors_a: HashSet<u64> = graph.neighbors(node_a).iter().copied().collect();
40        let neighbors_b: HashSet<u64> = graph.neighbors(node_b).iter().copied().collect();
41
42        let intersection = neighbors_a.intersection(&neighbors_b).count();
43        let union = neighbors_a.union(&neighbors_b).count();
44
45        if union == 0 {
46            0.0
47        } else {
48            intersection as f64 / union as f64
49        }
50    }
51
52    /// Compute Jaccard similarity for all pairs of nodes above a threshold.
53    ///
54    /// # Arguments
55    /// * `graph` - Input graph
56    /// * `min_similarity` - Only return pairs with similarity >= this threshold
57    /// * `max_pairs` - Maximum number of pairs to return
58    pub fn compute_all_pairs(
59        graph: &CsrGraph,
60        min_similarity: f64,
61        max_pairs: usize,
62    ) -> Vec<SimilarityScore> {
63        let n = graph.num_nodes;
64        let mut results = Vec::new();
65
66        for i in 0..n {
67            for j in (i + 1)..n {
68                let similarity = Self::compute_pair(graph, i as u64, j as u64);
69
70                if similarity >= min_similarity {
71                    results.push(SimilarityScore {
72                        id_a: i as u64,
73                        id_b: j as u64,
74                        similarity,
75                    });
76
77                    if results.len() >= max_pairs {
78                        return results;
79                    }
80                }
81            }
82        }
83
84        // Sort by similarity descending
85        results.sort_by(|a, b| {
86            b.similarity
87                .partial_cmp(&a.similarity)
88                .unwrap_or(std::cmp::Ordering::Equal)
89        });
90        results
91    }
92
93    /// Compute top-k most similar pairs using Jaccard similarity.
94    pub fn top_k_pairs(graph: &CsrGraph, k: usize) -> Vec<SimilarityScore> {
95        Self::compute_all_pairs(graph, 0.0, k)
96    }
97}
98
99impl Default for JaccardSimilarity {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl GpuKernel for JaccardSimilarity {
106    fn metadata(&self) -> &KernelMetadata {
107        &self.metadata
108    }
109}
110
111// ============================================================================
112// Cosine Similarity Kernel
113// ============================================================================
114
115/// Cosine similarity kernel.
116///
117/// Computes cosine similarity: |N(u) ∩ N(v)| / sqrt(|N(u)| * |N(v)|)
118/// This is the normalized version of common neighbors.
119#[derive(Debug, Clone)]
120pub struct CosineSimilarity {
121    metadata: KernelMetadata,
122}
123
124impl CosineSimilarity {
125    /// Create a new cosine similarity kernel.
126    #[must_use]
127    pub fn new() -> Self {
128        Self {
129            metadata: KernelMetadata::batch("graph/cosine-similarity", Domain::GraphAnalytics)
130                .with_description("Cosine similarity (normalized dot product)")
131                .with_throughput(100_000)
132                .with_latency_us(10.0),
133        }
134    }
135
136    /// Compute cosine similarity between two nodes.
137    pub fn compute_pair(graph: &CsrGraph, node_a: u64, node_b: u64) -> f64 {
138        let neighbors_a: HashSet<u64> = graph.neighbors(node_a).iter().copied().collect();
139        let neighbors_b: HashSet<u64> = graph.neighbors(node_b).iter().copied().collect();
140
141        let intersection = neighbors_a.intersection(&neighbors_b).count() as f64;
142        let norm = (neighbors_a.len() as f64 * neighbors_b.len() as f64).sqrt();
143
144        if norm == 0.0 {
145            0.0
146        } else {
147            intersection / norm
148        }
149    }
150
151    /// Compute cosine similarity for all pairs above a threshold.
152    pub fn compute_all_pairs(
153        graph: &CsrGraph,
154        min_similarity: f64,
155        max_pairs: usize,
156    ) -> Vec<SimilarityScore> {
157        let n = graph.num_nodes;
158        let mut results = Vec::new();
159
160        for i in 0..n {
161            for j in (i + 1)..n {
162                let similarity = Self::compute_pair(graph, i as u64, j as u64);
163
164                if similarity >= min_similarity {
165                    results.push(SimilarityScore {
166                        id_a: i as u64,
167                        id_b: j as u64,
168                        similarity,
169                    });
170
171                    if results.len() >= max_pairs {
172                        return results;
173                    }
174                }
175            }
176        }
177
178        results.sort_by(|a, b| {
179            b.similarity
180                .partial_cmp(&a.similarity)
181                .unwrap_or(std::cmp::Ordering::Equal)
182        });
183        results
184    }
185}
186
187impl Default for CosineSimilarity {
188    fn default() -> Self {
189        Self::new()
190    }
191}
192
193impl GpuKernel for CosineSimilarity {
194    fn metadata(&self) -> &KernelMetadata {
195        &self.metadata
196    }
197}
198
199// ============================================================================
200// Adamic-Adar Index Kernel
201// ============================================================================
202
203/// Adamic-Adar index kernel.
204///
205/// Computes Adamic-Adar index: Σ 1/log(|N(z)|) for all z ∈ N(u) ∩ N(v)
206/// This weights common neighbors inversely by their degree.
207#[derive(Debug, Clone)]
208pub struct AdamicAdarIndex {
209    metadata: KernelMetadata,
210}
211
212impl AdamicAdarIndex {
213    /// Create a new Adamic-Adar index kernel.
214    #[must_use]
215    pub fn new() -> Self {
216        Self {
217            metadata: KernelMetadata::batch("graph/adamic-adar", Domain::GraphAnalytics)
218                .with_description("Adamic-Adar index (weighted common neighbors)")
219                .with_throughput(100_000)
220                .with_latency_us(10.0),
221        }
222    }
223
224    /// Compute Adamic-Adar index between two nodes.
225    pub fn compute_pair(graph: &CsrGraph, node_a: u64, node_b: u64) -> f64 {
226        let neighbors_a: HashSet<u64> = graph.neighbors(node_a).iter().copied().collect();
227        let neighbors_b: HashSet<u64> = graph.neighbors(node_b).iter().copied().collect();
228
229        let common_neighbors = neighbors_a.intersection(&neighbors_b);
230
231        common_neighbors
232            .map(|&z| {
233                let degree = graph.out_degree(z) as f64;
234                if degree > 1.0 { 1.0 / degree.ln() } else { 0.0 }
235            })
236            .sum()
237    }
238
239    /// Compute Adamic-Adar for all pairs, returning those above threshold.
240    pub fn compute_all_pairs(
241        graph: &CsrGraph,
242        min_score: f64,
243        max_pairs: usize,
244    ) -> Vec<SimilarityScore> {
245        let n = graph.num_nodes;
246        let mut results = Vec::new();
247
248        for i in 0..n {
249            for j in (i + 1)..n {
250                let score = Self::compute_pair(graph, i as u64, j as u64);
251
252                if score >= min_score {
253                    results.push(SimilarityScore {
254                        id_a: i as u64,
255                        id_b: j as u64,
256                        similarity: score,
257                    });
258
259                    if results.len() >= max_pairs {
260                        return results;
261                    }
262                }
263            }
264        }
265
266        results.sort_by(|a, b| {
267            b.similarity
268                .partial_cmp(&a.similarity)
269                .unwrap_or(std::cmp::Ordering::Equal)
270        });
271        results
272    }
273
274    /// Find top-k most similar pairs for link prediction.
275    pub fn top_k_pairs(graph: &CsrGraph, k: usize) -> Vec<SimilarityScore> {
276        Self::compute_all_pairs(graph, 0.0, k)
277    }
278}
279
280impl Default for AdamicAdarIndex {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286impl GpuKernel for AdamicAdarIndex {
287    fn metadata(&self) -> &KernelMetadata {
288        &self.metadata
289    }
290}
291
292// ============================================================================
293// Common Neighbors Kernel
294// ============================================================================
295
296/// Common neighbors kernel.
297///
298/// Simply counts the number of common neighbors: |N(u) ∩ N(v)|
299#[derive(Debug, Clone)]
300pub struct CommonNeighbors {
301    metadata: KernelMetadata,
302}
303
304impl CommonNeighbors {
305    /// Create a new common neighbors kernel.
306    #[must_use]
307    pub fn new() -> Self {
308        Self {
309            metadata: KernelMetadata::batch("graph/common-neighbors", Domain::GraphAnalytics)
310                .with_description("Common neighbors count")
311                .with_throughput(200_000)
312                .with_latency_us(5.0),
313        }
314    }
315
316    /// Count common neighbors between two nodes.
317    pub fn compute_pair(graph: &CsrGraph, node_a: u64, node_b: u64) -> usize {
318        let neighbors_a: HashSet<u64> = graph.neighbors(node_a).iter().copied().collect();
319        let neighbors_b: HashSet<u64> = graph.neighbors(node_b).iter().copied().collect();
320
321        neighbors_a.intersection(&neighbors_b).count()
322    }
323
324    /// Compute common neighbors for all pairs with count >= min_count.
325    pub fn compute_all_pairs(
326        graph: &CsrGraph,
327        min_count: usize,
328        max_pairs: usize,
329    ) -> Vec<SimilarityScore> {
330        let n = graph.num_nodes;
331        let mut results = Vec::new();
332
333        for i in 0..n {
334            for j in (i + 1)..n {
335                let count = Self::compute_pair(graph, i as u64, j as u64);
336
337                if count >= min_count {
338                    results.push(SimilarityScore {
339                        id_a: i as u64,
340                        id_b: j as u64,
341                        similarity: count as f64,
342                    });
343
344                    if results.len() >= max_pairs {
345                        return results;
346                    }
347                }
348            }
349        }
350
351        results.sort_by(|a, b| {
352            b.similarity
353                .partial_cmp(&a.similarity)
354                .unwrap_or(std::cmp::Ordering::Equal)
355        });
356        results
357    }
358}
359
360impl Default for CommonNeighbors {
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366impl GpuKernel for CommonNeighbors {
367    fn metadata(&self) -> &KernelMetadata {
368        &self.metadata
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    fn create_test_graph() -> CsrGraph {
377        // Graph with known overlapping neighbors:
378        //     0 -- 1 -- 2
379        //     |    |    |
380        //     3 -- 4 -- 5
381        CsrGraph::from_edges(
382            6,
383            &[
384                (0, 1),
385                (1, 0),
386                (1, 2),
387                (2, 1),
388                (0, 3),
389                (3, 0),
390                (1, 4),
391                (4, 1),
392                (2, 5),
393                (5, 2),
394                (3, 4),
395                (4, 3),
396                (4, 5),
397                (5, 4),
398            ],
399        )
400    }
401
402    #[test]
403    fn test_jaccard_similarity_metadata() {
404        let kernel = JaccardSimilarity::new();
405        assert_eq!(kernel.metadata().id, "graph/jaccard-similarity");
406        assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
407    }
408
409    #[test]
410    fn test_jaccard_similarity_pair() {
411        let graph = create_test_graph();
412
413        // Nodes 0 and 2: neighbors of 0 = {1, 3}, neighbors of 2 = {1, 5}
414        // Intersection = {1}, Union = {1, 3, 5}
415        // Jaccard = 1/3
416        let sim = JaccardSimilarity::compute_pair(&graph, 0, 2);
417        assert!(
418            (sim - 1.0 / 3.0).abs() < 0.01,
419            "Expected ~0.33, got {}",
420            sim
421        );
422
423        // Self-comparison: identical neighbor sets should give 1.0 if same node
424        // But for different nodes with identical neighbors, it's 1.0
425    }
426
427    #[test]
428    fn test_cosine_similarity_pair() {
429        let graph = create_test_graph();
430
431        // Nodes 0 and 2: common = 1, |N(0)| = 2, |N(2)| = 2
432        // Cosine = 1 / sqrt(2*2) = 0.5
433        let sim = CosineSimilarity::compute_pair(&graph, 0, 2);
434        assert!((sim - 0.5).abs() < 0.01, "Expected 0.5, got {}", sim);
435    }
436
437    #[test]
438    fn test_adamic_adar_pair() {
439        let graph = create_test_graph();
440
441        // Nodes 0 and 2 share neighbor 1
442        // Node 1 has degree = neighbors in CSR format
443        let aa = AdamicAdarIndex::compute_pair(&graph, 0, 2);
444
445        // Adamic-Adar should be positive for nodes with common neighbors
446        assert!(aa > 0.0, "Expected positive Adamic-Adar score, got {}", aa);
447
448        // Nodes with no common neighbors should have 0
449        let aa_no_common = AdamicAdarIndex::compute_pair(&graph, 0, 5);
450        assert_eq!(aa_no_common, 0.0);
451    }
452
453    #[test]
454    fn test_common_neighbors_pair() {
455        let graph = create_test_graph();
456
457        // Nodes 0 and 2 share neighbor 1
458        let count = CommonNeighbors::compute_pair(&graph, 0, 2);
459        assert_eq!(count, 1);
460
461        // Nodes 0 and 1 are connected, check their common neighbors
462        let count = CommonNeighbors::compute_pair(&graph, 0, 1);
463        // N(0) = {1, 3}, N(1) = {0, 2, 4}, intersection = {} (0 and 1 are neighbors, not common neighbors of each other)
464        assert_eq!(count, 0);
465    }
466
467    #[test]
468    fn test_jaccard_all_pairs() {
469        let graph = create_test_graph();
470        let pairs = JaccardSimilarity::compute_all_pairs(&graph, 0.0, 100);
471
472        // Should have pairs for all node combinations
473        assert!(!pairs.is_empty());
474
475        // Should be sorted by similarity descending
476        for i in 1..pairs.len() {
477            assert!(pairs[i - 1].similarity >= pairs[i].similarity);
478        }
479    }
480}
481
482// ============================================================================
483// Value Similarity Kernel
484// ============================================================================
485
486/// Value distribution for similarity calculation.
487#[derive(Debug, Clone)]
488pub struct ValueDistribution {
489    /// Number of nodes.
490    pub node_count: usize,
491    /// Number of histogram bins.
492    pub bin_count: usize,
493    /// Probability distributions in row-major format [node_count × bin_count].
494    /// Each row must sum to 1.0 (normalized histogram).
495    pub distributions: Vec<f64>,
496    /// Bin edges for interpreting distributions (bin_count + 1 values).
497    pub bin_edges: Vec<f64>,
498    /// Binning strategy used.
499    pub strategy: BinningStrategy,
500}
501
502/// Binning strategy for value distributions.
503#[derive(Debug, Clone, Copy, PartialEq, Eq)]
504pub enum BinningStrategy {
505    /// Equal-width bins (uniform spacing).
506    EqualWidth,
507    /// Logarithmic bins (geometric spacing).
508    Logarithmic,
509    /// Quantile bins (equal probability mass).
510    Quantile,
511}
512
513impl ValueDistribution {
514    /// Create a new value distribution.
515    pub fn new(node_count: usize, bin_count: usize) -> Self {
516        Self {
517            node_count,
518            bin_count,
519            distributions: vec![0.0; node_count * bin_count],
520            bin_edges: vec![0.0; bin_count + 1],
521            strategy: BinningStrategy::EqualWidth,
522        }
523    }
524
525    /// Create from raw values using equal-width binning.
526    pub fn from_values(values: &[Vec<f64>], bin_count: usize) -> Self {
527        let node_count = values.len();
528
529        // Find global min/max
530        let (min_val, max_val) = values
531            .iter()
532            .flat_map(|v| v.iter())
533            .fold((f64::INFINITY, f64::NEG_INFINITY), |(min, max), &v| {
534                (min.min(v), max.max(v))
535            });
536
537        let range = max_val - min_val;
538        let bin_width = if range > 0.0 {
539            range / bin_count as f64
540        } else {
541            1.0
542        };
543
544        let mut dist = Self::new(node_count, bin_count);
545
546        // Set bin edges
547        for i in 0..=bin_count {
548            dist.bin_edges[i] = min_val + i as f64 * bin_width;
549        }
550        dist.bin_edges[bin_count] = max_val + 0.001; // Ensure max value is included
551
552        // Compute histograms
553        for (node, node_values) in values.iter().enumerate() {
554            if node_values.is_empty() {
555                continue;
556            }
557
558            for &v in node_values {
559                let bin = ((v - min_val) / bin_width).floor() as usize;
560                let bin = bin.min(bin_count - 1);
561                dist.distributions[node * bin_count + bin] += 1.0;
562            }
563
564            // Normalize
565            let sum: f64 = dist.distributions[node * bin_count..(node + 1) * bin_count]
566                .iter()
567                .sum();
568            if sum > 0.0 {
569                for b in 0..bin_count {
570                    dist.distributions[node * bin_count + b] /= sum;
571                }
572            }
573        }
574
575        dist
576    }
577
578    /// Get distribution for a node.
579    pub fn get_distribution(&self, node: usize) -> &[f64] {
580        let start = node * self.bin_count;
581        &self.distributions[start..start + self.bin_count]
582    }
583}
584
585/// Value similarity result.
586#[derive(Debug, Clone)]
587pub struct ValueSimilarityResult {
588    /// Node A index.
589    pub node_a: usize,
590    /// Node B index.
591    pub node_b: usize,
592    /// Similarity score [0, 1].
593    pub similarity: f64,
594    /// Distance metric value.
595    pub distance: f64,
596}
597
598/// Value similarity kernel.
599///
600/// Compares probability distributions using statistical distance metrics:
601/// - Jensen-Shannon Divergence (JSD)
602/// - Wasserstein Distance (Earth Mover's)
603#[derive(Debug, Clone)]
604pub struct ValueSimilarity {
605    metadata: KernelMetadata,
606}
607
608impl Default for ValueSimilarity {
609    fn default() -> Self {
610        Self::new()
611    }
612}
613
614impl ValueSimilarity {
615    /// Create a new value similarity kernel.
616    #[must_use]
617    pub fn new() -> Self {
618        Self {
619            metadata: KernelMetadata::batch("graph/value-similarity", Domain::GraphAnalytics)
620                .with_description("Value distribution similarity via JSD/Wasserstein")
621                .with_throughput(25_000)
622                .with_latency_us(800.0),
623        }
624    }
625
626    /// Compute Jensen-Shannon Divergence between two distributions.
627    ///
628    /// JSD(P||Q) = 0.5 * KL(P||M) + 0.5 * KL(Q||M)
629    /// where M = 0.5 * (P + Q)
630    pub fn jensen_shannon_divergence(p: &[f64], q: &[f64]) -> f64 {
631        assert_eq!(p.len(), q.len(), "Distributions must have same length");
632
633        let epsilon = 1e-10;
634
635        let mut kl_pm = 0.0;
636        let mut kl_qm = 0.0;
637
638        for i in 0..p.len() {
639            let m = 0.5 * (p[i] + q[i]);
640
641            if p[i] > epsilon && m > epsilon {
642                kl_pm += p[i] * (p[i] / m).ln();
643            }
644            if q[i] > epsilon && m > epsilon {
645                kl_qm += q[i] * (q[i] / m).ln();
646            }
647        }
648
649        0.5 * kl_pm + 0.5 * kl_qm
650    }
651
652    /// Compute similarity from JSD (normalized to [0, 1]).
653    pub fn jsd_similarity(p: &[f64], q: &[f64]) -> f64 {
654        let jsd = Self::jensen_shannon_divergence(p, q);
655        // JSD is in [0, ln(2)], normalize to [0, 1] similarity
656        1.0 - (jsd / 2.0_f64.ln()).sqrt()
657    }
658
659    /// Compute Wasserstein-1 distance (Earth Mover's Distance) for 1D distributions.
660    ///
661    /// For 1D sorted bins: `W1(P,Q) = Σ|CDF_P[i] - CDF_Q[i]|`
662    pub fn wasserstein_distance(p: &[f64], q: &[f64]) -> f64 {
663        assert_eq!(p.len(), q.len(), "Distributions must have same length");
664
665        let mut cdf_p = 0.0;
666        let mut cdf_q = 0.0;
667        let mut w1 = 0.0;
668
669        for i in 0..p.len() {
670            cdf_p += p[i];
671            cdf_q += q[i];
672            w1 += (cdf_p - cdf_q).abs();
673        }
674
675        w1
676    }
677
678    /// Compute similarity from Wasserstein distance.
679    pub fn wasserstein_similarity(p: &[f64], q: &[f64]) -> f64 {
680        let w1 = Self::wasserstein_distance(p, q);
681        // W1 is in [0, n_bins], normalize to [0, 1] similarity
682        1.0 / (1.0 + w1)
683    }
684
685    /// Compute pairwise similarities using JSD.
686    pub fn compute_all_pairs_jsd(
687        distributions: &ValueDistribution,
688        min_similarity: f64,
689        max_pairs: usize,
690    ) -> Vec<ValueSimilarityResult> {
691        let n = distributions.node_count;
692        let mut results = Vec::new();
693
694        for i in 0..n {
695            for j in (i + 1)..n {
696                let p = distributions.get_distribution(i);
697                let q = distributions.get_distribution(j);
698
699                let jsd = Self::jensen_shannon_divergence(p, q);
700                let similarity = 1.0 - (jsd / 2.0_f64.ln()).sqrt();
701
702                if similarity >= min_similarity {
703                    results.push(ValueSimilarityResult {
704                        node_a: i,
705                        node_b: j,
706                        similarity,
707                        distance: jsd,
708                    });
709
710                    if results.len() >= max_pairs {
711                        return results;
712                    }
713                }
714            }
715        }
716
717        // Sort by similarity descending
718        results.sort_by(|a, b| {
719            b.similarity
720                .partial_cmp(&a.similarity)
721                .unwrap_or(std::cmp::Ordering::Equal)
722        });
723
724        results
725    }
726
727    /// Compute pairwise similarities using Wasserstein distance.
728    pub fn compute_all_pairs_wasserstein(
729        distributions: &ValueDistribution,
730        min_similarity: f64,
731        max_pairs: usize,
732    ) -> Vec<ValueSimilarityResult> {
733        let n = distributions.node_count;
734        let mut results = Vec::new();
735
736        for i in 0..n {
737            for j in (i + 1)..n {
738                let p = distributions.get_distribution(i);
739                let q = distributions.get_distribution(j);
740
741                let w1 = Self::wasserstein_distance(p, q);
742                let similarity = 1.0 / (1.0 + w1);
743
744                if similarity >= min_similarity {
745                    results.push(ValueSimilarityResult {
746                        node_a: i,
747                        node_b: j,
748                        similarity,
749                        distance: w1,
750                    });
751
752                    if results.len() >= max_pairs {
753                        return results;
754                    }
755                }
756            }
757        }
758
759        results.sort_by(|a, b| {
760            b.similarity
761                .partial_cmp(&a.similarity)
762                .unwrap_or(std::cmp::Ordering::Equal)
763        });
764
765        results
766    }
767
768    /// Find nodes with similar value distributions.
769    pub fn find_similar_nodes(
770        distributions: &ValueDistribution,
771        target_node: usize,
772        min_similarity: f64,
773        top_k: usize,
774    ) -> Vec<ValueSimilarityResult> {
775        let n = distributions.node_count;
776        let p = distributions.get_distribution(target_node);
777        let mut results = Vec::new();
778
779        for i in 0..n {
780            if i == target_node {
781                continue;
782            }
783
784            let q = distributions.get_distribution(i);
785            let similarity = Self::jsd_similarity(p, q);
786
787            if similarity >= min_similarity {
788                results.push(ValueSimilarityResult {
789                    node_a: target_node,
790                    node_b: i,
791                    similarity,
792                    distance: Self::jensen_shannon_divergence(p, q),
793                });
794            }
795        }
796
797        results.sort_by(|a, b| {
798            b.similarity
799                .partial_cmp(&a.similarity)
800                .unwrap_or(std::cmp::Ordering::Equal)
801        });
802
803        results.into_iter().take(top_k).collect()
804    }
805}
806
807impl GpuKernel for ValueSimilarity {
808    fn metadata(&self) -> &KernelMetadata {
809        &self.metadata
810    }
811}
812
813#[cfg(test)]
814mod value_similarity_tests {
815    use super::*;
816
817    #[test]
818    fn test_value_similarity_metadata() {
819        let kernel = ValueSimilarity::new();
820        assert_eq!(kernel.metadata().id, "graph/value-similarity");
821        assert_eq!(kernel.metadata().domain, Domain::GraphAnalytics);
822    }
823
824    #[test]
825    fn test_jsd_identical_distributions() {
826        let p = vec![0.25, 0.25, 0.25, 0.25];
827        let q = vec![0.25, 0.25, 0.25, 0.25];
828
829        let jsd = ValueSimilarity::jensen_shannon_divergence(&p, &q);
830        assert!(
831            jsd.abs() < 0.001,
832            "JSD of identical distributions should be 0"
833        );
834    }
835
836    #[test]
837    fn test_jsd_different_distributions() {
838        let p = vec![1.0, 0.0, 0.0, 0.0];
839        let q = vec![0.0, 0.0, 0.0, 1.0];
840
841        let jsd = ValueSimilarity::jensen_shannon_divergence(&p, &q);
842        // JSD should be close to ln(2) for maximally different distributions
843        assert!(
844            jsd > 0.6,
845            "JSD should be high for very different distributions"
846        );
847    }
848
849    #[test]
850    fn test_jsd_similarity() {
851        let p = vec![0.25, 0.25, 0.25, 0.25];
852        let q = vec![0.25, 0.25, 0.25, 0.25];
853
854        let sim = ValueSimilarity::jsd_similarity(&p, &q);
855        assert!(
856            (sim - 1.0).abs() < 0.01,
857            "Identical distributions should have similarity 1.0"
858        );
859    }
860
861    #[test]
862    fn test_wasserstein_identical() {
863        let p = vec![0.25, 0.25, 0.25, 0.25];
864        let q = vec![0.25, 0.25, 0.25, 0.25];
865
866        let w1 = ValueSimilarity::wasserstein_distance(&p, &q);
867        assert!(
868            w1.abs() < 0.001,
869            "Wasserstein of identical distributions should be 0"
870        );
871    }
872
873    #[test]
874    fn test_wasserstein_shifted() {
875        let p = vec![1.0, 0.0, 0.0, 0.0];
876        let q = vec![0.0, 1.0, 0.0, 0.0];
877
878        let w1 = ValueSimilarity::wasserstein_distance(&p, &q);
879        // Should be 1.0 (one bin of "earth" moved one position)
880        assert!((w1 - 1.0).abs() < 0.01);
881    }
882
883    #[test]
884    fn test_value_distribution_from_values() {
885        let values = vec![vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0]];
886
887        let dist = ValueDistribution::from_values(&values, 4);
888
889        assert_eq!(dist.node_count, 2);
890        assert_eq!(dist.bin_count, 4);
891
892        // Check normalization
893        let sum0: f64 = dist.get_distribution(0).iter().sum();
894        let sum1: f64 = dist.get_distribution(1).iter().sum();
895        assert!((sum0 - 1.0).abs() < 0.01);
896        assert!((sum1 - 1.0).abs() < 0.01);
897    }
898
899    #[test]
900    fn test_find_similar_nodes() {
901        let values = vec![
902            vec![1.0, 2.0, 3.0],
903            vec![1.0, 2.0, 3.0],    // Same as node 0
904            vec![10.0, 11.0, 12.0], // Different
905        ];
906
907        let dist = ValueDistribution::from_values(&values, 5);
908        let similar = ValueSimilarity::find_similar_nodes(&dist, 0, 0.5, 5);
909
910        // Node 1 should be most similar to node 0
911        assert!(!similar.is_empty());
912        assert_eq!(similar[0].node_b, 1);
913    }
914}