gtars_tokenizers/
config.rs

1use std::fs::read_to_string;
2use std::path::Path;
3
4use thiserror::Error;
5
6use serde::{Deserialize, Serialize};
7use std::ffi::OsStr;
8
9#[derive(Deserialize, Serialize, Debug, PartialEq)]
10#[serde(rename_all = "lowercase")]
11pub enum SpecialToken {
12    Unk,
13    Pad,
14    Mask,
15    Cls,
16    Bos,
17    Eos,
18    Sep,
19}
20
21#[derive(Deserialize, Serialize, Debug, PartialEq)]
22pub struct SpecialTokenAssignment {
23    pub name: SpecialToken,
24    pub token: String, // must be valid chr:start-end
25}
26
27#[derive(Deserialize, Serialize, Debug, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum TokenizerType {
30    #[serde(rename = "bits")]
31    Bits,
32    #[serde(rename = "ailist")]
33    AIList,
34}
35
36#[derive(Deserialize, Serialize, Debug, PartialEq)]
37pub struct TokenizerConfig {
38    pub universe: String,
39    pub special_tokens: Option<Vec<SpecialTokenAssignment>>,
40    pub tokenizer_type: Option<TokenizerType>,
41}
42
43#[derive(Debug)]
44pub enum TokenizerInputFileType {
45    Toml,
46    Bed,
47    BedGz,
48}
49
50#[derive(Error, Debug)]
51pub enum TokenizerConfigError {
52    #[error(
53        "Missing or invalid file extension in tokenizer config file. It must be `toml`, `bed` or `bed.gz`"
54    )]
55    InvalidFileType,
56    #[error("Invalid tokenizer type in config file")]
57    InvalidTokenizerType,
58    #[error(transparent)]
59    Io(#[from] std::io::Error),
60    #[error(transparent)]
61    Toml(#[from] toml::de::Error),
62}
63
64pub type TokenizerConfigResult<T> = std::result::Result<T, TokenizerConfigError>;
65
66impl TokenizerInputFileType {
67    ///
68    /// Determine the type of the tokenizer input file based on its extension.
69    /// # Arguments
70    /// * `path` - A reference to a `Path` object representing the file path.
71    /// # Returns
72    /// * `TokenizerInputFileType` - An enum representing the type of the tokenizer input file.
73    ///
74    pub fn from_path(path: &Path) -> TokenizerConfigResult<Self> {
75        match path.extension().and_then(OsStr::to_str) {
76            Some("gz") => {
77                let file_stem = path
78                    .file_stem()
79                    .ok_or(TokenizerConfigError::InvalidFileType)?;
80                let ext2 = Path::new(file_stem)
81                    .extension()
82                    .and_then(OsStr::to_str)
83                    .ok_or(TokenizerConfigError::InvalidFileType)?;
84                if ext2 == "bed" {
85                    Ok(TokenizerInputFileType::BedGz)
86                } else {
87                    Err(TokenizerConfigError::InvalidFileType)
88                }
89            }
90            Some("toml") => Ok(TokenizerInputFileType::Toml),
91            Some("bed") => Ok(TokenizerInputFileType::Bed),
92            _ => Err(TokenizerConfigError::InvalidFileType),
93        }
94    }
95}
96
97impl TryFrom<&Path> for TokenizerConfig {
98    type Error = TokenizerConfigError;
99
100    fn try_from(path: &Path) -> Result<Self, Self::Error> {
101        let toml_str = read_to_string(path)?;
102        let config = toml::from_str(&toml_str)?;
103        Ok(config)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    use pretty_assertions::assert_eq;
112    use rstest::rstest;
113
114    use std::path::PathBuf;
115
116    #[rstest]
117    fn test_try_from_toml() {
118        let path = PathBuf::from("../tests/data/tokenizers/tokenizer.toml");
119        let result = TokenizerConfig::try_from(path.as_path());
120        assert_eq!(result.is_ok(), true);
121    }
122
123    #[rstest]
124    fn test_from_path_for_toml_extension() {
125        let path = PathBuf::from("dummy.toml");
126        let file_type = TokenizerInputFileType::from_path(path.as_path());
127        assert_eq!(matches!(file_type, Ok(TokenizerInputFileType::Toml)), true);
128    }
129
130    #[rstest]
131    fn test_from_path_for_invalid_extension() {
132        let path = PathBuf::from("invalid.xyz");
133        let file_type = TokenizerInputFileType::from_path(&path);
134        assert_eq!(file_type.is_err(), true);
135    }
136
137    #[rstest]
138    fn test_get_universe_name() {
139        let path = PathBuf::from("../tests/data/tokenizers/tokenizer.toml");
140        let result = TokenizerConfig::try_from(path.as_path()).unwrap();
141
142        assert_eq!(result.universe, "peaks.bed.gz");
143    }
144
145    #[rstest]
146    fn test_get_special_tokens() {
147        let path = PathBuf::from("../tests/data/tokenizers/tokenizer_custom_specials.toml");
148        let result = TokenizerConfig::try_from(path.as_path()).unwrap();
149        let special_tokens = result.special_tokens;
150
151        assert_eq!(special_tokens.is_some(), true);
152    }
153}