1use std::collections::HashMap as StdHashMap;
2use std::path::{Path, PathBuf};
3
4use fxhash::FxHashMap as HashMap;
5
6use gtars_core::models::Region;
7use gtars_overlaprs::Overlapper;
8
9use crate::utils::create_tokenize_core_from_universe;
10
11use super::config::{TokenizerConfig, TokenizerInputFileType, TokenizerType};
12use super::error::TokenizerError;
13use super::universe::Universe;
14use super::utils::prepare_universe_and_special_tokens;
15use super::utils::special_tokens::SpecialTokens;
16
17#[cfg(feature = "huggingface")]
18use hf_hub::api::sync::Api;
19
20pub const DEFAULT_UNIVERSE_FILENAME: &str = "universe.bed.gz";
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct Token {
24 pub id: u32,
25 pub value: String,
26}
27
28impl Token {
29 pub fn new(id: u32, value: String) -> Self {
30 Self { id, value }
31 }
32}
33
34pub struct Tokenizer {
35 core: HashMap<String, Box<dyn Overlapper<u32, u32>>>,
36 universe: Universe,
37 special_tokens: SpecialTokens,
38}
39
40impl Tokenizer {
41 pub fn new(
45 core: HashMap<String, Box<dyn Overlapper<u32, u32>>>,
46 universe: Universe,
47 special_tokens: SpecialTokens,
48 ) -> Self {
49 Self {
50 core,
51 universe,
52 special_tokens,
53 }
54 }
55
56 pub fn from_config<P: AsRef<Path>>(cfg_path: P) -> Result<Self, TokenizerError> {
60 let config = TokenizerConfig::try_from(cfg_path.as_ref())?;
61
62 let config_path = cfg_path.as_ref().parent().unwrap();
63 let universe_path = config_path.join(&config.universe);
64 let special_tokens = match config.special_tokens {
65 Some(tokens) => SpecialTokens::from(tokens),
66 None => SpecialTokens::default(),
67 };
68
69 let (universe, special_tokens) =
70 prepare_universe_and_special_tokens(universe_path, special_tokens)?;
71
72 let core = create_tokenize_core_from_universe(
73 &universe,
74 config.tokenizer_type.unwrap_or(TokenizerType::Bits),
75 );
76
77 Ok(Tokenizer {
78 core,
79 special_tokens,
80 universe,
81 })
82 }
83
84 pub fn from_bed<P: AsRef<Path>>(bed_path: P) -> Result<Self, TokenizerError> {
88 let special_tokens = SpecialTokens::default();
89 let (universe, special_tokens) =
90 prepare_universe_and_special_tokens(bed_path.as_ref(), special_tokens)?;
91 let core = create_tokenize_core_from_universe(&universe, TokenizerType::Bits);
92
93 Ok(Tokenizer {
94 core,
95 special_tokens,
96 universe,
97 })
98 }
99
100 #[cfg(feature = "huggingface")]
104 pub fn from_pretrained<P: AsRef<Path>>(path: P) -> Result<Self, TokenizerError> {
105 let universe_file_path: PathBuf = if path.as_ref().exists() {
107 path.as_ref().join(DEFAULT_UNIVERSE_FILENAME)
108 } else {
109 let api = Api::new().unwrap();
110 let repo = api.model(
111 path.as_ref()
112 .to_str()
113 .expect("Path is not valid UTF-8")
114 .to_string(),
115 );
116 repo.get(DEFAULT_UNIVERSE_FILENAME).unwrap()
117 };
118 let file_type = TokenizerInputFileType::from_path(universe_file_path.as_path())?;
119 match file_type {
120 TokenizerInputFileType::Toml => Tokenizer::from_config(universe_file_path),
121 TokenizerInputFileType::Bed => Tokenizer::from_bed(universe_file_path),
122 TokenizerInputFileType::BedGz => Tokenizer::from_bed(universe_file_path),
123 }
124 }
125
126 pub fn from_auto<P: AsRef<Path>>(path: P) -> Result<Self, TokenizerError> {
130 let file_type = TokenizerInputFileType::from_path(path.as_ref())?;
131 match file_type {
132 TokenizerInputFileType::Toml => Tokenizer::from_config(path),
133 TokenizerInputFileType::Bed => Tokenizer::from_bed(path),
134 TokenizerInputFileType::BedGz => Tokenizer::from_bed(path),
135 }
136 }
137
138 pub fn tokenize(&self, regions: &[Region]) -> Result<Vec<String>, TokenizerError> {
139 let tokens = regions
140 .iter()
141 .filter_map(|region| {
142 self.core.get(®ion.chr).map(|tree| {
143 tree.find(region.start, region.end)
144 .iter()
145 .map(|i| i.val)
146 .collect::<Vec<u32>>()
147 })
148 })
149 .flatten()
150 .map(|id| Token {
151 value: self.universe.convert_id_to_token(id).unwrap(),
152 id,
153 })
154 .collect::<Vec<Token>>();
155
156 if tokens.is_empty() {
157 return Ok(vec![self.special_tokens.unk.clone()]);
158 }
159
160 Ok(tokens.into_iter().map(|t| t.value).collect())
161 }
162
163 pub fn encode(&self, regions: &[Region]) -> Result<Vec<u32>, TokenizerError> {
164 let tokenized = self.tokenize(regions)?;
165 Ok(tokenized
166 .into_iter()
167 .map(|t| self.convert_token_to_id(&t).unwrap())
168 .collect())
169 }
170
171 pub fn decode(&self, ids: &[u32]) -> Result<Vec<String>, TokenizerError> {
172 let decoded: Vec<String> = ids
173 .iter()
174 .map(|id| {
175 self.universe
176 .convert_id_to_token(*id)
177 .unwrap_or(self.special_tokens.unk.clone())
178 })
179 .collect();
180 Ok(decoded)
181 }
182
183 pub fn convert_token_to_id(&self, token: &str) -> Option<u32> {
184 self.universe.convert_token_to_id(token)
185 }
186
187 pub fn convert_id_to_token(&self, id: u32) -> Option<String> {
188 self.universe.convert_id_to_token(id)
189 }
190
191 pub fn get_vocab_size(&self) -> usize {
192 self.universe.len()
193 }
194
195 pub fn get_vocab(&self) -> StdHashMap<String, u32> {
196 self.universe.region_to_id.clone()
197 }
198
199 pub fn get_unk_token(&self) -> String {
200 self.special_tokens.unk.clone()
201 }
202
203 pub fn get_pad_token(&self) -> String {
204 self.special_tokens.pad.clone()
205 }
206
207 pub fn get_mask_token(&self) -> String {
208 self.special_tokens.mask.clone()
209 }
210
211 pub fn get_cls_token(&self) -> String {
212 self.special_tokens.cls.clone()
213 }
214
215 pub fn get_eos_token(&self) -> String {
216 self.special_tokens.eos.clone()
217 }
218
219 pub fn get_bos_token(&self) -> String {
220 self.special_tokens.bos.clone()
221 }
222
223 pub fn get_sep_token(&self) -> String {
224 self.special_tokens.sep.clone()
225 }
226
227 pub fn get_unk_token_id(&self) -> u32 {
229 self.convert_token_to_id(&self.special_tokens.unk).unwrap()
230 }
231
232 pub fn get_pad_token_id(&self) -> u32 {
233 self.convert_token_to_id(&self.special_tokens.pad).unwrap()
234 }
235
236 pub fn get_mask_token_id(&self) -> u32 {
237 self.convert_token_to_id(&self.special_tokens.mask).unwrap()
238 }
239
240 pub fn get_cls_token_id(&self) -> u32 {
241 self.convert_token_to_id(&self.special_tokens.cls).unwrap()
242 }
243
244 pub fn get_eos_token_id(&self) -> u32 {
245 self.convert_token_to_id(&self.special_tokens.eos).unwrap()
246 }
247
248 pub fn get_bos_token_id(&self) -> u32 {
249 self.convert_token_to_id(&self.special_tokens.bos).unwrap()
250 }
251
252 pub fn get_sep_token_id(&self) -> u32 {
253 self.convert_token_to_id(&self.special_tokens.sep).unwrap()
254 }
255
256 pub fn get_special_tokens_mask(&self, tokens: &[String]) -> Vec<bool> {
257 tokens
258 .iter()
259 .map(|token| {
260 token == &self.special_tokens.unk
261 || token == &self.special_tokens.pad
262 || token == &self.special_tokens.mask
263 || token == &self.special_tokens.cls
264 || token == &self.special_tokens.eos
265 || token == &self.special_tokens.bos
266 || token == &self.special_tokens.sep
267 })
268 .collect()
269 }
270
271 pub fn get_special_tokens(&self) -> &SpecialTokens {
272 &self.special_tokens
273 }
274
275 pub fn get_universe(&self) -> &Universe {
276 &self.universe
277 }
278}
279
280#[cfg(test)]
281mod tokenizer_tests {
282 use super::*;
283
284 use pretty_assertions::assert_eq;
285 use rstest::*;
286
287 #[fixture]
288 fn anndata_path() -> String {
289 "../tests/data/tokenizers/pbmc_hg38.h5ad".to_string()
290 }
291
292 #[rstest]
293 fn test_tokenizer_creation_from_config() {
294 let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
295 let tokenizer =
296 Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
297 assert_eq!(tokenizer.get_vocab_size(), 32); }
299
300 #[rstest]
301 fn test_tokenizer_creation_from_bed() {
302 let bed_path = "../tests/data/tokenizers/peaks.bed";
303 let tokenizer =
304 Tokenizer::from_bed(bed_path).expect("Failed to create tokenizer from config.");
305 assert_eq!(tokenizer.get_vocab_size(), 32); }
307
308 #[rstest]
309 fn test_tokenizer_creation_from_bed_gz() {
310 let bed_path = "../tests/data/tokenizers/peaks.bed.gz";
311 let tokenizer =
312 Tokenizer::from_bed(bed_path).expect("Failed to create tokenizer from config.");
313 assert_eq!(tokenizer.get_vocab_size(), 32); }
315
316 #[rstest]
317 fn test_tokenizer_creation_auto_all() {
318 let bed_path = "../tests/data/tokenizers/peaks.bed";
319 let tokenizer =
320 Tokenizer::from_auto(bed_path).expect("Failed to create tokenizer from config.");
321 assert_eq!(tokenizer.get_vocab_size(), 32); let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
324 let tokenizer =
325 Tokenizer::from_auto(cfg_path).expect("Failed to create tokenizer from config.");
326 assert_eq!(tokenizer.get_vocab_size(), 32); let bed_path = "../tests/data/tokenizers/peaks.bed.gz";
329 let tokenizer =
330 Tokenizer::from_auto(bed_path).expect("Failed to create tokenizer from config.");
331 assert_eq!(tokenizer.get_vocab_size(), 32); }
333
334 #[rstest]
335 fn test_tokenizer_bad_tokenizer_type() {
336 let cfg_path = "../tests/data/tokenizers/tokenizer_bad_ttype.toml";
337 let tokenizer = Tokenizer::from_config(cfg_path);
338 assert_eq!(tokenizer.is_err(), true);
339 }
340
341 #[rstest]
342 fn test_tokenizer_custom_special_tokens() {
343 let cfg_path = "../tests/data/tokenizers/tokenizer_custom_specials.toml";
344 let tokenizer =
345 Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
346
347 assert_eq!(tokenizer.get_vocab_size(), 32); assert_eq!(tokenizer.get_unk_token(), "<UNKNOWN>");
351
352 assert_eq!(tokenizer.get_pad_token(), "<pad>");
354 }
355
356 #[rstest]
357 fn test_tokenize_single_region_not_overlapping() {
358 let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
359 let tokenizer =
360 Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
361
362 let regions = vec![Region {
363 chr: "chr1".to_string(),
364 start: 50,
365 end: 150,
366 rest: None,
367 }];
368
369 let tokenized = tokenizer.tokenize(®ions);
370 assert!(tokenized.is_ok());
371 let tokenized = tokenized.unwrap();
372 assert_eq!(tokenized.len(), 1);
373 assert_eq!(tokenized[0], "<unk>");
374 }
375
376 #[rstest]
377 fn test_tokenize_unk_chrom() {
378 let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
379 let tokenizer =
380 Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
381
382 let regions = vec![Region {
383 chr: "chr999".to_string(),
384 start: 50,
385 end: 150,
386 rest: None,
387 }];
388
389 let tokenized = tokenizer.tokenize(®ions);
390 assert!(tokenized.is_ok());
391 let tokenized = tokenized.unwrap();
392
393 assert_eq!(tokenized.len(), 1);
394 }
395
396 #[rstest]
397 fn test_tokenize_on_two_crhoms() {
398 let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
399 let tokenizer =
400 Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
401
402 let regions = vec![
403 Region {
404 chr: "chr1".to_string(),
405 start: 151399441,
406 end: 151399547,
407 rest: None,
408 },
409 Region {
410 chr: "chr2".to_string(),
411 start: 203871220,
412 end: 203871381,
413 rest: None,
414 },
415 ];
416
417 let tokenized = tokenizer.tokenize(®ions);
418 assert!(tokenized.is_ok());
419
420 let tokenized = tokenized.unwrap();
421 assert_eq!(tokenized.len(), 2);
422
423 assert_eq!(tokenized[0], "chr1:151399431-151399527");
425 assert_eq!(tokenizer.convert_token_to_id(&tokenized[0]), Some(6));
426
427 assert_eq!(tokenized[1], "chr2:203871200-203871375");
429 assert_eq!(tokenizer.convert_token_to_id(&tokenized[1]), Some(7));
430 }
431
432 #[rstest]
433 fn test_tokenize_on_two_chroms_ailist() {
434 let cfg_path = "../tests/data/tokenizers/tokenizer_ailist.toml";
435 let tokenizer =
436 Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
437
438 let regions = vec![
439 Region {
440 chr: "chr1".to_string(),
441 start: 151399441,
442 end: 151399547,
443 rest: None,
444 },
445 Region {
446 chr: "chr2".to_string(),
447 start: 203871220,
448 end: 203871381,
449 rest: None,
450 },
451 ];
452
453 let tokenized = tokenizer.tokenize(®ions);
454 assert!(tokenized.is_ok());
455
456 let tokenized = tokenized.unwrap();
457 assert_eq!(tokenized.len(), 2);
458
459 assert_eq!(tokenized[0], "chr1:151399431-151399527");
461 assert_eq!(tokenizer.convert_token_to_id(&tokenized[0]), Some(6));
462
463 assert_eq!(tokenized[1], "chr2:203871200-203871375");
465 assert_eq!(tokenizer.convert_token_to_id(&tokenized[1]), Some(7));
466 }
467
468 #[rstest]
469 fn test_tokenize_with_multi_overlap() {
470 let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
471 let tokenizer =
472 Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
473
474 let regions = vec![Region {
475 chr: "chr2".to_string(),
476 start: 203871346,
477 end: 203871616,
478 rest: None,
479 }];
480
481 let tokenized = tokenizer.tokenize(®ions);
482 assert!(tokenized.is_ok());
483
484 let tokenized = tokenized.unwrap();
485 assert_eq!(tokenized.len(), 2);
486
487 assert_eq!(tokenized[0], "chr2:203871200-203871375");
489 assert_eq!(tokenizer.convert_token_to_id(&tokenized[0]), Some(7));
490
491 assert_eq!(tokenized[1], "chr2:203871387-203871588");
493 assert_eq!(tokenizer.convert_token_to_id(&tokenized[1]), Some(8));
494 }
495}