1use std::collections::{HashMap, HashSet};
24
25use crate::detokenize::{Detokenizer, DetokenizeOptions};
26use crate::longest_match::Tokenize;
27use crate::map::TokenizerMap;
28use crate::tokenize::ITokenizer;
29
30pub struct Translator {
32 pub from_id: String,
33 pub to_id: String,
34 from_detok: Detokenizer,
35 to_tok: Box<dyn ITokenizer>,
36 text_buffer: String,
37}
38
39impl Translator {
40 pub fn new(from_map: &TokenizerMap, to_map: &TokenizerMap) -> Self {
42 Self {
43 from_id: from_map.id.clone(),
44 to_id: to_map.id.clone(),
45 from_detok: Detokenizer::new(from_map),
46 to_tok: Tokenize::pick(to_map),
47 text_buffer: String::new(),
48 }
49 }
50
51 pub fn translate(&mut self, ids: &[u32], partial: bool) -> Vec<u32> {
57 let text = self
58 .from_detok
59 .render(ids, DetokenizeOptions { partial, render_special: false });
60 if !text.is_empty() {
61 self.text_buffer.push_str(&text);
62 }
63
64 if !partial {
65 let all_text = std::mem::take(&mut self.text_buffer);
66 return self.to_tok.encode(&all_text);
67 }
68
69 let safe = find_last_safe_boundary(&self.text_buffer);
70 if safe == 0 {
71 return Vec::new();
72 }
73 let to_encode: String = self.text_buffer.drain(..safe).collect();
75 self.to_tok.encode(&to_encode)
76 }
77
78 pub fn finish(&mut self) -> Vec<u32> {
80 self.translate(&[], false)
81 }
82
83 pub fn reset(&mut self) {
85 self.from_detok.reset();
86 self.text_buffer.clear();
87 }
88}
89
90fn is_whitespace_cp(c: char) -> bool {
93 matches!(
94 c,
95 ' ' | '\t'
96 | '\n'
97 | '\r'
98 | '\x0B'
99 | '\x0C'
100 | '\u{00A0}'
101 | '\u{2028}'
102 | '\u{2029}'
103 | '\u{3000}'
104 )
105}
106
107fn find_last_safe_boundary(buf: &str) -> usize {
110 let mut last_after: usize = 0;
111 for (i, c) in buf.char_indices() {
112 if is_whitespace_cp(c) {
113 last_after = i + c.len_utf8();
114 }
115 }
116 last_after
117}
118
119pub fn translate_one_shot(
121 from_map: &TokenizerMap,
122 to_map: &TokenizerMap,
123 ids: &[u32],
124) -> Vec<u32> {
125 let mut tr = Translator::new(from_map, to_map);
126 tr.translate(ids, false)
127}
128
129pub fn static_translation_table(
138 from_map: &TokenizerMap,
139 to_map: &TokenizerMap,
140) -> HashMap<u32, Vec<u32>> {
141 let mut detok = Detokenizer::new(from_map);
142 let tok = Tokenize::pick(to_map);
143 let mut result: HashMap<u32, Vec<u32>> = HashMap::new();
144
145 let mut special_ids: HashSet<u32> = HashSet::new();
146 if let Some(specials) = &from_map.special_tokens {
147 for &v in specials.values() {
148 special_ids.insert(v);
149 }
150 }
151
152 if let Some(vocab) = &from_map.vocab {
153 for &id in vocab.values() {
154 if special_ids.contains(&id) {
155 continue;
156 }
157 let text = detok.render(
158 &[id],
159 DetokenizeOptions { partial: false, render_special: false },
160 );
161 if text.is_empty() {
162 detok.reset();
163 continue;
164 }
165 result.insert(id, tok.encode(&text));
166 detok.reset();
167 }
168 }
169
170 if let Some(tokens) = &from_map.tokens {
171 for id_str in tokens.keys() {
172 let Ok(id) = id_str.parse::<u32>() else {
173 continue;
174 };
175 if special_ids.contains(&id) || result.contains_key(&id) {
176 continue;
177 }
178 let text = detok.render(
179 &[id],
180 DetokenizeOptions { partial: false, render_special: false },
181 );
182 if text.is_empty() {
183 detok.reset();
184 continue;
185 }
186 result.insert(id, tok.encode(&text));
187 detok.reset();
188 }
189 }
190
191 result
192}