1use std::collections::HashMap;
6
7pub struct WordTokenizer {
9 pub lowercase: bool,
10 pub remove_punctuation: bool,
11 pub min_word_length: usize,
12 vocab_: Option<HashMap<String, usize>>,
13 index_to_word_: Option<Vec<String>>,
14}
15
16impl WordTokenizer {
17 pub fn new() -> Self {
18 WordTokenizer {
19 lowercase: true,
20 remove_punctuation: true,
21 min_word_length: 1,
22 vocab_: None,
23 index_to_word_: None,
24 }
25 }
26
27 pub fn lowercase(mut self, lowercase: bool) -> Self {
28 self.lowercase = lowercase;
29 self
30 }
31
32 pub fn remove_punctuation(mut self, remove: bool) -> Self {
33 self.remove_punctuation = remove;
34 self
35 }
36
37 pub fn min_word_length(mut self, length: usize) -> Self {
38 self.min_word_length = length;
39 self
40 }
41
42 fn preprocess(&self, text: &str) -> String {
43 let mut processed = text.to_string();
44
45 if self.lowercase {
46 processed = processed.to_lowercase();
47 }
48
49 if self.remove_punctuation {
50 processed = processed.chars()
51 .filter(|c| c.is_alphanumeric() || c.is_whitespace())
52 .collect();
53 }
54
55 processed
56 }
57
58 pub fn tokenize(&self, text: &str) -> Vec<String> {
59 let processed = self.preprocess(text);
60
61 processed.split_whitespace()
62 .filter(|word| word.len() >= self.min_word_length)
63 .map(|s| s.to_string())
64 .collect()
65 }
66
67 pub fn fit(&mut self, texts: &[String]) {
68 let mut vocab = HashMap::new();
69 let mut index = 0;
70
71 for text in texts {
72 let tokens = self.tokenize(text);
73 for token in tokens {
74 if !vocab.contains_key(&token) {
75 vocab.insert(token.clone(), index);
76 index += 1;
77 }
78 }
79 }
80
81 let mut index_to_word = vec![String::new(); vocab.len()];
82 for (word, &idx) in &vocab {
83 index_to_word[idx] = word.clone();
84 }
85
86 self.vocab_ = Some(vocab);
87 self.index_to_word_ = Some(index_to_word);
88 }
89
90 pub fn texts_to_sequences(&self, texts: &[String]) -> Vec<Vec<usize>> {
91 let vocab = self.vocab_.as_ref().expect("Tokenizer not fitted");
92
93 texts.iter()
94 .map(|text| {
95 self.tokenize(text)
96 .iter()
97 .filter_map(|token| vocab.get(token).copied())
98 .collect()
99 })
100 .collect()
101 }
102
103 pub fn sequences_to_texts(&self, sequences: &[Vec<usize>]) -> Vec<String> {
104 let index_to_word = self.index_to_word_.as_ref().expect("Tokenizer not fitted");
105
106 sequences.iter()
107 .map(|seq| {
108 seq.iter()
109 .filter_map(|&idx| {
110 if idx < index_to_word.len() {
111 Some(index_to_word[idx].clone())
112 } else {
113 None
114 }
115 })
116 .collect::<Vec<_>>()
117 .join(" ")
118 })
119 .collect()
120 }
121
122 pub fn vocab_size(&self) -> usize {
123 self.vocab_.as_ref().map(|v| v.len()).unwrap_or(0)
124 }
125
126 pub fn vocab(&self) -> Option<&HashMap<String, usize>> {
127 self.vocab_.as_ref()
128 }
129}
130
131impl Default for WordTokenizer {
132 fn default() -> Self { Self::new() }
133}
134
135pub struct CharTokenizer {
137 pub lowercase: bool,
138 char_to_idx_: Option<HashMap<char, usize>>,
139 idx_to_char_: Option<Vec<char>>,
140}
141
142impl CharTokenizer {
143 pub fn new() -> Self {
144 CharTokenizer {
145 lowercase: true,
146 char_to_idx_: None,
147 idx_to_char_: None,
148 }
149 }
150
151 pub fn lowercase(mut self, lowercase: bool) -> Self {
152 self.lowercase = lowercase;
153 self
154 }
155
156 pub fn fit(&mut self, texts: &[String]) {
157 let mut chars = std::collections::HashSet::new();
158
159 for text in texts {
160 let processed = if self.lowercase {
161 text.to_lowercase()
162 } else {
163 text.clone()
164 };
165
166 for c in processed.chars() {
167 chars.insert(c);
168 }
169 }
170
171 let mut char_to_idx = HashMap::new();
172 let mut idx_to_char = Vec::new();
173
174 for (i, c) in chars.into_iter().enumerate() {
175 char_to_idx.insert(c, i);
176 idx_to_char.push(c);
177 }
178
179 self.char_to_idx_ = Some(char_to_idx);
180 self.idx_to_char_ = Some(idx_to_char);
181 }
182
183 pub fn texts_to_sequences(&self, texts: &[String]) -> Vec<Vec<usize>> {
184 let char_to_idx = self.char_to_idx_.as_ref().expect("Tokenizer not fitted");
185
186 texts.iter()
187 .map(|text| {
188 let processed = if self.lowercase {
189 text.to_lowercase()
190 } else {
191 text.clone()
192 };
193
194 processed.chars()
195 .filter_map(|c| char_to_idx.get(&c).copied())
196 .collect()
197 })
198 .collect()
199 }
200
201 pub fn sequences_to_texts(&self, sequences: &[Vec<usize>]) -> Vec<String> {
202 let idx_to_char = self.idx_to_char_.as_ref().expect("Tokenizer not fitted");
203
204 sequences.iter()
205 .map(|seq| {
206 seq.iter()
207 .filter_map(|&idx| {
208 if idx < idx_to_char.len() {
209 Some(idx_to_char[idx])
210 } else {
211 None
212 }
213 })
214 .collect()
215 })
216 .collect()
217 }
218
219 pub fn vocab_size(&self) -> usize {
220 self.char_to_idx_.as_ref().map(|v| v.len()).unwrap_or(0)
221 }
222}
223
224impl Default for CharTokenizer {
225 fn default() -> Self { Self::new() }
226}
227
228pub struct BPETokenizer {
230 pub vocab_size: usize,
231 merges_: Vec<(String, String)>,
232 vocab_: HashMap<String, usize>,
233}
234
235impl BPETokenizer {
236 pub fn new(vocab_size: usize) -> Self {
237 BPETokenizer {
238 vocab_size,
239 merges_: Vec::new(),
240 vocab_: HashMap::new(),
241 }
242 }
243
244 fn get_pairs(word: &[String]) -> Vec<(String, String)> {
245 let mut pairs = Vec::new();
246 for i in 0..word.len().saturating_sub(1) {
247 pairs.push((word[i].clone(), word[i + 1].clone()));
248 }
249 pairs
250 }
251
252 pub fn fit(&mut self, texts: &[String]) {
253 let mut vocab: HashMap<Vec<String>, usize> = HashMap::new();
255
256 for text in texts {
257 for word in text.split_whitespace() {
258 let chars: Vec<String> = word.chars().map(|c| c.to_string()).collect();
259 *vocab.entry(chars).or_insert(0) += 1;
260 }
261 }
262
263 for _ in 0..self.vocab_size {
265 let mut pair_freqs: HashMap<(String, String), usize> = HashMap::new();
266
267 for (word, &freq) in &vocab {
268 let pairs = Self::get_pairs(word);
269 for pair in pairs {
270 *pair_freqs.entry(pair).or_insert(0) += freq;
271 }
272 }
273
274 if pair_freqs.is_empty() {
275 break;
276 }
277
278 let best_pair = pair_freqs.iter()
279 .max_by_key(|(_, &freq)| freq)
280 .map(|(pair, _)| pair.clone());
281
282 if let Some((first, second)) = best_pair {
283 self.merges_.push((first.clone(), second.clone()));
284
285 let mut new_vocab = HashMap::new();
287 for (word, freq) in vocab {
288 let new_word = self.merge_pair(&word, &first, &second);
289 new_vocab.insert(new_word, freq);
290 }
291 vocab = new_vocab;
292 } else {
293 break;
294 }
295 }
296
297 let mut idx = 0;
299 for (word, _) in vocab {
300 for token in word {
301 if !self.vocab_.contains_key(&token) {
302 self.vocab_.insert(token, idx);
303 idx += 1;
304 }
305 }
306 }
307 }
308
309 fn merge_pair(&self, word: &[String], first: &str, second: &str) -> Vec<String> {
310 let mut result = Vec::new();
311 let mut i = 0;
312
313 while i < word.len() {
314 if i < word.len() - 1 && word[i] == first && word[i + 1] == second {
315 result.push(format!("{}{}", first, second));
316 i += 2;
317 } else {
318 result.push(word[i].clone());
319 i += 1;
320 }
321 }
322
323 result
324 }
325
326 pub fn encode(&self, text: &str) -> Vec<usize> {
327 let mut tokens = Vec::new();
328
329 for word in text.split_whitespace() {
330 let mut chars: Vec<String> = word.chars().map(|c| c.to_string()).collect();
331
332 for (first, second) in &self.merges_ {
333 chars = self.merge_pair(&chars, first, second);
334 }
335
336 for token in chars {
337 if let Some(&idx) = self.vocab_.get(&token) {
338 tokens.push(idx);
339 }
340 }
341 }
342
343 tokens
344 }
345
346 pub fn vocab_size_actual(&self) -> usize {
347 self.vocab_.len()
348 }
349}
350
351pub struct TfidfVectorizer {
353 pub max_features: Option<usize>,
354 pub min_df: usize,
355 pub max_df: f32,
356 tokenizer: WordTokenizer,
357 idf_: Option<Vec<f32>>,
358 vocab_: Option<HashMap<String, usize>>,
359}
360
361impl TfidfVectorizer {
362 pub fn new() -> Self {
363 TfidfVectorizer {
364 max_features: None,
365 min_df: 1,
366 max_df: 1.0,
367 tokenizer: WordTokenizer::new(),
368 idf_: None,
369 vocab_: None,
370 }
371 }
372
373 pub fn max_features(mut self, max: usize) -> Self {
374 self.max_features = Some(max);
375 self
376 }
377
378 pub fn min_df(mut self, min: usize) -> Self {
379 self.min_df = min;
380 self
381 }
382
383 pub fn fit(&mut self, texts: &[String]) {
384 let mut doc_freq: HashMap<String, usize> = HashMap::new();
386 let n_docs = texts.len();
387
388 for text in texts {
389 let tokens = self.tokenizer.tokenize(text);
390 let unique_tokens: std::collections::HashSet<_> = tokens.into_iter().collect();
391
392 for token in unique_tokens {
393 *doc_freq.entry(token).or_insert(0) += 1;
394 }
395 }
396
397 let max_df_count = (self.max_df * n_docs as f32) as usize;
399 let mut filtered_vocab: Vec<(String, usize)> = doc_freq.into_iter()
400 .filter(|(_, freq)| *freq >= self.min_df && *freq <= max_df_count)
401 .collect();
402
403 if let Some(max_feat) = self.max_features {
405 filtered_vocab.sort_by(|a, b| b.1.cmp(&a.1));
406 filtered_vocab.truncate(max_feat);
407 }
408
409 let mut vocab = HashMap::new();
411 let mut idf = Vec::new();
412
413 for (i, (term, df)) in filtered_vocab.iter().enumerate() {
414 vocab.insert(term.clone(), i);
415 let idf_value = ((n_docs + 1) as f32 / (*df + 1) as f32).ln() + 1.0;
417 idf.push(idf_value);
418 }
419
420 self.vocab_ = Some(vocab);
421 self.idf_ = Some(idf);
422 }
423
424 pub fn transform(&self, texts: &[String]) -> Vec<Vec<f32>> {
425 let vocab = self.vocab_.as_ref().expect("Vectorizer not fitted");
426 let idf = self.idf_.as_ref().unwrap();
427 let vocab_size = vocab.len();
428
429 texts.iter()
430 .map(|text| {
431 let tokens = self.tokenizer.tokenize(text);
432 let mut tf = vec![0.0f32; vocab_size];
433
434 for token in &tokens {
436 if let Some(&idx) = vocab.get(token) {
437 tf[idx] += 1.0;
438 }
439 }
440
441 let total: f32 = tf.iter().sum();
443 if total > 0.0 {
444 for t in &mut tf {
445 *t /= total;
446 }
447 }
448
449 for (i, t) in tf.iter_mut().enumerate() {
451 *t *= idf[i];
452 }
453
454 let norm: f32 = tf.iter().map(|&x| x * x).sum::<f32>().sqrt();
456 if norm > 0.0 {
457 for t in &mut tf {
458 *t /= norm;
459 }
460 }
461
462 tf
463 })
464 .collect()
465 }
466
467 pub fn fit_transform(&mut self, texts: &[String]) -> Vec<Vec<f32>> {
468 self.fit(texts);
469 self.transform(texts)
470 }
471
472 pub fn vocab_size(&self) -> usize {
473 self.vocab_.as_ref().map(|v| v.len()).unwrap_or(0)
474 }
475}
476
477impl Default for TfidfVectorizer {
478 fn default() -> Self { Self::new() }
479}
480
481pub struct Word2Vec {
483 pub embedding_dim: usize,
484 pub window_size: usize,
485 pub min_count: usize,
486 pub learning_rate: f32,
487 pub epochs: usize,
488 embeddings_: Option<Vec<Vec<f32>>>,
489 vocab_: Option<HashMap<String, usize>>,
490}
491
492impl Word2Vec {
493 pub fn new(embedding_dim: usize) -> Self {
494 Word2Vec {
495 embedding_dim,
496 window_size: 5,
497 min_count: 5,
498 learning_rate: 0.025,
499 epochs: 5,
500 embeddings_: None,
501 vocab_: None,
502 }
503 }
504
505 pub fn window_size(mut self, size: usize) -> Self {
506 self.window_size = size;
507 self
508 }
509
510 pub fn min_count(mut self, count: usize) -> Self {
511 self.min_count = count;
512 self
513 }
514
515 pub fn fit(&mut self, texts: &[String]) {
516 let mut word_counts: HashMap<String, usize> = HashMap::new();
518 let tokenizer = WordTokenizer::new();
519
520 for text in texts {
521 let tokens = tokenizer.tokenize(text);
522 for token in tokens {
523 *word_counts.entry(token).or_insert(0) += 1;
524 }
525 }
526
527 let mut vocab = HashMap::new();
529 let mut idx = 0;
530 for (word, count) in word_counts {
531 if count >= self.min_count {
532 vocab.insert(word, idx);
533 idx += 1;
534 }
535 }
536
537 let vocab_size = vocab.len();
538
539 use rand::prelude::*;
541 let mut rng = thread_rng();
542 let mut embeddings = vec![vec![0.0f32; self.embedding_dim]; vocab_size];
543
544 for emb in &mut embeddings {
545 for val in emb {
546 *val = (rng.gen::<f32>() - 0.5) / self.embedding_dim as f32;
547 }
548 }
549
550 for _epoch in 0..self.epochs {
552 for text in texts {
553 let tokens = tokenizer.tokenize(text);
554 let indices: Vec<usize> = tokens.iter()
555 .filter_map(|t| vocab.get(t).copied())
556 .collect();
557
558 for (i, ¢er_idx) in indices.iter().enumerate() {
559 let start = i.saturating_sub(self.window_size);
560 let end = (i + self.window_size + 1).min(indices.len());
561
562 for j in start..end {
563 if i == j { continue; }
564 let context_idx = indices[j];
565
566 for d in 0..self.embedding_dim {
568 let grad = embeddings[context_idx][d] - embeddings[center_idx][d];
569 embeddings[center_idx][d] += self.learning_rate * grad * 0.01;
570 }
571 }
572 }
573 }
574 }
575
576 self.embeddings_ = Some(embeddings);
577 self.vocab_ = Some(vocab);
578 }
579
580 pub fn get_vector(&self, word: &str) -> Option<&[f32]> {
581 let vocab = self.vocab_.as_ref()?;
582 let embeddings = self.embeddings_.as_ref()?;
583 let idx = vocab.get(word)?;
584 Some(&embeddings[*idx])
585 }
586
587 pub fn similarity(&self, word1: &str, word2: &str) -> Option<f32> {
588 let vec1 = self.get_vector(word1)?;
589 let vec2 = self.get_vector(word2)?;
590
591 let dot: f32 = vec1.iter().zip(vec2.iter()).map(|(a, b)| a * b).sum();
592 let norm1: f32 = vec1.iter().map(|x| x * x).sum::<f32>().sqrt();
593 let norm2: f32 = vec2.iter().map(|x| x * x).sum::<f32>().sqrt();
594
595 Some(dot / (norm1 * norm2).max(1e-10))
596 }
597
598 pub fn vocab_size(&self) -> usize {
599 self.vocab_.as_ref().map(|v| v.len()).unwrap_or(0)
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn test_word_tokenizer() {
609 let mut tokenizer = WordTokenizer::new();
610 let texts = vec![
611 "Hello world".to_string(),
612 "Hello Rust".to_string(),
613 ];
614
615 tokenizer.fit(&texts);
616 assert_eq!(tokenizer.vocab_size(), 3);
617
618 let sequences = tokenizer.texts_to_sequences(&texts);
619 assert_eq!(sequences.len(), 2);
620 assert_eq!(sequences[0].len(), 2);
621 }
622
623 #[test]
624 fn test_char_tokenizer() {
625 let mut tokenizer = CharTokenizer::new();
626 let texts = vec!["abc".to_string(), "def".to_string()];
627
628 tokenizer.fit(&texts);
629 assert_eq!(tokenizer.vocab_size(), 6);
630
631 let sequences = tokenizer.texts_to_sequences(&texts);
632 assert_eq!(sequences[0].len(), 3);
633 }
634
635 #[test]
636 fn test_tfidf() {
637 let texts = vec![
638 "the cat sat on the mat".to_string(),
639 "the dog sat on the log".to_string(),
640 ];
641
642 let mut vectorizer = TfidfVectorizer::new();
643 let vectors = vectorizer.fit_transform(&texts);
644
645 assert_eq!(vectors.len(), 2);
646 assert!(vectors[0].len() > 0);
647 }
648
649 #[test]
650 fn test_word2vec() {
651 let texts = vec![
652 "the quick brown fox jumps".to_string(),
653 "the lazy dog sleeps".to_string(),
654 ];
655
656 let mut w2v = Word2Vec::new(10).min_count(1);
657 w2v.epochs = 2;
658 w2v.fit(&texts);
659
660 assert!(w2v.vocab_size() > 0);
661 assert!(w2v.get_vector("the").is_some());
662 }
663}
664
665