use std::collections::HashMap;
use trustformers_core::errors::Result;
use trustformers_core::traits::{TokenizedInput, Tokenizer};
#[derive(Debug, Clone)]
pub struct CanineTokenizer {
max_length: Option<usize>,
downsample_rate: usize,
hash_size: usize,
cls_token_id: u32,
sep_token_id: u32,
pad_token_id: u32,
mask_token_id: u32,
add_special_tokens: bool,
}
impl CanineTokenizer {
pub fn new() -> Self {
Self {
max_length: None,
downsample_rate: 1, hash_size: 16384, cls_token_id: 0,
sep_token_id: 1,
pad_token_id: 2,
mask_token_id: 3,
add_special_tokens: true,
}
}
pub fn with_max_length(mut self, max_length: usize) -> Self {
self.max_length = Some(max_length);
self
}
pub fn with_downsample_rate(mut self, downsample_rate: usize) -> Self {
self.downsample_rate = downsample_rate;
self
}
pub fn with_hash_size(mut self, hash_size: usize) -> Self {
self.hash_size = hash_size;
self
}
pub fn with_special_tokens(
mut self,
cls_token_id: u32,
sep_token_id: u32,
pad_token_id: u32,
mask_token_id: u32,
) -> Self {
self.cls_token_id = cls_token_id;
self.sep_token_id = sep_token_id;
self.pad_token_id = pad_token_id;
self.mask_token_id = mask_token_id;
self
}
pub fn with_add_special_tokens(mut self, add_special_tokens: bool) -> Self {
self.add_special_tokens = add_special_tokens;
self
}
fn hash_char(&self, ch: char) -> u32 {
let code_point = ch as u32;
if code_point <= 127 {
return 4 + code_point;
}
let mut hash: u64 = 0xcbf29ce484222325; let fnv_prime: u64 = 0x100000001b3;
let bytes = code_point.to_le_bytes();
for byte in bytes {
hash ^= byte as u64;
hash = hash.wrapping_mul(fnv_prime);
}
let hashed = (hash % (self.hash_size as u64 - 132)) + 132;
hashed as u32
}
fn chars_to_ids(&self, text: &str) -> Vec<u32> {
text.chars().map(|ch| self.hash_char(ch)).collect()
}
fn downsample_sequence(&self, token_ids: Vec<u32>) -> Vec<u32> {
if self.downsample_rate <= 1 {
return token_ids;
}
token_ids
.into_iter()
.enumerate()
.filter_map(
|(i, id)| {
if i % self.downsample_rate == 0 {
Some(id)
} else {
None
}
},
)
.collect()
}
fn add_special_tokens_to_sequence(&self, token_ids: Vec<u32>) -> Vec<u32> {
if !self.add_special_tokens {
return token_ids;
}
let mut result = Vec::new();
result.push(self.cls_token_id);
result.extend(token_ids);
result.push(self.sep_token_id);
result
}
fn create_attention_mask(&self, length: usize) -> Vec<u8> {
vec![1; length]
}
fn pad_or_truncate(
&self,
mut token_ids: Vec<u32>,
mut attention_mask: Vec<u8>,
) -> (Vec<u32>, Vec<u8>) {
if let Some(max_len) = self.max_length {
if token_ids.len() > max_len {
token_ids.truncate(max_len);
attention_mask.truncate(max_len);
if self.add_special_tokens && max_len > 0 {
token_ids[max_len - 1] = self.sep_token_id;
}
} else if token_ids.len() < max_len {
let pad_length = max_len - token_ids.len();
token_ids.extend(vec![self.pad_token_id; pad_length]);
attention_mask.extend(vec![0; pad_length]);
}
}
(token_ids, attention_mask)
}
}
impl Default for CanineTokenizer {
fn default() -> Self {
Self::new()
}
}
impl Tokenizer for CanineTokenizer {
fn encode(&self, text: &str) -> Result<TokenizedInput> {
let char_ids = self.chars_to_ids(text);
let downsampled_ids = self.downsample_sequence(char_ids);
let token_ids = self.add_special_tokens_to_sequence(downsampled_ids);
let attention_mask = self.create_attention_mask(token_ids.len());
let (final_token_ids, final_attention_mask) =
self.pad_or_truncate(token_ids, attention_mask);
Ok(TokenizedInput {
input_ids: final_token_ids,
attention_mask: final_attention_mask,
token_type_ids: None,
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
})
}
fn decode(&self, token_ids: &[u32]) -> Result<String> {
let mut result = String::new();
for &token_id in token_ids {
if token_id == self.cls_token_id
|| token_id == self.sep_token_id
|| token_id == self.pad_token_id
{
continue; }
if (4..=131).contains(&token_id) {
let ascii_code = token_id - 4;
if let Some(ch) = char::from_u32(ascii_code) {
result.push(ch);
}
} else {
result.push('�'); }
}
Ok(result)
}
fn vocab_size(&self) -> usize {
self.hash_size
}
fn encode_pair(&self, text: &str, text2: &str) -> Result<TokenizedInput> {
let char_ids1 = self.chars_to_ids(text);
let char_ids2 = self.chars_to_ids(text2);
let downsampled_ids1 = self.downsample_sequence(char_ids1);
let downsampled_ids2 = self.downsample_sequence(char_ids2);
let sep_count = if self.add_special_tokens { 1 } else { 0 };
let first_seq_len = 1 + downsampled_ids1.len() + sep_count;
let mut token_ids = Vec::new();
if self.add_special_tokens {
token_ids.push(self.cls_token_id);
}
token_ids.extend(downsampled_ids1);
if self.add_special_tokens {
token_ids.push(self.sep_token_id);
}
token_ids.extend(downsampled_ids2);
if self.add_special_tokens {
token_ids.push(self.sep_token_id);
}
let attention_mask = self.create_attention_mask(token_ids.len());
let mut token_type_ids = Vec::new();
token_type_ids.extend(vec![0; first_seq_len]);
token_type_ids.extend(vec![1; token_ids.len() - first_seq_len]);
let (final_token_ids, final_attention_mask) =
self.pad_or_truncate(token_ids, attention_mask);
token_type_ids.truncate(final_token_ids.len());
Ok(TokenizedInput {
input_ids: final_token_ids,
attention_mask: final_attention_mask,
token_type_ids: Some(token_type_ids),
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
})
}
fn get_vocab(&self) -> HashMap<String, u32> {
HashMap::new()
}
fn token_to_id(&self, token: &str) -> Option<u32> {
if token.len() == 1 {
token.chars().next().map(|c| self.hash_char(c))
} else {
None
}
}
fn id_to_token(&self, id: u32) -> Option<String> {
if (4..=131).contains(&id) {
Some(((id - 4) as u8 as char).to_string())
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_canine_basic_encoding() {
let tokenizer = CanineTokenizer::new();
let text = "Hello";
let encoded = tokenizer.encode(text).expect("Encoding failed");
assert_eq!(encoded.input_ids.len(), 7); assert_eq!(encoded.input_ids[0], tokenizer.cls_token_id);
assert_eq!(encoded.input_ids[6], tokenizer.sep_token_id);
}
#[test]
fn test_canine_ascii_characters() {
let tokenizer = CanineTokenizer::new();
let text = "A";
let encoded = tokenizer.encode(text).expect("Encoding failed");
assert_eq!(encoded.input_ids[1], 69); }
#[test]
fn test_canine_downsampling() {
let tokenizer = CanineTokenizer::new().with_downsample_rate(2);
let text = "Hello World";
let encoded = tokenizer.encode(text).expect("Encoding failed");
let expected_downsampled_chars = text.len().div_ceil(2); let expected_total = expected_downsampled_chars + 2;
assert_eq!(encoded.input_ids.len(), expected_total);
}
#[test]
fn test_canine_max_length() {
let tokenizer = CanineTokenizer::new().with_max_length(5);
let text = "Hello World";
let encoded = tokenizer.encode(text).expect("Encoding failed");
assert_eq!(encoded.input_ids.len(), 5);
assert_eq!(encoded.attention_mask.len(), 5);
assert_eq!(encoded.input_ids[4], tokenizer.sep_token_id);
}
#[test]
fn test_canine_encode_pair() {
let tokenizer = CanineTokenizer::new();
let text1 = "Hello";
let text2 = "World";
let encoded = tokenizer.encode_pair(text1, text2).expect("Operation failed in test");
let expected_len = 1 + text1.len() + 1 + text2.len() + 1;
assert_eq!(encoded.input_ids.len(), expected_len);
assert!(encoded.token_type_ids.is_some());
let token_types = encoded.token_type_ids.expect("Operation failed in test");
assert_eq!(token_types.len(), expected_len);
assert_eq!(token_types[0], 0); assert_eq!(token_types[1], 0);
let second_seq_start = 1 + text1.len() + 1; assert_eq!(token_types[second_seq_start], 1); }
#[test]
fn test_canine_unicode_handling() {
let tokenizer = CanineTokenizer::new();
let text = "Hello 世界";
let encoded = tokenizer.encode(text).expect("Encoding failed");
assert!(encoded.input_ids.len() > 2);
let h_id = encoded.input_ids[1]; assert_eq!(h_id, 4 + 72); }
#[test]
fn test_canine_decode_ascii() {
let tokenizer = CanineTokenizer::new();
let text = "Hello";
let encoded = tokenizer.encode(text).expect("Encoding failed");
let decoded = tokenizer.decode(&encoded.input_ids).expect("Decoding failed");
assert!(decoded.contains("Hello"));
}
#[test]
fn test_canine_no_special_tokens() {
let tokenizer = CanineTokenizer::new().with_add_special_tokens(false);
let text = "Hi";
let encoded = tokenizer.encode(text).expect("Encoding failed");
assert_eq!(encoded.input_ids.len(), text.len());
}
}