rten_generate/
text_decoder.rs1use rten_text::models::DecodeError;
4use rten_text::{TokenId, Tokenizer, TokenizerError};
5
6use crate::generator::{GeneratorError, GeneratorItem};
7
8pub struct TextDecoder<'a, G: Iterator<Item = GeneratorItem>> {
14 generator: G,
15 tokenizer: &'a Tokenizer,
16}
17
18impl<'a, G> TextDecoder<'a, G>
19where
20 G: Iterator<Item = GeneratorItem>,
21{
22 pub fn wrap(generator: G, tokenizer: &'a Tokenizer) -> TextDecoder<'a, G> {
24 TextDecoder {
25 generator,
26 tokenizer,
27 }
28 }
29
30 pub fn with_ids(self) -> TextDecoderWithIds<'a, G> {
32 TextDecoderWithIds(self)
33 }
34
35 fn next_with_ids(&mut self) -> Option<Result<(Vec<TokenId>, String), GeneratorError>> {
36 let mut token_buf = Vec::new();
39
40 for token in self.generator.by_ref() {
41 let token = match token {
42 Ok(tok) => tok,
43 Err(err) => return Some(Err(err)),
44 };
45
46 token_buf.push(token);
47
48 let text = self.tokenizer.decode(&token_buf);
49 match text {
50 Ok(text) => return Some(Ok((token_buf, text))),
51 Err(TokenizerError::DecodeError(DecodeError::InvalidUtf8)) => {
52 continue;
55 }
56 Err(err) => {
57 return Some(Err(GeneratorError::DecodeError(err)));
58 }
59 }
60 }
61
62 if !token_buf.is_empty() {
63 return Some(Ok((token_buf, String::new())));
64 }
65
66 None
67 }
68}
69
70impl<G: Iterator<Item = GeneratorItem>> Iterator for TextDecoder<'_, G> {
71 type Item = Result<String, GeneratorError>;
73
74 fn next(&mut self) -> Option<Self::Item> {
81 let next = self.next_with_ids()?;
82 Some(next.map(|(_id, text)| text))
83 }
84}
85
86pub struct TextDecoderWithIds<'a, G: Iterator<Item = GeneratorItem>>(TextDecoder<'a, G>);
89
90impl<G: Iterator<Item = GeneratorItem>> Iterator for TextDecoderWithIds<'_, G> {
91 type Item = Result<(Vec<TokenId>, String), GeneratorError>;
94
95 fn next(&mut self) -> Option<Self::Item> {
102 self.0.next_with_ids()
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use std::collections::HashMap;
109
110 use rten_text::models::{Bpe, BpeOptions, WordPiece};
111 use rten_text::pre_tokenizers::Split;
112 use rten_text::{TokenId, Tokenizer};
113
114 use crate::{GeneratorError, GeneratorUtils};
115
116 fn create_tokenizer() -> Tokenizer {
119 let vocab: HashMap<String, TokenId> = [("one", 1), ("two", 2), ("three", 3)]
120 .into_iter()
121 .map(|(s, id)| (s.to_string(), id))
122 .collect();
123 let model = WordPiece::from_vocab(vocab, Default::default());
124 Tokenizer::new(model, Default::default())
125 }
126
127 fn create_bpe_tokenizer() -> Tokenizer {
130 let model = Bpe::new(BpeOptions::default()).unwrap();
131 Tokenizer::new(model, Default::default()).with_pre_tokenizer(Box::new(Split::gpt2()))
132 }
133
134 #[test]
135 fn test_decode() {
136 let tokenizer = create_tokenizer();
137 let generator = [1, 2, 3].into_iter().map(Ok);
138 let tokens: Vec<_> = generator
139 .decode(&tokenizer)
140 .map(|tok| tok.map_err(|e| e.to_string()))
141 .collect();
142 assert_eq!(tokens, ["one", "two", "three"].map(|s| Ok(s.to_string())));
143 }
144
145 #[test]
146 fn test_decode_with_ids() {
147 let tokenizer = create_tokenizer();
148 let generator = [1, 2, 3].into_iter().map(Ok);
149 let tokens: Vec<_> = generator
150 .decode(&tokenizer)
151 .with_ids()
152 .map(|result| result.map_err(|e| e.to_string()))
153 .collect();
154 assert_eq!(
155 tokens,
156 [
157 Ok(([1].into(), "one".into())),
158 Ok(([2].into(), "two".into())),
159 Ok(([3].into(), "three".into())),
160 ]
161 );
162 }
163
164 #[test]
165 fn test_decode_partial_utf8() {
166 let tokenizer = create_bpe_tokenizer();
167
168 let token_ids = tokenizer.encode("😊", None).unwrap().into_token_ids();
172 assert!(token_ids.len() > 1);
173 let generator = token_ids.into_iter().map(|tok_id| Ok(tok_id as u32));
174
175 let tokens: Vec<_> = generator
176 .decode(&tokenizer)
177 .map(|tok| tok.map_err(|e| e.to_string()))
178 .collect();
179
180 assert_eq!(tokens, ["😊"].map(|s| Ok(s.to_string())));
181 }
182
183 #[test]
184 fn test_decode_ids_partial_utf8() {
185 let tokenizer = create_bpe_tokenizer();
186
187 let token_ids = tokenizer.encode("😊", None).unwrap().into_token_ids();
191 assert!(token_ids.len() > 1);
192 let generator = token_ids
193 .into_iter()
194 .take(1)
195 .map(|tok_id| Ok(tok_id as u32));
196
197 let tokens: Vec<_> = generator
198 .decode(&tokenizer)
199 .with_ids()
200 .map(|result| result.map_err(|e| e.to_string()))
201 .collect();
202
203 assert_eq!(tokens, [Ok(([172].into(), "".into()))]);
204 }
205
206 #[test]
207 fn test_generate_error() {
208 let tokenizer = create_tokenizer();
209 let generator = [
210 Ok(1),
211 Err(GeneratorError::GenerateError("oh no".to_string().into())),
212 Ok(3),
213 ]
214 .into_iter();
215
216 let tokens: Vec<_> = generator
217 .decode(&tokenizer)
218 .map(|tok| tok.map_err(|e| e.to_string()))
219 .collect();
220
221 assert_eq!(
222 tokens,
223 [
224 Ok("one".to_string()),
225 Err("generation error: oh no".to_string()),
226 Ok("three".to_string())
227 ]
228 );
229 }
230
231 #[test]
232 fn test_decode_error() {
233 let tokenizer = create_tokenizer();
234 let generator = [1, 5, 3].into_iter().map(Ok);
235
236 let tokens: Vec<_> = generator
237 .decode(&tokenizer)
238 .map(|tok| tok.map_err(|e| e.to_string()))
239 .collect();
240
241 assert_eq!(
242 tokens,
243 [
244 Ok("one".to_string()),
245 Err("decode error: decoding failed: cannot decode unknown token ID 5".to_string()),
246 Ok("three".to_string())
247 ]
248 );
249 }
250}