1use crate::error::{Result, TextError};
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::{BufRead, BufReader};
14use unicode_normalization::UnicodeNormalization;
15
16const CLS_TOKEN: &str = "[CLS]";
19const SEP_TOKEN: &str = "[SEP]";
20const PAD_TOKEN: &str = "[PAD]";
21const MASK_TOKEN: &str = "[MASK]";
22const UNK_TOKEN: &str = "[UNK]";
23
24const MAX_WORD_CHARS: usize = 200;
26
27#[derive(Debug, Clone, PartialEq)]
34pub struct BertEncoding {
35 pub input_ids: Vec<u32>,
37 pub attention_mask: Vec<u32>,
39 pub token_type_ids: Vec<u32>,
41}
42
43impl BertEncoding {
44 pub fn new(input_ids: Vec<u32>, attention_mask: Vec<u32>, token_type_ids: Vec<u32>) -> Self {
46 BertEncoding {
47 input_ids,
48 attention_mask,
49 token_type_ids,
50 }
51 }
52
53 pub fn len(&self) -> usize {
55 self.input_ids.len()
56 }
57
58 pub fn is_empty(&self) -> bool {
60 self.input_ids.is_empty()
61 }
62}
63
64#[derive(Debug, Clone)]
68pub struct BatchEncoding {
69 pub encodings: Vec<BertEncoding>,
71}
72
73impl BatchEncoding {
74 pub fn new(encodings: Vec<BertEncoding>) -> Self {
76 BatchEncoding { encodings }
77 }
78
79 pub fn len(&self) -> usize {
81 self.encodings.len()
82 }
83
84 pub fn is_empty(&self) -> bool {
86 self.encodings.is_empty()
87 }
88
89 pub fn input_ids(&self) -> Vec<Vec<u32>> {
91 self.encodings.iter().map(|e| e.input_ids.clone()).collect()
92 }
93
94 pub fn attention_masks(&self) -> Vec<Vec<u32>> {
96 self.encodings
97 .iter()
98 .map(|e| e.attention_mask.clone())
99 .collect()
100 }
101
102 pub fn token_type_ids(&self) -> Vec<Vec<u32>> {
104 self.encodings
105 .iter()
106 .map(|e| e.token_type_ids.clone())
107 .collect()
108 }
109}
110
111#[derive(Debug, Clone)]
117struct BasicTokenizer {
118 do_lower_case: bool,
119}
120
121impl BasicTokenizer {
122 fn new(do_lower_case: bool) -> Self {
123 BasicTokenizer { do_lower_case }
124 }
125
126 fn tokenize(&self, text: &str) -> Vec<String> {
127 let text = if self.do_lower_case {
128 text.to_lowercase()
129 } else {
130 text.to_string()
131 };
132
133 let text: String = text.nfd().filter(|c| !is_combining_mark(*c)).collect();
135
136 let mut spaced = String::with_capacity(text.len() + 32);
138 for ch in text.chars() {
139 if ch.is_whitespace() {
140 spaced.push(' ');
141 } else if is_punctuation_char(ch) || is_chinese_char(ch) {
142 spaced.push(' ');
143 spaced.push(ch);
144 spaced.push(' ');
145 } else {
146 spaced.push(ch);
147 }
148 }
149
150 spaced
151 .split_whitespace()
152 .filter(|s| !s.is_empty())
153 .map(|s| s.to_string())
154 .collect()
155 }
156}
157
158fn is_combining_mark(ch: char) -> bool {
160 let cp = ch as u32;
161 (0x0300..=0x036F).contains(&cp)
162 || (0x1DC0..=0x1DFF).contains(&cp)
163 || (0x1AB0..=0x1AFF).contains(&cp)
164 || (0x20D0..=0x20FF).contains(&cp)
165}
166
167fn is_punctuation_char(ch: char) -> bool {
169 let cp = ch as u32;
170 if cp <= 47 || (58..=64).contains(&cp) || (91..=96).contains(&cp) || (123..=126).contains(&cp) {
172 return true;
173 }
174 ch.is_ascii_punctuation()
176 || matches!(
177 ch,
178 '。' | ','
179 | '、'
180 | ';'
181 | ':'
182 | '?'
183 | '!'
184 | '—'
185 | '…'
186 | '\u{2018}'
187 | '\u{2019}'
188 | '\u{201C}'
189 | '\u{201D}'
190 )
191}
192
193fn is_chinese_char(ch: char) -> bool {
195 let cp = ch as u32;
196 (0x4E00..=0x9FFF).contains(&cp)
197 || (0x3400..=0x4DBF).contains(&cp)
198 || (0x20000..=0x2A6DF).contains(&cp)
199 || (0x2A700..=0x2B73F).contains(&cp)
200 || (0x2B740..=0x2B81F).contains(&cp)
201 || (0x2B820..=0x2CEAF).contains(&cp)
202 || (0xF900..=0xFAFF).contains(&cp)
203 || (0x2F800..=0x2FA1F).contains(&cp)
204}
205
206fn wordpiece_segment(word: &str, vocab: &HashMap<String, u32>) -> Vec<String> {
213 let chars: Vec<char> = word.chars().collect();
214 if chars.len() > MAX_WORD_CHARS {
215 return vec![UNK_TOKEN.to_string()];
216 }
217
218 let n = chars.len();
219 let mut sub_tokens: Vec<String> = Vec::new();
220 let mut start = 0usize;
221
222 while start < n {
223 let mut end = n;
224 let mut found_tok: Option<String> = None;
225
226 while start < end {
227 let substr: String = chars[start..end].iter().collect();
228 let candidate = if start == 0 {
229 substr.clone()
230 } else {
231 format!("##{}", substr)
232 };
233
234 if vocab.contains_key(&candidate) {
235 found_tok = Some(candidate);
236 break;
237 }
238
239 if end == start + 1 {
240 return vec![UNK_TOKEN.to_string()];
242 }
243 end -= 1;
244 }
245
246 match found_tok {
247 Some(tok) => {
248 sub_tokens.push(tok);
249 start = end;
250 }
251 None => {
252 return vec![UNK_TOKEN.to_string()];
253 }
254 }
255 }
256
257 sub_tokens
258}
259
260#[derive(Debug, Clone)]
288pub struct BertTokenizer {
289 vocab: HashMap<String, u32>,
290 ids_to_tokens: HashMap<u32, String>,
291 cls_token_id: u32,
292 sep_token_id: u32,
293 pad_token_id: u32,
294 mask_token_id: u32,
295 unk_token_id: u32,
296 max_len: usize,
297 lowercase: bool,
298 basic: BasicTokenizer,
299}
300
301impl BertTokenizer {
302 pub fn new(mut vocab: HashMap<String, u32>, lowercase: bool) -> Self {
309 let specials = [PAD_TOKEN, UNK_TOKEN, CLS_TOKEN, SEP_TOKEN, MASK_TOKEN];
311 for tok in &specials {
312 if !vocab.contains_key(*tok) {
313 let next_id = vocab.len() as u32;
314 vocab.insert(tok.to_string(), next_id);
315 }
316 }
317
318 let cls_token_id = vocab[CLS_TOKEN];
319 let sep_token_id = vocab[SEP_TOKEN];
320 let pad_token_id = vocab[PAD_TOKEN];
321 let mask_token_id = vocab[MASK_TOKEN];
322 let unk_token_id = vocab[UNK_TOKEN];
323
324 let ids_to_tokens: HashMap<u32, String> =
325 vocab.iter().map(|(k, &v)| (v, k.clone())).collect();
326
327 BertTokenizer {
328 vocab,
329 ids_to_tokens,
330 cls_token_id,
331 sep_token_id,
332 pad_token_id,
333 mask_token_id,
334 unk_token_id,
335 max_len: 512,
336 lowercase,
337 basic: BasicTokenizer::new(lowercase),
338 }
339 }
340
341 pub fn from_vocab_file(path: &str) -> Result<Self> {
347 let file = File::open(path).map_err(|e| TextError::IoError(e.to_string()))?;
348 let reader = BufReader::new(file);
349
350 let mut vocab = HashMap::new();
351 for (idx, line) in reader.lines().enumerate() {
352 let token = line.map_err(|e| TextError::IoError(e.to_string()))?;
353 let token = token.trim().to_string();
354 if !token.is_empty() {
355 vocab.insert(token, idx as u32);
356 }
357 }
358
359 if vocab.is_empty() {
360 return Err(TextError::VocabularyError(
361 "Vocabulary file is empty".to_string(),
362 ));
363 }
364
365 Ok(Self::new(vocab, true))
366 }
367
368 pub fn with_max_len(mut self, max_len: usize) -> Self {
370 self.max_len = max_len;
371 self
372 }
373
374 pub fn cls_token_id(&self) -> u32 {
378 self.cls_token_id
379 }
380
381 pub fn sep_token_id(&self) -> u32 {
383 self.sep_token_id
384 }
385
386 pub fn pad_token_id(&self) -> u32 {
388 self.pad_token_id
389 }
390
391 pub fn mask_token_id(&self) -> u32 {
393 self.mask_token_id
394 }
395
396 pub fn unk_token_id(&self) -> u32 {
398 self.unk_token_id
399 }
400
401 pub fn vocab_size(&self) -> usize {
403 self.vocab.len()
404 }
405
406 pub fn vocab(&self) -> &HashMap<String, u32> {
408 &self.vocab
409 }
410
411 pub fn lowercase(&self) -> bool {
413 self.lowercase
414 }
415
416 pub fn tokenize(&self, text: &str) -> Vec<String> {
424 if text.is_empty() {
425 return Vec::new();
426 }
427 let words = self.basic.tokenize(text);
428 words
429 .iter()
430 .flat_map(|w| wordpiece_segment(w, &self.vocab))
431 .collect()
432 }
433
434 fn token_to_id(&self, token: &str) -> u32 {
436 self.vocab.get(token).copied().unwrap_or(self.unk_token_id)
437 }
438
439 pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
446 let sub_tokens = self.tokenize(text);
447 let mut ids = Vec::with_capacity(sub_tokens.len() + 2);
448 ids.push(self.cls_token_id);
449 ids.extend(sub_tokens.iter().map(|t| self.token_to_id(t)));
450 ids.push(self.sep_token_id);
451 Ok(ids)
452 }
453
454 pub fn encode_pair(&self, text_a: &str, text_b: &str) -> Result<(Vec<u32>, Vec<u32>)> {
461 let tokens_a = self.tokenize(text_a);
462 let tokens_b = self.tokenize(text_b);
463
464 let total = 1 + tokens_a.len() + 1 + tokens_b.len() + 1; let mut ids = Vec::with_capacity(total);
466 let mut type_ids = Vec::with_capacity(total);
467
468 ids.push(self.cls_token_id);
470 type_ids.push(0u32);
471
472 for tok in &tokens_a {
473 ids.push(self.token_to_id(tok));
474 type_ids.push(0);
475 }
476
477 ids.push(self.sep_token_id);
478 type_ids.push(0);
479
480 for tok in &tokens_b {
482 ids.push(self.token_to_id(tok));
483 type_ids.push(1);
484 }
485
486 ids.push(self.sep_token_id);
487 type_ids.push(1);
488
489 Ok((ids, type_ids))
490 }
491
492 pub fn encode_single(
499 &self,
500 text: &str,
501 max_length: usize,
502 padding: bool,
503 truncation: bool,
504 ) -> Result<BertEncoding> {
505 if max_length == 0 {
506 return Err(TextError::InvalidInput(
507 "max_length must be greater than 0".to_string(),
508 ));
509 }
510
511 let sub_tokens = self.tokenize(text);
512 let budget = max_length.saturating_sub(2);
514
515 let content: Vec<u32> = if truncation && sub_tokens.len() > budget {
516 sub_tokens[..budget]
517 .iter()
518 .map(|t| self.token_to_id(t))
519 .collect()
520 } else {
521 sub_tokens.iter().map(|t| self.token_to_id(t)).collect()
522 };
523
524 let mut ids = Vec::with_capacity(max_length);
525 ids.push(self.cls_token_id);
526 ids.extend_from_slice(&content);
527 ids.push(self.sep_token_id);
528
529 let real_len = ids.len();
530
531 if padding && ids.len() < max_length {
532 let pad_count = max_length - ids.len();
533 ids.extend(std::iter::repeat_n(self.pad_token_id, pad_count));
534 }
535
536 let seq_len = ids.len();
537 let mut mask = vec![0u32; seq_len];
538 for m in mask.iter_mut().take(real_len) {
539 *m = 1;
540 }
541 let type_ids = vec![0u32; seq_len];
542
543 Ok(BertEncoding::new(ids, mask, type_ids))
544 }
545
546 pub fn encode_batch(
552 &self,
553 texts: &[&str],
554 max_length: usize,
555 padding: bool,
556 truncation: bool,
557 ) -> Result<BatchEncoding> {
558 if max_length == 0 {
559 return Err(TextError::InvalidInput(
560 "max_length must be greater than 0".to_string(),
561 ));
562 }
563
564 let mut raw_encodings: Vec<(Vec<u32>, usize)> = Vec::with_capacity(texts.len());
566
567 for text in texts {
568 let sub_tokens = self.tokenize(text);
569 let budget = max_length.saturating_sub(2);
570 let content: Vec<u32> = if truncation && sub_tokens.len() > budget {
571 sub_tokens[..budget]
572 .iter()
573 .map(|t| self.token_to_id(t))
574 .collect()
575 } else {
576 sub_tokens.iter().map(|t| self.token_to_id(t)).collect()
577 };
578
579 let mut ids = Vec::with_capacity(content.len() + 2);
580 ids.push(self.cls_token_id);
581 ids.extend_from_slice(&content);
582 ids.push(self.sep_token_id);
583 let real_len = ids.len();
584 raw_encodings.push((ids, real_len));
585 }
586
587 let target_len = if padding {
589 let max_real = raw_encodings
590 .iter()
591 .map(|(ids, _)| ids.len())
592 .max()
593 .unwrap_or(0);
594 max_real.min(max_length)
595 } else {
596 max_length
597 };
598
599 let encodings = raw_encodings
601 .into_iter()
602 .map(|(mut ids, real_len)| {
603 if padding && ids.len() < target_len {
604 let pad_count = target_len - ids.len();
605 ids.extend(std::iter::repeat_n(self.pad_token_id, pad_count));
606 }
607
608 let seq_len = ids.len();
609 let mut mask = vec![0u32; seq_len];
610 for m in mask.iter_mut().take(real_len) {
611 *m = 1;
612 }
613 let type_ids = vec![0u32; seq_len];
614 BertEncoding::new(ids, mask, type_ids)
615 })
616 .collect();
617
618 Ok(BatchEncoding::new(encodings))
619 }
620
621 pub fn decode(&self, ids: &[u32]) -> String {
629 let special_ids: [u32; 4] = [
630 self.cls_token_id,
631 self.sep_token_id,
632 self.pad_token_id,
633 self.mask_token_id,
634 ];
635
636 let mut out = String::new();
637 for &id in ids {
638 if special_ids.contains(&id) {
639 continue;
640 }
641
642 let tok = match self.ids_to_tokens.get(&id) {
643 Some(t) => t.as_str(),
644 None => UNK_TOKEN,
645 };
646
647 if tok == UNK_TOKEN {
648 if !out.is_empty() {
649 out.push(' ');
650 }
651 out.push_str(tok);
652 continue;
653 }
654
655 if let Some(cont) = tok.strip_prefix("##") {
656 out.push_str(cont);
658 } else {
659 if !out.is_empty() {
660 out.push(' ');
661 }
662 out.push_str(tok);
663 }
664 }
665 out
666 }
667
668 pub fn convert_token_to_id(&self, token: &str) -> Option<u32> {
670 self.vocab.get(token).copied()
671 }
672
673 pub fn convert_id_to_token(&self, id: u32) -> Option<&str> {
675 self.ids_to_tokens.get(&id).map(|s| s.as_str())
676 }
677}
678
679#[cfg(test)]
682mod tests {
683 use super::*;
684 use std::collections::HashMap;
685
686 fn base_vocab() -> HashMap<String, u32> {
690 let tokens = [
691 "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "hello", "world", "play", "##ing", "##ed", "good", "morning", "the", "quick", "brown", "fox", ",", "!", ];
710 tokens
711 .iter()
712 .enumerate()
713 .map(|(i, t)| (t.to_string(), i as u32))
714 .collect()
715 }
716
717 fn make_tokenizer() -> BertTokenizer {
718 BertTokenizer::new(base_vocab(), true)
719 }
720
721 #[test]
724 fn test_bert_tokenize_basic() {
725 let tok = make_tokenizer();
726 let tokens = tok.tokenize("Hello, World!");
727 assert!(
729 tokens.contains(&"hello".to_string()),
730 "expected 'hello' in {:?}",
731 tokens
732 );
733 assert!(
734 tokens.contains(&"world".to_string()),
735 "expected 'world' in {:?}",
736 tokens
737 );
738 assert!(
740 tokens.contains(&",".to_string()),
741 "expected ',' in {:?}",
742 tokens
743 );
744 assert!(
745 tokens.contains(&"!".to_string()),
746 "expected '!' in {:?}",
747 tokens
748 );
749 }
750
751 #[test]
754 fn test_bert_special_tokens() {
755 let tok = make_tokenizer();
756 let ids = tok.encode("hello world").expect("encode failed");
757 assert_eq!(ids[0], tok.cls_token_id(), "first token should be [CLS]");
759 assert_eq!(
760 *ids.last().expect("non-empty"),
761 tok.sep_token_id(),
762 "last token should be [SEP]"
763 );
764 }
765
766 #[test]
769 fn test_bert_wordpiece() {
770 let tok = make_tokenizer();
771 let tokens = tok.tokenize("playing");
773 assert_eq!(tokens, vec!["play", "##ing"]);
774 }
775
776 #[test]
779 fn test_bert_unknown() {
780 let tok = make_tokenizer();
781 let ids = tok.encode("xyzzy").expect("encode failed");
783 assert_eq!(ids.len(), 3);
785 assert_eq!(ids[1], tok.unk_token_id(), "OOV token should map to [UNK]");
786 }
787
788 #[test]
791 fn test_bert_encode_pair() {
792 let tok = make_tokenizer();
793 let (ids, type_ids) = tok
794 .encode_pair("hello", "world")
795 .expect("encode_pair failed");
796 assert_eq!(ids[0], tok.cls_token_id());
798 assert_eq!(type_ids[0], 0);
800 assert_eq!(*ids.last().expect("non-empty"), tok.sep_token_id());
802 assert_eq!(*type_ids.last().expect("non-empty"), 1);
804
805 let first_sep_pos = ids
807 .iter()
808 .position(|&id| id == tok.sep_token_id())
809 .expect("has SEP");
810 for i in 0..=first_sep_pos {
812 assert_eq!(type_ids[i], 0, "position {} should be type 0", i);
813 }
814 for i in (first_sep_pos + 1)..type_ids.len() {
816 assert_eq!(type_ids[i], 1, "position {} should be type 1", i);
817 }
818 }
819
820 #[test]
823 fn test_bert_decode_skips_special() {
824 let tok = make_tokenizer();
825 let ids = tok.encode("hello world").expect("encode failed");
826 let decoded = tok.decode(&ids);
827 assert!(
829 !decoded.contains("[CLS]"),
830 "decoded should not contain [CLS]: {:?}",
831 decoded
832 );
833 assert!(
834 !decoded.contains("[SEP]"),
835 "decoded should not contain [SEP]: {:?}",
836 decoded
837 );
838 assert!(
839 decoded.contains("hello"),
840 "decoded should contain 'hello': {:?}",
841 decoded
842 );
843 assert!(
844 decoded.contains("world"),
845 "decoded should contain 'world': {:?}",
846 decoded
847 );
848 }
849
850 #[test]
853 fn test_bert_batch_padding() {
854 let tok = make_tokenizer();
855 let texts = vec!["hello", "hello world"];
856 let batch = tok
857 .encode_batch(&texts, 10, true, false)
858 .expect("encode_batch failed");
859
860 assert_eq!(batch.len(), 2);
861 let len0 = batch.encodings[0].len();
863 let len1 = batch.encodings[1].len();
864 assert_eq!(len0, len1, "padded lengths must be equal");
865
866 let short_enc = &batch.encodings[0];
868 let has_pad = short_enc
869 .input_ids
870 .iter()
871 .any(|&id| id == tok.pad_token_id());
872 let longer_real = batch.encodings[1]
873 .attention_mask
874 .iter()
875 .filter(|&&m| m == 1)
876 .count();
877 let shorter_real = batch.encodings[0]
878 .attention_mask
879 .iter()
880 .filter(|&&m| m == 1)
881 .count();
882 assert!(
883 has_pad,
884 "shorter sequence should have padding; ids={:?}",
885 short_enc.input_ids
886 );
887 assert!(
888 shorter_real < longer_real,
889 "shorter text should have fewer real tokens"
890 );
891 for (id, mask) in short_enc
893 .input_ids
894 .iter()
895 .zip(short_enc.attention_mask.iter())
896 {
897 if *id == tok.pad_token_id() {
898 assert_eq!(*mask, 0, "padding token must have mask 0");
899 }
900 }
901 }
902
903 #[test]
906 fn test_bert_batch_truncation() {
907 let tok = make_tokenizer();
908 let texts = vec!["the quick brown fox"];
911 let batch = tok
912 .encode_batch(&texts, 4, false, true)
913 .expect("encode_batch failed");
914
915 let enc = &batch.encodings[0];
916 assert_eq!(enc.input_ids.len(), 4);
918 assert_eq!(enc.input_ids[0], tok.cls_token_id());
919 assert_eq!(
920 *enc.input_ids.last().expect("non-empty"),
921 tok.sep_token_id()
922 );
923 }
924
925 #[test]
928 fn test_bert_lowercase() {
929 let tok_lower = BertTokenizer::new(base_vocab(), true);
930 let tok_cased = BertTokenizer::new(base_vocab(), false);
931
932 let lower_tokens = tok_lower.tokenize("HELLO");
934 assert!(
935 lower_tokens.contains(&"hello".to_string()),
936 "lowercase should map HELLO→hello: {:?}",
937 lower_tokens
938 );
939
940 let cased_tokens = tok_cased.tokenize("HELLO");
942 assert!(
943 cased_tokens.contains(&"[UNK]".to_string()),
944 "cased tokenizer should map HELLO to [UNK]: {:?}",
945 cased_tokens
946 );
947 }
948
949 #[test]
952 fn test_bert_from_vocab_string() {
953 let token_list: &[&str] = &[
955 "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "rust", "is", "great",
956 ];
957 let vocab: HashMap<String, u32> = token_list
958 .iter()
959 .enumerate()
960 .map(|(i, t)| (t.to_string(), i as u32))
961 .collect();
962 let tokenizer = BertTokenizer::new(vocab, true);
963 let ids = tokenizer.encode("rust is great").expect("encode failed");
964 assert_eq!(ids.len(), 5);
966 assert_eq!(ids[0], tokenizer.cls_token_id());
967 }
968
969 #[test]
972 fn test_bert_empty_input() {
973 let tok = make_tokenizer();
974 let ids = tok.encode("").expect("encode empty");
975 assert_eq!(ids.len(), 2);
977 assert_eq!(ids[0], tok.cls_token_id());
978 assert_eq!(ids[1], tok.sep_token_id());
979 }
980
981 #[test]
984 fn test_bert_all_oov() {
985 let tok = make_tokenizer();
986 let ids = tok.encode("zzz yyy xxx").expect("encode all-OOV");
988 assert_eq!(ids.len(), 5);
990 for &id in &ids[1..4] {
991 assert_eq!(id, tok.unk_token_id());
992 }
993 }
994
995 #[test]
998 fn test_bert_max_len_one_truncation() {
999 let tok = make_tokenizer();
1000 let enc = tok
1006 .encode_single("hello world", 1, false, true)
1007 .expect("encode_single");
1008 assert!(
1010 enc.input_ids.len() >= 2,
1011 "must contain at least [CLS] and [SEP]"
1012 );
1013 assert_eq!(enc.input_ids[0], tok.cls_token_id());
1014 assert_eq!(
1015 *enc.input_ids.last().expect("non-empty"),
1016 tok.sep_token_id()
1017 );
1018 assert_eq!(enc.input_ids.len(), 2, "only [CLS] and [SEP] expected");
1020 }
1021
1022 #[test]
1025 fn test_bert_decode_wordpiece_merge() {
1026 let tok = make_tokenizer();
1027 let decoded = tok.decode(&[7, 8]);
1029 assert_eq!(decoded, "playing", "expected 'playing', got '{}'", decoded);
1030 }
1031
1032 #[test]
1035 fn test_bert_from_vocab_file() {
1036 use std::io::Write;
1037
1038 let mut tmp = std::env::temp_dir();
1039 tmp.push("scirs2_bert_vocab_test.txt");
1040 {
1041 let mut f = std::fs::File::create(&tmp).expect("create temp file");
1042 writeln!(f, "[PAD]").expect("write");
1043 writeln!(f, "[UNK]").expect("write");
1044 writeln!(f, "[CLS]").expect("write");
1045 writeln!(f, "[SEP]").expect("write");
1046 writeln!(f, "[MASK]").expect("write");
1047 writeln!(f, "hello").expect("write");
1048 writeln!(f, "world").expect("write");
1049 }
1050 let path = tmp.to_str().expect("valid path");
1051 let tokenizer = BertTokenizer::from_vocab_file(path).expect("from_vocab_file");
1052 assert_eq!(tokenizer.convert_token_to_id("[CLS]"), Some(2));
1053 assert_eq!(tokenizer.convert_token_to_id("hello"), Some(5));
1054 let _ = std::fs::remove_file(&tmp);
1055 }
1056}