rten_generate/
text_decoder.rs

1//! Iterator adapters to decode token IDs into text using `rten-text`.
2
3use rten_text::models::DecodeError;
4use rten_text::{TokenId, Tokenizer, TokenizerError};
5
6use crate::generator::{GeneratorError, GeneratorItem};
7
8/// Wraps a [`Generator`](crate::Generator) to decode the output token IDs from
9/// the model into text using a [`Tokenizer`].
10///
11/// This is normally created by calling [`decode`](crate::GeneratorUtils::decode)
12/// on a `Generator`.
13pub struct TextDecoder<'a, G: Iterator<Item = GeneratorItem>> {
14    generator: G,
15    tokenizer: &'a Tokenizer,
16}
17
18impl<'a, G> TextDecoder<'a, G>
19where
20    G: Iterator<Item = GeneratorItem>,
21{
22    /// Wrap a token generator and decode its outputs using `tokenizer`.
23    pub fn wrap(generator: G, tokenizer: &'a Tokenizer) -> TextDecoder<'a, G> {
24        TextDecoder {
25            generator,
26            tokenizer,
27        }
28    }
29
30    /// Return an iterator that yields both the decoded text and token IDs.
31    pub fn with_ids(self) -> TextDecoderWithIds<'a, G> {
32        TextDecoderWithIds(self)
33    }
34
35    fn next_with_ids(&mut self) -> Option<Result<(Vec<TokenId>, String), GeneratorError>> {
36        // Buffer that holds model output tokens until it forms a valid UTF-8
37        // sequence.
38        let mut token_buf = Vec::new();
39
40        for token in self.generator.by_ref() {
41            let token = match token {
42                Ok(tok) => tok,
43                Err(err) => return Some(Err(err)),
44            };
45
46            token_buf.push(token);
47
48            let text = self.tokenizer.decode(&token_buf);
49            match text {
50                Ok(text) => return Some(Ok((token_buf, text))),
51                Err(TokenizerError::DecodeError(DecodeError::InvalidUtf8)) => {
52                    // If the current token sequence doesn't correspond to a
53                    // complete UTF-8 sequence, add more tokens until it does.
54                    continue;
55                }
56                Err(err) => {
57                    return Some(Err(GeneratorError::DecodeError(err)));
58                }
59            }
60        }
61
62        if !token_buf.is_empty() {
63            return Some(Ok((token_buf, String::new())));
64        }
65
66        None
67    }
68}
69
70impl<G: Iterator<Item = GeneratorItem>> Iterator for TextDecoder<'_, G> {
71    /// The decoded string, or the error that occurred during generation.
72    type Item = Result<String, GeneratorError>;
73
74    /// Run the model repeatedly until it generates a sequence of tokens which
75    /// can be decoded into a valid UTF-8 sequence.
76    ///
77    /// Returns `Some(Ok(text))` if successful, `Some(Err(error))` if an error
78    /// occurs during generation or `None` if the end of output has been
79    /// reached.
80    fn next(&mut self) -> Option<Self::Item> {
81        let next = self.next_with_ids()?;
82        Some(next.map(|(_id, text)| text))
83    }
84}
85
86/// A variant of [`TextDecoder`] that yields both the token IDs and the decoded
87/// string.
88pub struct TextDecoderWithIds<'a, G: Iterator<Item = GeneratorItem>>(TextDecoder<'a, G>);
89
90impl<G: Iterator<Item = GeneratorItem>> Iterator for TextDecoderWithIds<'_, G> {
91    /// A pair of (token IDs, decoded string), or the error that occurred during
92    /// generation.
93    type Item = Result<(Vec<TokenId>, String), GeneratorError>;
94
95    /// Run the model repeatedly until it generates a sequence of tokens which
96    /// can be decoded into a valid UTF-8 sequence.
97    ///
98    /// Returns `Some(Ok((token_ids, text)))` if successful, `Some(Err(error))`
99    /// if an error occurs during generation or `None` if the end of output has
100    /// been reached.
101    fn next(&mut self) -> Option<Self::Item> {
102        self.0.next_with_ids()
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use std::collections::HashMap;
109
110    use rten_text::models::{Bpe, BpeOptions, WordPiece};
111    use rten_text::pre_tokenizers::Split;
112    use rten_text::{TokenId, Tokenizer};
113
114    use crate::{GeneratorError, GeneratorUtils};
115
116    /// Create a simple WordPiece tokenizer. This is essentially just a lookup
117    /// from token ID to string.
118    fn create_tokenizer() -> Tokenizer {
119        let vocab: HashMap<String, TokenId> = [("one", 1), ("two", 2), ("three", 3)]
120            .into_iter()
121            .map(|(s, id)| (s.to_string(), id))
122            .collect();
123        let model = WordPiece::from_vocab(vocab, Default::default());
124        Tokenizer::new(model, Default::default())
125    }
126
127    /// Create a BPE tokenizer with an empty vocab. This can encode and decode
128    /// arbitrary Unicode characters, by using one token per UTF-8 byte.
129    fn create_bpe_tokenizer() -> Tokenizer {
130        let model = Bpe::new(BpeOptions::default()).unwrap();
131        Tokenizer::new(model, Default::default()).with_pre_tokenizer(Box::new(Split::gpt2()))
132    }
133
134    #[test]
135    fn test_decode() {
136        let tokenizer = create_tokenizer();
137        let generator = [1, 2, 3].into_iter().map(Ok);
138        let tokens: Vec<_> = generator
139            .decode(&tokenizer)
140            .map(|tok| tok.map_err(|e| e.to_string()))
141            .collect();
142        assert_eq!(tokens, ["one", "two", "three"].map(|s| Ok(s.to_string())));
143    }
144
145    #[test]
146    fn test_decode_with_ids() {
147        let tokenizer = create_tokenizer();
148        let generator = [1, 2, 3].into_iter().map(Ok);
149        let tokens: Vec<_> = generator
150            .decode(&tokenizer)
151            .with_ids()
152            .map(|result| result.map_err(|e| e.to_string()))
153            .collect();
154        assert_eq!(
155            tokens,
156            [
157                Ok(([1].into(), "one".into())),
158                Ok(([2].into(), "two".into())),
159                Ok(([3].into(), "three".into())),
160            ]
161        );
162    }
163
164    #[test]
165    fn test_decode_partial_utf8() {
166        let tokenizer = create_bpe_tokenizer();
167
168        // Encode a character which will require multiple token IDs. This means
169        // the text decoder will need to loop until accumulated tokens decode
170        // to a valid UTF-8 sequence.
171        let token_ids = tokenizer.encode("😊", None).unwrap().into_token_ids();
172        assert!(token_ids.len() > 1);
173        let generator = token_ids.into_iter().map(|tok_id| Ok(tok_id as u32));
174
175        let tokens: Vec<_> = generator
176            .decode(&tokenizer)
177            .map(|tok| tok.map_err(|e| e.to_string()))
178            .collect();
179
180        assert_eq!(tokens, ["😊"].map(|s| Ok(s.to_string())));
181    }
182
183    #[test]
184    fn test_decode_ids_partial_utf8() {
185        let tokenizer = create_bpe_tokenizer();
186
187        // Encode a character which will require multiple token IDs, and feed
188        // only a prefix into the decoder. This means decoding will end with
189        // a buffer of excess IDs that cannot be decoded.
190        let token_ids = tokenizer.encode("😊", None).unwrap().into_token_ids();
191        assert!(token_ids.len() > 1);
192        let generator = token_ids
193            .into_iter()
194            .take(1)
195            .map(|tok_id| Ok(tok_id as u32));
196
197        let tokens: Vec<_> = generator
198            .decode(&tokenizer)
199            .with_ids()
200            .map(|result| result.map_err(|e| e.to_string()))
201            .collect();
202
203        assert_eq!(tokens, [Ok(([172].into(), "".into()))]);
204    }
205
206    #[test]
207    fn test_generate_error() {
208        let tokenizer = create_tokenizer();
209        let generator = [
210            Ok(1),
211            Err(GeneratorError::GenerateError("oh no".to_string().into())),
212            Ok(3),
213        ]
214        .into_iter();
215
216        let tokens: Vec<_> = generator
217            .decode(&tokenizer)
218            .map(|tok| tok.map_err(|e| e.to_string()))
219            .collect();
220
221        assert_eq!(
222            tokens,
223            [
224                Ok("one".to_string()),
225                Err("generation error: oh no".to_string()),
226                Ok("three".to_string())
227            ]
228        );
229    }
230
231    #[test]
232    fn test_decode_error() {
233        let tokenizer = create_tokenizer();
234        let generator = [1, 5, 3].into_iter().map(Ok);
235
236        let tokens: Vec<_> = generator
237            .decode(&tokenizer)
238            .map(|tok| tok.map_err(|e| e.to_string()))
239            .collect();
240
241        assert_eq!(
242            tokens,
243            [
244                Ok("one".to_string()),
245                Err("decode error: decoding failed: cannot decode unknown token ID 5".to_string()),
246                Ok("three".to_string())
247            ]
248        );
249    }
250}