1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
//! Vocabulary management for the OxiBonsai tokenizer.
//!
//! Provides bidirectional mapping between token strings and integer IDs,
//! plus a separate special-token registry used for BOS/EOS/PAD/UNK handling.
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::error::{TokenizerError, TokenizerResult};
/// Bidirectional token ↔ ID vocabulary with special-token support.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Vocabulary {
token_to_id: HashMap<String, u32>,
id_to_token: HashMap<u32, String>,
special_tokens: HashMap<String, u32>,
}
impl Vocabulary {
/// Create an empty vocabulary.
pub fn new() -> Self {
Self::default()
}
/// Insert a regular (non-special) token ↔ ID pair.
///
/// If the ID or token already exists, the mapping is silently overwritten.
pub fn insert(&mut self, token: &str, id: u32) {
self.token_to_id.insert(token.to_owned(), id);
self.id_to_token.insert(id, token.to_owned());
}
/// Register a special token (e.g. `<s>`, `</s>`, `<unk>`, `<pad>`).
///
/// Special tokens are also inserted into the main maps so they can be
/// looked up by the standard `get_id` / `get_token` interface.
pub fn add_special(&mut self, token: &str, id: u32) {
self.special_tokens.insert(token.to_owned(), id);
self.insert(token, id);
}
/// Look up the integer ID for a token string.
///
/// Returns `None` if the token is not present in the vocabulary.
pub fn get_id(&self, token: &str) -> Option<u32> {
self.token_to_id.get(token).copied()
}
/// Look up the token string for a given integer ID.
///
/// Returns `None` if the ID is not present.
pub fn get_token(&self, id: u32) -> Option<&str> {
self.id_to_token.get(&id).map(|s| s.as_str())
}
/// Total number of tokens in the vocabulary (including special tokens).
pub fn size(&self) -> usize {
self.token_to_id.len()
}
/// Returns `true` if the vocabulary contains no tokens.
pub fn is_empty(&self) -> bool {
self.token_to_id.is_empty()
}
/// Returns `true` if the given token string is registered as a special token.
pub fn is_special_token(&self, token: &str) -> bool {
self.special_tokens.contains_key(token)
}
/// Returns `true` if the given ID corresponds to a special token.
pub fn is_special_id(&self, id: u32) -> bool {
self.special_tokens.values().any(|&v| v == id)
}
/// Iterate over all (token, id) pairs (regular + special).
pub fn iter(&self) -> impl Iterator<Item = (&str, u32)> {
self.token_to_id.iter().map(|(k, &v)| (k.as_str(), v))
}
/// Deserialize a vocabulary from a JSON object mapping token → id.
///
/// The JSON must be a flat object: `{ "<token>": <id>, ... }`.
/// Special tokens (those whose names start and end with `<` / `>`) are
/// automatically promoted to the special-token registry.
pub fn from_json(json: &str) -> TokenizerResult<Self> {
let raw: HashMap<String, u32> =
serde_json::from_str(json).map_err(|e| TokenizerError::InvalidJson(e.to_string()))?;
if raw.is_empty() {
return Err(TokenizerError::InvalidVocab(
"vocabulary JSON must not be empty".into(),
));
}
let mut vocab = Self::new();
for (token, id) in raw {
// Heuristic: treat tokens that look like `<something>` as special.
if token.starts_with('<') && token.ends_with('>') {
vocab.add_special(&token, id);
} else {
vocab.insert(&token, id);
}
}
Ok(vocab)
}
/// Serialize the vocabulary to a compact JSON object (token → id).
///
/// The output is always sorted by token string for determinism.
pub fn to_json(&self) -> String {
// Collect and sort for deterministic output.
let mut entries: Vec<(&str, u32)> = self.iter().collect();
entries.sort_by(|a, b| a.0.cmp(b.0));
let mut out = String::from('{');
for (i, (token, id)) in entries.iter().enumerate() {
if i > 0 {
out.push(',');
}
// Cheap manual JSON escaping for token strings (printable ASCII assumed).
out.push('"');
for ch in token.chars() {
match ch {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
c => out.push(c),
}
}
out.push_str("\":");
out.push_str(&id.to_string());
}
out.push('}');
out
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn insert_and_lookup() {
let mut v = Vocabulary::new();
v.insert("hello", 1);
v.insert("world", 2);
assert_eq!(v.get_id("hello"), Some(1));
assert_eq!(v.get_id("world"), Some(2));
assert_eq!(v.get_token(1), Some("hello"));
assert_eq!(v.get_token(99), None);
}
#[test]
fn special_tokens_are_found_in_main_maps() {
let mut v = Vocabulary::new();
v.add_special("<bos>", 0);
assert_eq!(v.get_id("<bos>"), Some(0));
assert!(v.is_special_token("<bos>"));
assert!(v.is_special_id(0));
}
#[test]
fn json_roundtrip() {
let mut v = Vocabulary::new();
v.insert("a", 3);
v.insert("b", 4);
v.add_special("<unk>", 0);
let json = v.to_json();
let v2 = Vocabulary::from_json(&json).expect("parse should succeed");
assert_eq!(v2.get_id("a"), Some(3));
assert_eq!(v2.get_id("b"), Some(4));
assert_eq!(v2.get_id("<unk>"), Some(0));
}
#[test]
fn empty_json_fails() {
assert!(Vocabulary::from_json("{}").is_err());
}
#[test]
fn invalid_json_fails() {
assert!(Vocabulary::from_json("not json").is_err());
}
}