1use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct EmbeddingConfig {
18 pub dimension: usize,
20 pub max_seq_length: usize,
22 pub normalize: bool,
24 pub pooling: PoolingStrategy,
26 pub vocab_size: usize,
28}
29
30impl Default for EmbeddingConfig {
31 fn default() -> Self {
32 Self {
33 dimension: 384,
34 max_seq_length: 512,
35 normalize: true,
36 pooling: PoolingStrategy::Mean,
37 vocab_size: 50000,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
44pub enum PoolingStrategy {
45 Mean,
47 Max,
49 CLS,
51 AttentionWeighted,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct EmbeddingResult {
58 pub embeddings: Vec<Vec<f64>>,
60 pub token_counts: Vec<usize>,
62 pub dimension: usize,
64}
65
66#[derive(Debug, Clone)]
72pub struct EmbeddingGeneration {
73 metadata: KernelMetadata,
74}
75
76impl Default for EmbeddingGeneration {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82impl EmbeddingGeneration {
83 #[must_use]
85 pub fn new() -> Self {
86 Self {
87 metadata: KernelMetadata::batch("ml/embedding-generation", Domain::StatisticalML)
88 .with_description("GPU-accelerated text embedding generation")
89 .with_throughput(10_000)
90 .with_latency_us(50.0),
91 }
92 }
93
94 pub fn compute(texts: &[&str], config: &EmbeddingConfig) -> EmbeddingResult {
96 if texts.is_empty() {
97 return EmbeddingResult {
98 embeddings: Vec::new(),
99 token_counts: Vec::new(),
100 dimension: config.dimension,
101 };
102 }
103
104 let mut embeddings = Vec::with_capacity(texts.len());
105 let mut token_counts = Vec::with_capacity(texts.len());
106
107 for text in texts {
108 let tokens = Self::tokenize(text, config.max_seq_length);
109 token_counts.push(tokens.len());
110
111 let token_embeddings: Vec<Vec<f64>> = tokens
112 .iter()
113 .map(|token| Self::hash_embedding(token, config.dimension, config.vocab_size))
114 .collect();
115
116 let pooled = Self::pool_embeddings(&token_embeddings, config);
117
118 let final_embedding = if config.normalize {
119 Self::normalize_vector(&pooled)
120 } else {
121 pooled
122 };
123
124 embeddings.push(final_embedding);
125 }
126
127 EmbeddingResult {
128 embeddings,
129 token_counts,
130 dimension: config.dimension,
131 }
132 }
133
134 fn tokenize(text: &str, max_length: usize) -> Vec<String> {
136 text.to_lowercase()
137 .split_whitespace()
138 .take(max_length)
139 .map(|s| s.chars().filter(|c| c.is_alphanumeric()).collect())
140 .filter(|s: &String| !s.is_empty())
141 .collect()
142 }
143
144 fn hash_embedding(token: &str, dimension: usize, vocab_size: usize) -> Vec<f64> {
146 let mut embedding = vec![0.0; dimension];
147
148 let hash1 = Self::hash_token(token, 0) as usize;
150 let hash2 = Self::hash_token(token, 1) as usize;
151 let hash3 = Self::hash_token(token, 2) as usize;
152
153 for i in 0..dimension {
155 let idx1 = (hash1 + i * 31) % vocab_size;
156 let idx2 = (hash2 + i * 37) % vocab_size;
157 let idx3 = (hash3 + i * 41) % vocab_size;
158
159 let sign1 = if (idx1 % 2) == 0 { 1.0 } else { -1.0 };
161 let sign2 = if (idx2 % 2) == 0 { 1.0 } else { -1.0 };
162
163 embedding[i] = sign1 * ((idx1 as f64 / vocab_size as f64) - 0.5)
164 + sign2 * ((idx2 as f64 / vocab_size as f64) - 0.5) * 0.5
165 + ((idx3 as f64 / vocab_size as f64) - 0.5) * 0.25;
166 }
167
168 embedding
169 }
170
171 fn hash_token(token: &str, seed: u64) -> u64 {
173 let mut hash: u64 = seed.wrapping_mul(0x517cc1b727220a95);
174 for byte in token.bytes() {
175 hash = hash.wrapping_mul(31).wrapping_add(byte as u64);
176 }
177 hash
178 }
179
180 fn pool_embeddings(embeddings: &[Vec<f64>], config: &EmbeddingConfig) -> Vec<f64> {
182 if embeddings.is_empty() {
183 return vec![0.0; config.dimension];
184 }
185
186 match config.pooling {
187 PoolingStrategy::Mean => {
188 let mut result = vec![0.0; config.dimension];
189 for emb in embeddings {
190 for (i, &v) in emb.iter().enumerate() {
191 result[i] += v;
192 }
193 }
194 let n = embeddings.len() as f64;
195 result.iter_mut().for_each(|v| *v /= n);
196 result
197 }
198 PoolingStrategy::Max => {
199 let mut result = vec![f64::NEG_INFINITY; config.dimension];
200 for emb in embeddings {
201 for (i, &v) in emb.iter().enumerate() {
202 result[i] = result[i].max(v);
203 }
204 }
205 result
206 }
207 PoolingStrategy::CLS => embeddings[0].clone(),
208 PoolingStrategy::AttentionWeighted => {
209 let mut result = vec![0.0; config.dimension];
211 let mut total_weight = 0.0;
212
213 for (pos, emb) in embeddings.iter().enumerate() {
214 let weight = 1.0 / (1.0 + pos as f64 * 0.1);
215 total_weight += weight;
216 for (i, &v) in emb.iter().enumerate() {
217 result[i] += v * weight;
218 }
219 }
220
221 result.iter_mut().for_each(|v| *v /= total_weight);
222 result
223 }
224 }
225 }
226
227 fn normalize_vector(v: &[f64]) -> Vec<f64> {
229 let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
230 if norm < 1e-10 {
231 v.to_vec()
232 } else {
233 v.iter().map(|x| x / norm).collect()
234 }
235 }
236}
237
238impl GpuKernel for EmbeddingGeneration {
239 fn metadata(&self) -> &KernelMetadata {
240 &self.metadata
241 }
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct SimilarityConfig {
251 pub metric: SimilarityMetric,
253 pub threshold: f64,
255 pub top_k: usize,
257 pub include_self: bool,
259}
260
261impl Default for SimilarityConfig {
262 fn default() -> Self {
263 Self {
264 metric: SimilarityMetric::Cosine,
265 threshold: 0.5,
266 top_k: 10,
267 include_self: false,
268 }
269 }
270}
271
272#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
274pub enum SimilarityMetric {
275 Cosine,
277 Euclidean,
279 DotProduct,
281 Manhattan,
283}
284
285#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct SimilarityMatch {
288 pub query_idx: usize,
290 pub match_idx: usize,
292 pub score: f64,
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct SimilarityResult {
299 pub matches: Vec<SimilarityMatch>,
301 pub similarity_matrix: Option<Vec<Vec<f64>>>,
303 pub query_count: usize,
305 pub corpus_count: usize,
307}
308
309#[derive(Debug, Clone)]
314pub struct SemanticSimilarity {
315 metadata: KernelMetadata,
316}
317
318impl Default for SemanticSimilarity {
319 fn default() -> Self {
320 Self::new()
321 }
322}
323
324impl SemanticSimilarity {
325 #[must_use]
327 pub fn new() -> Self {
328 Self {
329 metadata: KernelMetadata::batch("ml/semantic-similarity", Domain::StatisticalML)
330 .with_description("Semantic similarity matching for documents and entities")
331 .with_throughput(50_000)
332 .with_latency_us(20.0),
333 }
334 }
335
336 pub fn compute(
338 queries: &[Vec<f64>],
339 corpus: &[Vec<f64>],
340 config: &SimilarityConfig,
341 ) -> SimilarityResult {
342 if queries.is_empty() || corpus.is_empty() {
343 return SimilarityResult {
344 matches: Vec::new(),
345 similarity_matrix: None,
346 query_count: queries.len(),
347 corpus_count: corpus.len(),
348 };
349 }
350
351 let mut all_matches: Vec<SimilarityMatch> = Vec::new();
352 let mut similarity_matrix: Vec<Vec<f64>> = Vec::with_capacity(queries.len());
353
354 for (q_idx, query) in queries.iter().enumerate() {
355 let mut row_scores: Vec<(usize, f64)> = Vec::with_capacity(corpus.len());
356
357 for (c_idx, doc) in corpus.iter().enumerate() {
358 if !config.include_self && q_idx == c_idx {
359 continue;
360 }
361
362 let score = Self::compute_similarity(query, doc, config.metric);
363 row_scores.push((c_idx, score));
364 }
365
366 row_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
368
369 for (c_idx, score) in row_scores.iter().take(config.top_k) {
371 if *score >= config.threshold {
372 all_matches.push(SimilarityMatch {
373 query_idx: q_idx,
374 match_idx: *c_idx,
375 score: *score,
376 });
377 }
378 }
379
380 let mut full_row = vec![0.0; corpus.len()];
382 for (c_idx, score) in row_scores {
383 full_row[c_idx] = score;
384 }
385 similarity_matrix.push(full_row);
386 }
387
388 SimilarityResult {
389 matches: all_matches,
390 similarity_matrix: Some(similarity_matrix),
391 query_count: queries.len(),
392 corpus_count: corpus.len(),
393 }
394 }
395
396 pub fn find_similar(
398 queries: &[Vec<f64>],
399 corpus: &[Vec<f64>],
400 labels: Option<&[String]>,
401 config: &SimilarityConfig,
402 ) -> Vec<Vec<(usize, f64, Option<String>)>> {
403 let result = Self::compute(queries, corpus, config);
404
405 let mut grouped: HashMap<usize, Vec<(usize, f64)>> = HashMap::new();
406 for m in result.matches {
407 grouped
408 .entry(m.query_idx)
409 .or_default()
410 .push((m.match_idx, m.score));
411 }
412
413 for matches in grouped.values_mut() {
415 matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
416 }
417
418 queries
419 .iter()
420 .enumerate()
421 .map(|(q_idx, _)| {
422 grouped
423 .get(&q_idx)
424 .map(|matches| {
425 matches
426 .iter()
427 .map(|(idx, score)| {
428 let label = labels.map(|l| l.get(*idx).cloned()).flatten();
429 (*idx, *score, label)
430 })
431 .collect()
432 })
433 .unwrap_or_default()
434 })
435 .collect()
436 }
437
438 fn compute_similarity(a: &[f64], b: &[f64], metric: SimilarityMetric) -> f64 {
440 if a.len() != b.len() || a.is_empty() {
441 return 0.0;
442 }
443
444 match metric {
445 SimilarityMetric::Cosine => {
446 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
447 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
448 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
449 if norm_a < 1e-10 || norm_b < 1e-10 {
450 0.0
451 } else {
452 dot / (norm_a * norm_b)
453 }
454 }
455 SimilarityMetric::Euclidean => {
456 let dist: f64 = a
457 .iter()
458 .zip(b.iter())
459 .map(|(x, y)| (x - y).powi(2))
460 .sum::<f64>()
461 .sqrt();
462 1.0 / (1.0 + dist)
463 }
464 SimilarityMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
465 SimilarityMetric::Manhattan => {
466 let dist: f64 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
467 1.0 / (1.0 + dist)
468 }
469 }
470 }
471
472 pub fn deduplicate(embeddings: &[Vec<f64>], threshold: f64) -> Vec<usize> {
474 if embeddings.is_empty() {
475 return Vec::new();
476 }
477
478 let mut keep: Vec<usize> = vec![0]; for i in 1..embeddings.len() {
481 let is_duplicate = keep.iter().any(|&j| {
482 let sim = Self::compute_similarity(
483 &embeddings[i],
484 &embeddings[j],
485 SimilarityMetric::Cosine,
486 );
487 sim >= threshold
488 });
489
490 if !is_duplicate {
491 keep.push(i);
492 }
493 }
494
495 keep
496 }
497}
498
499impl GpuKernel for SemanticSimilarity {
500 fn metadata(&self) -> &KernelMetadata {
501 &self.metadata
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_embedding_generation_metadata() {
511 let kernel = EmbeddingGeneration::new();
512 assert_eq!(kernel.metadata().id, "ml/embedding-generation");
513 }
514
515 #[test]
516 fn test_embedding_generation_basic() {
517 let config = EmbeddingConfig::default();
518 let texts = vec!["hello world", "machine learning"];
519
520 let result = EmbeddingGeneration::compute(&texts, &config);
521
522 assert_eq!(result.embeddings.len(), 2);
523 assert_eq!(result.embeddings[0].len(), config.dimension);
524 assert_eq!(result.token_counts, vec![2, 2]);
525 }
526
527 #[test]
528 fn test_embedding_normalization() {
529 let config = EmbeddingConfig {
530 normalize: true,
531 ..Default::default()
532 };
533
534 let result = EmbeddingGeneration::compute(&["test text"], &config);
535
536 let norm: f64 = result.embeddings[0]
537 .iter()
538 .map(|x| x * x)
539 .sum::<f64>()
540 .sqrt();
541 assert!((norm - 1.0).abs() < 0.001);
542 }
543
544 #[test]
545 fn test_embedding_empty() {
546 let config = EmbeddingConfig::default();
547 let result = EmbeddingGeneration::compute(&[], &config);
548 assert!(result.embeddings.is_empty());
549 }
550
551 #[test]
552 fn test_pooling_strategies() {
553 let texts = vec!["a b c d e"];
554
555 for pooling in [
556 PoolingStrategy::Mean,
557 PoolingStrategy::Max,
558 PoolingStrategy::CLS,
559 PoolingStrategy::AttentionWeighted,
560 ] {
561 let config = EmbeddingConfig {
562 pooling,
563 ..Default::default()
564 };
565 let result = EmbeddingGeneration::compute(&texts, &config);
566 assert_eq!(result.embeddings.len(), 1);
567 assert_eq!(result.embeddings[0].len(), config.dimension);
568 }
569 }
570
571 #[test]
572 fn test_semantic_similarity_metadata() {
573 let kernel = SemanticSimilarity::new();
574 assert_eq!(kernel.metadata().id, "ml/semantic-similarity");
575 }
576
577 #[test]
578 fn test_semantic_similarity_basic() {
579 let queries = vec![vec![1.0, 0.0, 0.0]];
580 let corpus = vec![
581 vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0], vec![0.7, 0.7, 0.0], ];
585
586 let config = SimilarityConfig {
587 threshold: 0.0,
588 include_self: true,
589 ..Default::default()
590 };
591
592 let result = SemanticSimilarity::compute(&queries, &corpus, &config);
593
594 assert!(!result.matches.is_empty());
595 assert_eq!(result.matches[0].match_idx, 0);
597 assert!((result.matches[0].score - 1.0).abs() < 0.001);
598 }
599
600 #[test]
601 fn test_similarity_metrics() {
602 let a = vec![1.0, 2.0, 3.0];
603 let b = vec![1.0, 2.0, 3.0];
604
605 for metric in [
606 SimilarityMetric::Cosine,
607 SimilarityMetric::Euclidean,
608 SimilarityMetric::DotProduct,
609 SimilarityMetric::Manhattan,
610 ] {
611 let sim = SemanticSimilarity::compute_similarity(&a, &b, metric);
612 assert!(
613 sim > 0.0,
614 "Identical vectors should have positive similarity for {:?}",
615 metric
616 );
617 }
618 }
619
620 #[test]
621 fn test_deduplicate() {
622 let embeddings = vec![
623 vec![1.0, 0.0],
624 vec![0.99, 0.01], vec![0.0, 1.0], vec![0.01, 0.99], ];
628
629 let kept = SemanticSimilarity::deduplicate(&embeddings, 0.95);
630
631 assert_eq!(kept.len(), 2);
632 assert!(kept.contains(&0));
633 assert!(kept.contains(&2));
634 }
635
636 #[test]
637 fn test_find_similar_with_labels() {
638 let queries = vec![vec![1.0, 0.0]];
639 let corpus = vec![vec![0.9, 0.1], vec![0.0, 1.0]];
640 let labels = vec!["doc_a".to_string(), "doc_b".to_string()];
641
642 let config = SimilarityConfig {
643 threshold: 0.0,
644 include_self: true, ..Default::default()
646 };
647
648 let results = SemanticSimilarity::find_similar(&queries, &corpus, Some(&labels), &config);
649
650 assert_eq!(results.len(), 1);
651 assert!(!results[0].is_empty());
652 assert_eq!(results[0][0].2, Some("doc_a".to_string()));
654 }
655}