gtars_tokenizers/
tokenizer.rs

1use std::collections::HashMap as StdHashMap;
2use std::path::{Path, PathBuf};
3
4use fxhash::FxHashMap as HashMap;
5
6use gtars_core::models::Region;
7use gtars_overlaprs::Overlapper;
8
9use crate::utils::create_tokenize_core_from_universe;
10
11use super::config::{TokenizerConfig, TokenizerInputFileType, TokenizerType};
12use super::error::TokenizerError;
13use super::universe::Universe;
14use super::utils::prepare_universe_and_special_tokens;
15use super::utils::special_tokens::SpecialTokens;
16
17#[cfg(feature = "huggingface")]
18use hf_hub::api::sync::Api;
19
20pub const DEFAULT_UNIVERSE_FILENAME: &str = "universe.bed.gz";
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct Token {
24    pub id: u32,
25    pub value: String,
26}
27
28impl Token {
29    pub fn new(id: u32, value: String) -> Self {
30        Self { id, value }
31    }
32}
33
34pub struct Tokenizer {
35    core: HashMap<String, Box<dyn Overlapper<u32, u32>>>,
36    universe: Universe,
37    special_tokens: SpecialTokens,
38}
39
40impl Tokenizer {
41    ///
42    /// Create a new tokenizer
43    ///
44    pub fn new(
45        core: HashMap<String, Box<dyn Overlapper<u32, u32>>>,
46        universe: Universe,
47        special_tokens: SpecialTokens,
48    ) -> Self {
49        Self {
50            core,
51            universe,
52            special_tokens,
53        }
54    }
55
56    ///
57    /// Create a new tokenizer from a config file
58    ///
59    pub fn from_config<P: AsRef<Path>>(cfg_path: P) -> Result<Self, TokenizerError> {
60        let config = TokenizerConfig::try_from(cfg_path.as_ref())?;
61
62        let config_path = cfg_path.as_ref().parent().unwrap();
63        let universe_path = config_path.join(&config.universe);
64        let special_tokens = match config.special_tokens {
65            Some(tokens) => SpecialTokens::from(tokens),
66            None => SpecialTokens::default(),
67        };
68
69        let (universe, special_tokens) =
70            prepare_universe_and_special_tokens(universe_path, special_tokens)?;
71
72        let core = create_tokenize_core_from_universe(
73            &universe,
74            config.tokenizer_type.unwrap_or(TokenizerType::Bits),
75        );
76
77        Ok(Tokenizer {
78            core,
79            special_tokens,
80            universe,
81        })
82    }
83
84    ///
85    /// Create a new tokenizer from a bed file
86    ///
87    pub fn from_bed<P: AsRef<Path>>(bed_path: P) -> Result<Self, TokenizerError> {
88        let special_tokens = SpecialTokens::default();
89        let (universe, special_tokens) =
90            prepare_universe_and_special_tokens(bed_path.as_ref(), special_tokens)?;
91        let core = create_tokenize_core_from_universe(&universe, TokenizerType::Bits);
92
93        Ok(Tokenizer {
94            core,
95            special_tokens,
96            universe,
97        })
98    }
99
100    ///
101    /// Create a new tokenizer from a pre-trained model
102    ///
103    #[cfg(feature = "huggingface")]
104    pub fn from_pretrained<P: AsRef<Path>>(path: P) -> Result<Self, TokenizerError> {
105        // if local
106        let universe_file_path: PathBuf = if path.as_ref().exists() {
107            path.as_ref().join(DEFAULT_UNIVERSE_FILENAME)
108        } else {
109            let api = Api::new().unwrap();
110            let repo = api.model(
111                path.as_ref()
112                    .to_str()
113                    .expect("Path is not valid UTF-8")
114                    .to_string(),
115            );
116            repo.get(DEFAULT_UNIVERSE_FILENAME).unwrap()
117        };
118        let file_type = TokenizerInputFileType::from_path(universe_file_path.as_path())?;
119        match file_type {
120            TokenizerInputFileType::Toml => Tokenizer::from_config(universe_file_path),
121            TokenizerInputFileType::Bed => Tokenizer::from_bed(universe_file_path),
122            TokenizerInputFileType::BedGz => Tokenizer::from_bed(universe_file_path),
123        }
124    }
125
126    ///
127    /// Create a new tokenizer from a file, automatically detecting the type
128    ///
129    pub fn from_auto<P: AsRef<Path>>(path: P) -> Result<Self, TokenizerError> {
130        let file_type = TokenizerInputFileType::from_path(path.as_ref())?;
131        match file_type {
132            TokenizerInputFileType::Toml => Tokenizer::from_config(path),
133            TokenizerInputFileType::Bed => Tokenizer::from_bed(path),
134            TokenizerInputFileType::BedGz => Tokenizer::from_bed(path),
135        }
136    }
137
138    pub fn tokenize(&self, regions: &[Region]) -> Result<Vec<String>, TokenizerError> {
139        let tokens = regions
140            .iter()
141            .filter_map(|region| {
142                self.core.get(&region.chr).map(|tree| {
143                    tree.find(region.start, region.end)
144                        .iter()
145                        .map(|i| i.val)
146                        .collect::<Vec<u32>>()
147                })
148            })
149            .flatten()
150            .map(|id| Token {
151                value: self.universe.convert_id_to_token(id).unwrap(),
152                id,
153            })
154            .collect::<Vec<Token>>();
155
156        if tokens.is_empty() {
157            return Ok(vec![self.special_tokens.unk.clone()]);
158        }
159
160        Ok(tokens.into_iter().map(|t| t.value).collect())
161    }
162
163    pub fn encode(&self, regions: &[Region]) -> Result<Vec<u32>, TokenizerError> {
164        let tokenized = self.tokenize(regions)?;
165        Ok(tokenized
166            .into_iter()
167            .map(|t| self.convert_token_to_id(&t).unwrap())
168            .collect())
169    }
170
171    pub fn decode(&self, ids: &[u32]) -> Result<Vec<String>, TokenizerError> {
172        let decoded: Vec<String> = ids
173            .iter()
174            .map(|id| {
175                self.universe
176                    .convert_id_to_token(*id)
177                    .unwrap_or(self.special_tokens.unk.clone())
178            })
179            .collect();
180        Ok(decoded)
181    }
182
183    pub fn convert_token_to_id(&self, token: &str) -> Option<u32> {
184        self.universe.convert_token_to_id(token)
185    }
186
187    pub fn convert_id_to_token(&self, id: u32) -> Option<String> {
188        self.universe.convert_id_to_token(id)
189    }
190
191    pub fn get_vocab_size(&self) -> usize {
192        self.universe.len()
193    }
194
195    pub fn get_vocab(&self) -> StdHashMap<String, u32> {
196        self.universe.region_to_id.clone()
197    }
198
199    pub fn get_unk_token(&self) -> String {
200        self.special_tokens.unk.clone()
201    }
202
203    pub fn get_pad_token(&self) -> String {
204        self.special_tokens.pad.clone()
205    }
206
207    pub fn get_mask_token(&self) -> String {
208        self.special_tokens.mask.clone()
209    }
210
211    pub fn get_cls_token(&self) -> String {
212        self.special_tokens.cls.clone()
213    }
214
215    pub fn get_eos_token(&self) -> String {
216        self.special_tokens.eos.clone()
217    }
218
219    pub fn get_bos_token(&self) -> String {
220        self.special_tokens.bos.clone()
221    }
222
223    pub fn get_sep_token(&self) -> String {
224        self.special_tokens.sep.clone()
225    }
226
227    // ids
228    pub fn get_unk_token_id(&self) -> u32 {
229        self.convert_token_to_id(&self.special_tokens.unk).unwrap()
230    }
231
232    pub fn get_pad_token_id(&self) -> u32 {
233        self.convert_token_to_id(&self.special_tokens.pad).unwrap()
234    }
235
236    pub fn get_mask_token_id(&self) -> u32 {
237        self.convert_token_to_id(&self.special_tokens.mask).unwrap()
238    }
239
240    pub fn get_cls_token_id(&self) -> u32 {
241        self.convert_token_to_id(&self.special_tokens.cls).unwrap()
242    }
243
244    pub fn get_eos_token_id(&self) -> u32 {
245        self.convert_token_to_id(&self.special_tokens.eos).unwrap()
246    }
247
248    pub fn get_bos_token_id(&self) -> u32 {
249        self.convert_token_to_id(&self.special_tokens.bos).unwrap()
250    }
251
252    pub fn get_sep_token_id(&self) -> u32 {
253        self.convert_token_to_id(&self.special_tokens.sep).unwrap()
254    }
255
256    pub fn get_special_tokens_mask(&self, tokens: &[String]) -> Vec<bool> {
257        tokens
258            .iter()
259            .map(|token| {
260                token == &self.special_tokens.unk
261                    || token == &self.special_tokens.pad
262                    || token == &self.special_tokens.mask
263                    || token == &self.special_tokens.cls
264                    || token == &self.special_tokens.eos
265                    || token == &self.special_tokens.bos
266                    || token == &self.special_tokens.sep
267            })
268            .collect()
269    }
270
271    pub fn get_special_tokens(&self) -> &SpecialTokens {
272        &self.special_tokens
273    }
274
275    pub fn get_universe(&self) -> &Universe {
276        &self.universe
277    }
278}
279
280#[cfg(test)]
281mod tokenizer_tests {
282    use super::*;
283
284    use pretty_assertions::assert_eq;
285    use rstest::*;
286
287    #[fixture]
288    fn anndata_path() -> String {
289        "../tests/data/tokenizers/pbmc_hg38.h5ad".to_string()
290    }
291
292    #[rstest]
293    fn test_tokenizer_creation_from_config() {
294        let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
295        let tokenizer =
296            Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
297        assert_eq!(tokenizer.get_vocab_size(), 32); // 25 regions + 7 special tokens
298    }
299
300    #[rstest]
301    fn test_tokenizer_creation_from_bed() {
302        let bed_path = "../tests/data/tokenizers/peaks.bed";
303        let tokenizer =
304            Tokenizer::from_bed(bed_path).expect("Failed to create tokenizer from config.");
305        assert_eq!(tokenizer.get_vocab_size(), 32); // 25 regions + 7 special tokens
306    }
307
308    #[rstest]
309    fn test_tokenizer_creation_from_bed_gz() {
310        let bed_path = "../tests/data/tokenizers/peaks.bed.gz";
311        let tokenizer =
312            Tokenizer::from_bed(bed_path).expect("Failed to create tokenizer from config.");
313        assert_eq!(tokenizer.get_vocab_size(), 32); // 25 regions + 7 special tokens
314    }
315
316    #[rstest]
317    fn test_tokenizer_creation_auto_all() {
318        let bed_path = "../tests/data/tokenizers/peaks.bed";
319        let tokenizer =
320            Tokenizer::from_auto(bed_path).expect("Failed to create tokenizer from config.");
321        assert_eq!(tokenizer.get_vocab_size(), 32); // 25 regions + 7 special tokens
322
323        let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
324        let tokenizer =
325            Tokenizer::from_auto(cfg_path).expect("Failed to create tokenizer from config.");
326        assert_eq!(tokenizer.get_vocab_size(), 32); // 25 regions + 7 special tokens
327
328        let bed_path = "../tests/data/tokenizers/peaks.bed.gz";
329        let tokenizer =
330            Tokenizer::from_auto(bed_path).expect("Failed to create tokenizer from config.");
331        assert_eq!(tokenizer.get_vocab_size(), 32); // 25 regions + 7 special tokens
332    }
333
334    #[rstest]
335    fn test_tokenizer_bad_tokenizer_type() {
336        let cfg_path = "../tests/data/tokenizers/tokenizer_bad_ttype.toml";
337        let tokenizer = Tokenizer::from_config(cfg_path);
338        assert_eq!(tokenizer.is_err(), true);
339    }
340
341    #[rstest]
342    fn test_tokenizer_custom_special_tokens() {
343        let cfg_path = "../tests/data/tokenizers/tokenizer_custom_specials.toml";
344        let tokenizer =
345            Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
346
347        assert_eq!(tokenizer.get_vocab_size(), 32); // 25 regions + 7 special tokens
348
349        // check that unk was overridden
350        assert_eq!(tokenizer.get_unk_token(), "<UNKNOWN>");
351
352        // check that pad didnt change
353        assert_eq!(tokenizer.get_pad_token(), "<pad>");
354    }
355
356    #[rstest]
357    fn test_tokenize_single_region_not_overlapping() {
358        let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
359        let tokenizer =
360            Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
361
362        let regions = vec![Region {
363            chr: "chr1".to_string(),
364            start: 50,
365            end: 150,
366            rest: None,
367        }];
368
369        let tokenized = tokenizer.tokenize(&regions);
370        assert!(tokenized.is_ok());
371        let tokenized = tokenized.unwrap();
372        assert_eq!(tokenized.len(), 1);
373        assert_eq!(tokenized[0], "<unk>");
374    }
375
376    #[rstest]
377    fn test_tokenize_unk_chrom() {
378        let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
379        let tokenizer =
380            Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
381
382        let regions = vec![Region {
383            chr: "chr999".to_string(),
384            start: 50,
385            end: 150,
386            rest: None,
387        }];
388
389        let tokenized = tokenizer.tokenize(&regions);
390        assert!(tokenized.is_ok());
391        let tokenized = tokenized.unwrap();
392
393        assert_eq!(tokenized.len(), 1);
394    }
395
396    #[rstest]
397    fn test_tokenize_on_two_crhoms() {
398        let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
399        let tokenizer =
400            Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
401
402        let regions = vec![
403            Region {
404                chr: "chr1".to_string(),
405                start: 151399441,
406                end: 151399547,
407                rest: None,
408            },
409            Region {
410                chr: "chr2".to_string(),
411                start: 203871220,
412                end: 203871381,
413                rest: None,
414            },
415        ];
416
417        let tokenized = tokenizer.tokenize(&regions);
418        assert!(tokenized.is_ok());
419
420        let tokenized = tokenized.unwrap();
421        assert_eq!(tokenized.len(), 2);
422
423        // chr1:151399432-151399527 -- CONFIRMED IN IGV
424        assert_eq!(tokenized[0], "chr1:151399431-151399527");
425        assert_eq!(tokenizer.convert_token_to_id(&tokenized[0]), Some(6));
426
427        // chr2:203871201-203871375 -- CONFIRMED IN IGV
428        assert_eq!(tokenized[1], "chr2:203871200-203871375");
429        assert_eq!(tokenizer.convert_token_to_id(&tokenized[1]), Some(7));
430    }
431
432    #[rstest]
433    fn test_tokenize_on_two_chroms_ailist() {
434        let cfg_path = "../tests/data/tokenizers/tokenizer_ailist.toml";
435        let tokenizer =
436            Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
437
438        let regions = vec![
439            Region {
440                chr: "chr1".to_string(),
441                start: 151399441,
442                end: 151399547,
443                rest: None,
444            },
445            Region {
446                chr: "chr2".to_string(),
447                start: 203871220,
448                end: 203871381,
449                rest: None,
450            },
451        ];
452
453        let tokenized = tokenizer.tokenize(&regions);
454        assert!(tokenized.is_ok());
455
456        let tokenized = tokenized.unwrap();
457        assert_eq!(tokenized.len(), 2);
458
459        // chr1:151399432-151399527 -- CONFIRMED IN IGV
460        assert_eq!(tokenized[0], "chr1:151399431-151399527");
461        assert_eq!(tokenizer.convert_token_to_id(&tokenized[0]), Some(6));
462
463        // chr2:203871201-203871375 -- CONFIRMED IN IGV
464        assert_eq!(tokenized[1], "chr2:203871200-203871375");
465        assert_eq!(tokenizer.convert_token_to_id(&tokenized[1]), Some(7));
466    }
467
468    #[rstest]
469    fn test_tokenize_with_multi_overlap() {
470        let cfg_path = "../tests/data/tokenizers/tokenizer.toml";
471        let tokenizer =
472            Tokenizer::from_config(cfg_path).expect("Failed to create tokenizer from config.");
473
474        let regions = vec![Region {
475            chr: "chr2".to_string(),
476            start: 203871346,
477            end: 203871616,
478            rest: None,
479        }];
480
481        let tokenized = tokenizer.tokenize(&regions);
482        assert!(tokenized.is_ok());
483
484        let tokenized = tokenized.unwrap();
485        assert_eq!(tokenized.len(), 2);
486
487        // chr2:203871201-203871375 -- CONFIRMED IN IGV
488        assert_eq!(tokenized[0], "chr2:203871200-203871375");
489        assert_eq!(tokenizer.convert_token_to_id(&tokenized[0]), Some(7));
490
491        // chr2:203871388-203871588 -- CONFIRMED IN IGV
492        assert_eq!(tokenized[1], "chr2:203871387-203871588");
493        assert_eq!(tokenizer.convert_token_to_id(&tokenized[1]), Some(8));
494    }
495}