rustkernel_ml/
nlp.rs

1//! Natural Language Processing and LLM integration kernels.
2//!
3//! This module provides GPU-accelerated NLP algorithms:
4//! - EmbeddingGeneration - Text to vector embeddings
5//! - SemanticSimilarity - Document/entity similarity matching
6
7use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11// ============================================================================
12// Embedding Generation Kernel
13// ============================================================================
14
15/// Configuration for embedding generation.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EmbeddingConfig {
18    /// Embedding dimension.
19    pub dimension: usize,
20    /// Maximum sequence length.
21    pub max_seq_length: usize,
22    /// Whether to normalize embeddings.
23    pub normalize: bool,
24    /// Pooling strategy for sequence embeddings.
25    pub pooling: PoolingStrategy,
26    /// Vocabulary size for hash-based embeddings.
27    pub vocab_size: usize,
28}
29
30impl Default for EmbeddingConfig {
31    fn default() -> Self {
32        Self {
33            dimension: 384,
34            max_seq_length: 512,
35            normalize: true,
36            pooling: PoolingStrategy::Mean,
37            vocab_size: 50000,
38        }
39    }
40}
41
42/// Pooling strategy for combining token embeddings.
43#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
44pub enum PoolingStrategy {
45    /// Average of all token embeddings.
46    Mean,
47    /// Max pooling across tokens.
48    Max,
49    /// Use CLS token embedding (first token).
50    CLS,
51    /// Weighted average by attention.
52    AttentionWeighted,
53}
54
55/// Result of embedding generation.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct EmbeddingResult {
58    /// Generated embeddings (one per input text).
59    pub embeddings: Vec<Vec<f64>>,
60    /// Token counts per input.
61    pub token_counts: Vec<usize>,
62    /// Embedding dimension.
63    pub dimension: usize,
64}
65
66/// Embedding Generation kernel.
67///
68/// Generates dense vector embeddings from text using hash-based
69/// token embeddings with configurable pooling strategies.
70/// Suitable for semantic search, clustering, and similarity tasks.
71#[derive(Debug, Clone)]
72pub struct EmbeddingGeneration {
73    metadata: KernelMetadata,
74}
75
76impl Default for EmbeddingGeneration {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82impl EmbeddingGeneration {
83    /// Create a new Embedding Generation kernel.
84    #[must_use]
85    pub fn new() -> Self {
86        Self {
87            metadata: KernelMetadata::batch("ml/embedding-generation", Domain::StatisticalML)
88                .with_description("GPU-accelerated text embedding generation")
89                .with_throughput(10_000)
90                .with_latency_us(50.0),
91        }
92    }
93
94    /// Generate embeddings for a batch of texts.
95    pub fn compute(texts: &[&str], config: &EmbeddingConfig) -> EmbeddingResult {
96        if texts.is_empty() {
97            return EmbeddingResult {
98                embeddings: Vec::new(),
99                token_counts: Vec::new(),
100                dimension: config.dimension,
101            };
102        }
103
104        let mut embeddings = Vec::with_capacity(texts.len());
105        let mut token_counts = Vec::with_capacity(texts.len());
106
107        for text in texts {
108            let tokens = Self::tokenize(text, config.max_seq_length);
109            token_counts.push(tokens.len());
110
111            let token_embeddings: Vec<Vec<f64>> = tokens
112                .iter()
113                .map(|token| Self::hash_embedding(token, config.dimension, config.vocab_size))
114                .collect();
115
116            let pooled = Self::pool_embeddings(&token_embeddings, config);
117
118            let final_embedding = if config.normalize {
119                Self::normalize_vector(&pooled)
120            } else {
121                pooled
122            };
123
124            embeddings.push(final_embedding);
125        }
126
127        EmbeddingResult {
128            embeddings,
129            token_counts,
130            dimension: config.dimension,
131        }
132    }
133
134    /// Simple whitespace tokenization with lowercasing.
135    fn tokenize(text: &str, max_length: usize) -> Vec<String> {
136        text.to_lowercase()
137            .split_whitespace()
138            .take(max_length)
139            .map(|s| s.chars().filter(|c| c.is_alphanumeric()).collect())
140            .filter(|s: &String| !s.is_empty())
141            .collect()
142    }
143
144    /// Generate embedding from token using hash-based approach.
145    fn hash_embedding(token: &str, dimension: usize, vocab_size: usize) -> Vec<f64> {
146        let mut embedding = vec![0.0; dimension];
147
148        // Use multiple hash functions for better distribution
149        let hash1 = Self::hash_token(token, 0) as usize;
150        let hash2 = Self::hash_token(token, 1) as usize;
151        let hash3 = Self::hash_token(token, 2) as usize;
152
153        // Sparse embedding based on hashes
154        for i in 0..dimension {
155            let idx1 = (hash1 + i * 31) % vocab_size;
156            let idx2 = (hash2 + i * 37) % vocab_size;
157            let idx3 = (hash3 + i * 41) % vocab_size;
158
159            // Combine hashes to create embedding value
160            let sign1 = if (idx1 % 2) == 0 { 1.0 } else { -1.0 };
161            let sign2 = if (idx2 % 2) == 0 { 1.0 } else { -1.0 };
162
163            embedding[i] = sign1 * ((idx1 as f64 / vocab_size as f64) - 0.5)
164                + sign2 * ((idx2 as f64 / vocab_size as f64) - 0.5) * 0.5
165                + ((idx3 as f64 / vocab_size as f64) - 0.5) * 0.25;
166        }
167
168        embedding
169    }
170
171    /// Simple hash function for tokens.
172    fn hash_token(token: &str, seed: u64) -> u64 {
173        let mut hash: u64 = seed.wrapping_mul(0x517cc1b727220a95);
174        for byte in token.bytes() {
175            hash = hash.wrapping_mul(31).wrapping_add(byte as u64);
176        }
177        hash
178    }
179
180    /// Pool token embeddings according to strategy.
181    fn pool_embeddings(embeddings: &[Vec<f64>], config: &EmbeddingConfig) -> Vec<f64> {
182        if embeddings.is_empty() {
183            return vec![0.0; config.dimension];
184        }
185
186        match config.pooling {
187            PoolingStrategy::Mean => {
188                let mut result = vec![0.0; config.dimension];
189                for emb in embeddings {
190                    for (i, &v) in emb.iter().enumerate() {
191                        result[i] += v;
192                    }
193                }
194                let n = embeddings.len() as f64;
195                result.iter_mut().for_each(|v| *v /= n);
196                result
197            }
198            PoolingStrategy::Max => {
199                let mut result = vec![f64::NEG_INFINITY; config.dimension];
200                for emb in embeddings {
201                    for (i, &v) in emb.iter().enumerate() {
202                        result[i] = result[i].max(v);
203                    }
204                }
205                result
206            }
207            PoolingStrategy::CLS => embeddings[0].clone(),
208            PoolingStrategy::AttentionWeighted => {
209                // Simple attention: weight by position (earlier = higher weight)
210                let mut result = vec![0.0; config.dimension];
211                let mut total_weight = 0.0;
212
213                for (pos, emb) in embeddings.iter().enumerate() {
214                    let weight = 1.0 / (1.0 + pos as f64 * 0.1);
215                    total_weight += weight;
216                    for (i, &v) in emb.iter().enumerate() {
217                        result[i] += v * weight;
218                    }
219                }
220
221                result.iter_mut().for_each(|v| *v /= total_weight);
222                result
223            }
224        }
225    }
226
227    /// Normalize vector to unit length.
228    fn normalize_vector(v: &[f64]) -> Vec<f64> {
229        let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
230        if norm < 1e-10 {
231            v.to_vec()
232        } else {
233            v.iter().map(|x| x / norm).collect()
234        }
235    }
236}
237
238impl GpuKernel for EmbeddingGeneration {
239    fn metadata(&self) -> &KernelMetadata {
240        &self.metadata
241    }
242}
243
244// ============================================================================
245// Semantic Similarity Kernel
246// ============================================================================
247
248/// Configuration for semantic similarity.
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct SimilarityConfig {
251    /// Similarity metric to use.
252    pub metric: SimilarityMetric,
253    /// Minimum similarity threshold for matches.
254    pub threshold: f64,
255    /// Maximum number of matches to return per query.
256    pub top_k: usize,
257    /// Whether to include self-matches.
258    pub include_self: bool,
259}
260
261impl Default for SimilarityConfig {
262    fn default() -> Self {
263        Self {
264            metric: SimilarityMetric::Cosine,
265            threshold: 0.5,
266            top_k: 10,
267            include_self: false,
268        }
269    }
270}
271
272/// Similarity metric.
273#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
274pub enum SimilarityMetric {
275    /// Cosine similarity (dot product of normalized vectors).
276    Cosine,
277    /// Euclidean distance (converted to similarity).
278    Euclidean,
279    /// Dot product (unnormalized).
280    DotProduct,
281    /// Manhattan distance (converted to similarity).
282    Manhattan,
283}
284
285/// A similarity match result.
286#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct SimilarityMatch {
288    /// Index of the query item.
289    pub query_idx: usize,
290    /// Index of the matched item.
291    pub match_idx: usize,
292    /// Similarity score.
293    pub score: f64,
294}
295
296/// Result of semantic similarity computation.
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct SimilarityResult {
299    /// All matches above threshold.
300    pub matches: Vec<SimilarityMatch>,
301    /// Full similarity matrix (if computed).
302    pub similarity_matrix: Option<Vec<Vec<f64>>>,
303    /// Query embeddings used.
304    pub query_count: usize,
305    /// Corpus embeddings used.
306    pub corpus_count: usize,
307}
308
309/// Semantic Similarity kernel.
310///
311/// Computes semantic similarity between text embeddings for
312/// document matching, entity resolution, and semantic search.
313#[derive(Debug, Clone)]
314pub struct SemanticSimilarity {
315    metadata: KernelMetadata,
316}
317
318impl Default for SemanticSimilarity {
319    fn default() -> Self {
320        Self::new()
321    }
322}
323
324impl SemanticSimilarity {
325    /// Create a new Semantic Similarity kernel.
326    #[must_use]
327    pub fn new() -> Self {
328        Self {
329            metadata: KernelMetadata::batch("ml/semantic-similarity", Domain::StatisticalML)
330                .with_description("Semantic similarity matching for documents and entities")
331                .with_throughput(50_000)
332                .with_latency_us(20.0),
333        }
334    }
335
336    /// Compute similarity between query embeddings and corpus embeddings.
337    pub fn compute(
338        queries: &[Vec<f64>],
339        corpus: &[Vec<f64>],
340        config: &SimilarityConfig,
341    ) -> SimilarityResult {
342        if queries.is_empty() || corpus.is_empty() {
343            return SimilarityResult {
344                matches: Vec::new(),
345                similarity_matrix: None,
346                query_count: queries.len(),
347                corpus_count: corpus.len(),
348            };
349        }
350
351        let mut all_matches: Vec<SimilarityMatch> = Vec::new();
352        let mut similarity_matrix: Vec<Vec<f64>> = Vec::with_capacity(queries.len());
353
354        for (q_idx, query) in queries.iter().enumerate() {
355            let mut row_scores: Vec<(usize, f64)> = Vec::with_capacity(corpus.len());
356
357            for (c_idx, doc) in corpus.iter().enumerate() {
358                if !config.include_self && q_idx == c_idx {
359                    continue;
360                }
361
362                let score = Self::compute_similarity(query, doc, config.metric);
363                row_scores.push((c_idx, score));
364            }
365
366            // Sort by score descending
367            row_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
368
369            // Take top-k above threshold
370            for (c_idx, score) in row_scores.iter().take(config.top_k) {
371                if *score >= config.threshold {
372                    all_matches.push(SimilarityMatch {
373                        query_idx: q_idx,
374                        match_idx: *c_idx,
375                        score: *score,
376                    });
377                }
378            }
379
380            // Build full row for matrix
381            let mut full_row = vec![0.0; corpus.len()];
382            for (c_idx, score) in row_scores {
383                full_row[c_idx] = score;
384            }
385            similarity_matrix.push(full_row);
386        }
387
388        SimilarityResult {
389            matches: all_matches,
390            similarity_matrix: Some(similarity_matrix),
391            query_count: queries.len(),
392            corpus_count: corpus.len(),
393        }
394    }
395
396    /// Find most similar documents for each query.
397    pub fn find_similar(
398        queries: &[Vec<f64>],
399        corpus: &[Vec<f64>],
400        labels: Option<&[String]>,
401        config: &SimilarityConfig,
402    ) -> Vec<Vec<(usize, f64, Option<String>)>> {
403        let result = Self::compute(queries, corpus, config);
404
405        let mut grouped: HashMap<usize, Vec<(usize, f64)>> = HashMap::new();
406        for m in result.matches {
407            grouped
408                .entry(m.query_idx)
409                .or_default()
410                .push((m.match_idx, m.score));
411        }
412
413        // Sort each group by score descending
414        for matches in grouped.values_mut() {
415            matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
416        }
417
418        queries
419            .iter()
420            .enumerate()
421            .map(|(q_idx, _)| {
422                grouped
423                    .get(&q_idx)
424                    .map(|matches| {
425                        matches
426                            .iter()
427                            .map(|(idx, score)| {
428                                let label = labels.map(|l| l.get(*idx).cloned()).flatten();
429                                (*idx, *score, label)
430                            })
431                            .collect()
432                    })
433                    .unwrap_or_default()
434            })
435            .collect()
436    }
437
438    /// Compute pairwise similarity between two vectors.
439    fn compute_similarity(a: &[f64], b: &[f64], metric: SimilarityMetric) -> f64 {
440        if a.len() != b.len() || a.is_empty() {
441            return 0.0;
442        }
443
444        match metric {
445            SimilarityMetric::Cosine => {
446                let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
447                let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
448                let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
449                if norm_a < 1e-10 || norm_b < 1e-10 {
450                    0.0
451                } else {
452                    dot / (norm_a * norm_b)
453                }
454            }
455            SimilarityMetric::Euclidean => {
456                let dist: f64 = a
457                    .iter()
458                    .zip(b.iter())
459                    .map(|(x, y)| (x - y).powi(2))
460                    .sum::<f64>()
461                    .sqrt();
462                1.0 / (1.0 + dist)
463            }
464            SimilarityMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
465            SimilarityMetric::Manhattan => {
466                let dist: f64 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
467                1.0 / (1.0 + dist)
468            }
469        }
470    }
471
472    /// Deduplicate a corpus based on similarity threshold.
473    pub fn deduplicate(embeddings: &[Vec<f64>], threshold: f64) -> Vec<usize> {
474        if embeddings.is_empty() {
475            return Vec::new();
476        }
477
478        let mut keep: Vec<usize> = vec![0]; // Always keep first
479
480        for i in 1..embeddings.len() {
481            let is_duplicate = keep.iter().any(|&j| {
482                let sim = Self::compute_similarity(
483                    &embeddings[i],
484                    &embeddings[j],
485                    SimilarityMetric::Cosine,
486                );
487                sim >= threshold
488            });
489
490            if !is_duplicate {
491                keep.push(i);
492            }
493        }
494
495        keep
496    }
497}
498
499impl GpuKernel for SemanticSimilarity {
500    fn metadata(&self) -> &KernelMetadata {
501        &self.metadata
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_embedding_generation_metadata() {
511        let kernel = EmbeddingGeneration::new();
512        assert_eq!(kernel.metadata().id, "ml/embedding-generation");
513    }
514
515    #[test]
516    fn test_embedding_generation_basic() {
517        let config = EmbeddingConfig::default();
518        let texts = vec!["hello world", "machine learning"];
519
520        let result = EmbeddingGeneration::compute(&texts, &config);
521
522        assert_eq!(result.embeddings.len(), 2);
523        assert_eq!(result.embeddings[0].len(), config.dimension);
524        assert_eq!(result.token_counts, vec![2, 2]);
525    }
526
527    #[test]
528    fn test_embedding_normalization() {
529        let config = EmbeddingConfig {
530            normalize: true,
531            ..Default::default()
532        };
533
534        let result = EmbeddingGeneration::compute(&["test text"], &config);
535
536        let norm: f64 = result.embeddings[0]
537            .iter()
538            .map(|x| x * x)
539            .sum::<f64>()
540            .sqrt();
541        assert!((norm - 1.0).abs() < 0.001);
542    }
543
544    #[test]
545    fn test_embedding_empty() {
546        let config = EmbeddingConfig::default();
547        let result = EmbeddingGeneration::compute(&[], &config);
548        assert!(result.embeddings.is_empty());
549    }
550
551    #[test]
552    fn test_pooling_strategies() {
553        let texts = vec!["a b c d e"];
554
555        for pooling in [
556            PoolingStrategy::Mean,
557            PoolingStrategy::Max,
558            PoolingStrategy::CLS,
559            PoolingStrategy::AttentionWeighted,
560        ] {
561            let config = EmbeddingConfig {
562                pooling,
563                ..Default::default()
564            };
565            let result = EmbeddingGeneration::compute(&texts, &config);
566            assert_eq!(result.embeddings.len(), 1);
567            assert_eq!(result.embeddings[0].len(), config.dimension);
568        }
569    }
570
571    #[test]
572    fn test_semantic_similarity_metadata() {
573        let kernel = SemanticSimilarity::new();
574        assert_eq!(kernel.metadata().id, "ml/semantic-similarity");
575    }
576
577    #[test]
578    fn test_semantic_similarity_basic() {
579        let queries = vec![vec![1.0, 0.0, 0.0]];
580        let corpus = vec![
581            vec![1.0, 0.0, 0.0], // Same as query
582            vec![0.0, 1.0, 0.0], // Orthogonal
583            vec![0.7, 0.7, 0.0], // Partially similar
584        ];
585
586        let config = SimilarityConfig {
587            threshold: 0.0,
588            include_self: true,
589            ..Default::default()
590        };
591
592        let result = SemanticSimilarity::compute(&queries, &corpus, &config);
593
594        assert!(!result.matches.is_empty());
595        // First match should be the identical vector
596        assert_eq!(result.matches[0].match_idx, 0);
597        assert!((result.matches[0].score - 1.0).abs() < 0.001);
598    }
599
600    #[test]
601    fn test_similarity_metrics() {
602        let a = vec![1.0, 2.0, 3.0];
603        let b = vec![1.0, 2.0, 3.0];
604
605        for metric in [
606            SimilarityMetric::Cosine,
607            SimilarityMetric::Euclidean,
608            SimilarityMetric::DotProduct,
609            SimilarityMetric::Manhattan,
610        ] {
611            let sim = SemanticSimilarity::compute_similarity(&a, &b, metric);
612            assert!(
613                sim > 0.0,
614                "Identical vectors should have positive similarity for {:?}",
615                metric
616            );
617        }
618    }
619
620    #[test]
621    fn test_deduplicate() {
622        let embeddings = vec![
623            vec![1.0, 0.0],
624            vec![0.99, 0.01], // Very similar to first
625            vec![0.0, 1.0],   // Different
626            vec![0.01, 0.99], // Very similar to third
627        ];
628
629        let kept = SemanticSimilarity::deduplicate(&embeddings, 0.95);
630
631        assert_eq!(kept.len(), 2);
632        assert!(kept.contains(&0));
633        assert!(kept.contains(&2));
634    }
635
636    #[test]
637    fn test_find_similar_with_labels() {
638        let queries = vec![vec![1.0, 0.0]];
639        let corpus = vec![vec![0.9, 0.1], vec![0.0, 1.0]];
640        let labels = vec!["doc_a".to_string(), "doc_b".to_string()];
641
642        let config = SimilarityConfig {
643            threshold: 0.0,
644            include_self: true, // Include all comparisons since query != corpus
645            ..Default::default()
646        };
647
648        let results = SemanticSimilarity::find_similar(&queries, &corpus, Some(&labels), &config);
649
650        assert_eq!(results.len(), 1);
651        assert!(!results[0].is_empty());
652        // The highest similarity should come first (doc_a has higher cosine sim to query)
653        assert_eq!(results[0][0].2, Some("doc_a".to_string()));
654    }
655}