gtars_tokenizers/
config.rs1use std::fs::read_to_string;
2use std::path::Path;
3
4use thiserror::Error;
5
6use serde::{Deserialize, Serialize};
7use std::ffi::OsStr;
8
9#[derive(Deserialize, Serialize, Debug, PartialEq)]
10#[serde(rename_all = "lowercase")]
11pub enum SpecialToken {
12 Unk,
13 Pad,
14 Mask,
15 Cls,
16 Bos,
17 Eos,
18 Sep,
19}
20
21#[derive(Deserialize, Serialize, Debug, PartialEq)]
22pub struct SpecialTokenAssignment {
23 pub name: SpecialToken,
24 pub token: String, }
26
27#[derive(Deserialize, Serialize, Debug, PartialEq)]
28#[serde(rename_all = "lowercase")]
29pub enum TokenizerType {
30 #[serde(rename = "bits")]
31 Bits,
32 #[serde(rename = "ailist")]
33 AIList,
34}
35
36#[derive(Deserialize, Serialize, Debug, PartialEq)]
37pub struct TokenizerConfig {
38 pub universe: String,
39 pub special_tokens: Option<Vec<SpecialTokenAssignment>>,
40 pub tokenizer_type: Option<TokenizerType>,
41}
42
43#[derive(Debug)]
44pub enum TokenizerInputFileType {
45 Toml,
46 Bed,
47 BedGz,
48}
49
50#[derive(Error, Debug)]
51pub enum TokenizerConfigError {
52 #[error(
53 "Missing or invalid file extension in tokenizer config file. It must be `toml`, `bed` or `bed.gz`"
54 )]
55 InvalidFileType,
56 #[error("Invalid tokenizer type in config file")]
57 InvalidTokenizerType,
58 #[error(transparent)]
59 Io(#[from] std::io::Error),
60 #[error(transparent)]
61 Toml(#[from] toml::de::Error),
62}
63
64pub type TokenizerConfigResult<T> = std::result::Result<T, TokenizerConfigError>;
65
66impl TokenizerInputFileType {
67 pub fn from_path(path: &Path) -> TokenizerConfigResult<Self> {
75 match path.extension().and_then(OsStr::to_str) {
76 Some("gz") => {
77 let file_stem = path
78 .file_stem()
79 .ok_or(TokenizerConfigError::InvalidFileType)?;
80 let ext2 = Path::new(file_stem)
81 .extension()
82 .and_then(OsStr::to_str)
83 .ok_or(TokenizerConfigError::InvalidFileType)?;
84 if ext2 == "bed" {
85 Ok(TokenizerInputFileType::BedGz)
86 } else {
87 Err(TokenizerConfigError::InvalidFileType)
88 }
89 }
90 Some("toml") => Ok(TokenizerInputFileType::Toml),
91 Some("bed") => Ok(TokenizerInputFileType::Bed),
92 _ => Err(TokenizerConfigError::InvalidFileType),
93 }
94 }
95}
96
97impl TryFrom<&Path> for TokenizerConfig {
98 type Error = TokenizerConfigError;
99
100 fn try_from(path: &Path) -> Result<Self, Self::Error> {
101 let toml_str = read_to_string(path)?;
102 let config = toml::from_str(&toml_str)?;
103 Ok(config)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 use pretty_assertions::assert_eq;
112 use rstest::rstest;
113
114 use std::path::PathBuf;
115
116 #[rstest]
117 fn test_try_from_toml() {
118 let path = PathBuf::from("../tests/data/tokenizers/tokenizer.toml");
119 let result = TokenizerConfig::try_from(path.as_path());
120 assert_eq!(result.is_ok(), true);
121 }
122
123 #[rstest]
124 fn test_from_path_for_toml_extension() {
125 let path = PathBuf::from("dummy.toml");
126 let file_type = TokenizerInputFileType::from_path(path.as_path());
127 assert_eq!(matches!(file_type, Ok(TokenizerInputFileType::Toml)), true);
128 }
129
130 #[rstest]
131 fn test_from_path_for_invalid_extension() {
132 let path = PathBuf::from("invalid.xyz");
133 let file_type = TokenizerInputFileType::from_path(&path);
134 assert_eq!(file_type.is_err(), true);
135 }
136
137 #[rstest]
138 fn test_get_universe_name() {
139 let path = PathBuf::from("../tests/data/tokenizers/tokenizer.toml");
140 let result = TokenizerConfig::try_from(path.as_path()).unwrap();
141
142 assert_eq!(result.universe, "peaks.bed.gz");
143 }
144
145 #[rstest]
146 fn test_get_special_tokens() {
147 let path = PathBuf::from("../tests/data/tokenizers/tokenizer_custom_specials.toml");
148 let result = TokenizerConfig::try_from(path.as_path()).unwrap();
149 let special_tokens = result.special_tokens;
150
151 assert_eq!(special_tokens.is_some(), true);
152 }
153}