tokenizers/tokenizer/
pre_tokenizer.rs

1use crate::{
2    normalizer::Range, Encoding, NormalizedString, OffsetReferential, Offsets, Result, Token,
3};
4use std::collections::HashMap;
5
6/// Various possible types of offsets
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum OffsetType {
9    Byte,
10    Char,
11    None,
12}
13
14/// Wrapper for a subpart of a `NormalizedString`.
15///
16/// This Split contains the underlying `NormalizedString` as well as its offsets
17/// in the original string. These offsets are in the `original` referential.
18/// It also contains any `Token` associated to the current split
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct Split {
21    /// The underlying `NormalizedString`. Each SubString is represented by a `NormalizedString`
22    /// and in the end we might be carrying a lot of SubString representing various parts of the
23    /// original input string.
24    normalized: NormalizedString,
25    /// Optional Tokens associated to this Split
26    tokens: Option<Vec<Token>>,
27}
28
29impl From<NormalizedString> for Split {
30    fn from(n: NormalizedString) -> Self {
31        Self {
32            normalized: n,
33            tokens: None,
34        }
35    }
36}
37
38impl From<(NormalizedString, Option<Vec<Token>>)> for Split {
39    fn from(f: (NormalizedString, Option<Vec<Token>>)) -> Self {
40        Self {
41            normalized: f.0,
42            tokens: f.1,
43        }
44    }
45}
46
47/// The `PreTokenizedString` is in charge of splitting an underlying string,
48/// making sure everything is fine while doing so, and providing ways to normalize
49/// and tokenize these splits.
50/// Once everything has been normalized and tokenized, the `PreTokenizedString` is able
51/// to build an `Encoding` with all the relevant offsets and word ids, relative to the
52/// original string.
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct PreTokenizedString {
55    original: String,
56    splits: Vec<Split>,
57}
58
59impl PreTokenizedString {
60    /// Split the `PreTokenizedString` by providing a `split_fn` in charge of splitting
61    /// each substring (`NormalizedString`) into multiple parts.
62    ///
63    /// `split_fn` takes a `NormalizedString` and is in charge of returning an iterator
64    /// over the produced `NormalizedString`. `split_fn` is free of modifying these
65    /// `NormalizedString` as relevant, as long as it respects the constraint stated below.
66    ///
67    /// There are only one constraint that *MUST* be respected:
68    /// > The produced `NormalizedString`, if combined back together, must have the
69    /// > same `original` string as the original one given to `split_fn`. This concretely
70    /// > means that for the offset tracking to work as expected, `split_fn` must produce
71    /// > "splits" of the original string.
72    pub fn split<F, U, R>(&mut self, mut split_fn: F) -> Result<()>
73    where
74        F: FnMut(usize, NormalizedString) -> Result<U>,
75        U: IntoIterator<Item = R>,
76        R: Into<Split>,
77    {
78        // new_splits is at least as big as self.splits
79        let mut new_splits = Vec::with_capacity(self.splits.len());
80        for (i, original_split) in self.splits.drain(..).enumerate() {
81            if original_split.tokens.is_some() {
82                new_splits.push(original_split);
83                continue;
84            }
85
86            new_splits.extend(
87                split_fn(i, original_split.normalized)?
88                    .into_iter()
89                    .filter_map(|split| {
90                        let split: Split = split.into();
91                        if split.normalized.is_empty() {
92                            None
93                        } else {
94                            Some(split)
95                        }
96                    }),
97            );
98        }
99        self.splits = new_splits;
100
101        Ok(())
102    }
103
104    /// Normalized all the splits that do not have attached `Tokens`, using the provided
105    /// `normalize` function.
106    pub fn normalize<F>(&mut self, normalize: F) -> Result<()>
107    where
108        F: Fn(&mut NormalizedString) -> Result<()>,
109    {
110        for split in self.splits.iter_mut().filter(|s| s.tokens.is_none()) {
111            normalize(&mut split.normalized)?;
112        }
113        Ok(())
114    }
115
116    /// Tokenize all the splits that do not have attached `Tokens`, using the provided
117    /// `tokenize` function
118    pub fn tokenize<F>(&mut self, tokenize: F) -> Result<()>
119    where
120        F: Fn(&NormalizedString) -> Result<Vec<Token>>,
121    {
122        for split in self.splits.iter_mut().filter(|s| s.tokens.is_none()) {
123            split.tokens = Some(tokenize(&split.normalized)?);
124        }
125
126        Ok(())
127    }
128
129    /// Transform the current `PreTokenizedString` into an `Encoding`.
130    ///
131    /// If a `word_idx` is provided, any word in the generated `Encoding`
132    /// will be set to this value. This is generally used with pre-tokenized
133    /// input, that do not need the `PreTokenizedString` to generate word ids.
134    ///
135    /// This method will fail if some splits do not have associated `Token`.
136    pub fn into_encoding(
137        self,
138        word_idx: Option<u32>,
139        type_id: u32,
140        offset_type: OffsetType,
141    ) -> Result<Encoding> {
142        if self.splits.is_empty() {
143            Ok(Encoding::default())
144        } else if !self.splits.iter().all(|split| split.tokens.is_some()) {
145            Err("Split has not been tokenized, call `PreTokenizedString::tokenize` first".into())
146        } else {
147            let offset_converter = match offset_type {
148                OffsetType::Char => Some(BytesToCharOffsetConverter::new(&self.original)),
149                OffsetType::Byte => None,
150                OffsetType::None => {
151                    let tokens = self
152                        .splits
153                        .into_iter()
154                        .flat_map(|split| {
155                            split.tokens.unwrap().into_iter().map(|token| {
156                                // Replace this with the actual fields you need for the Encoding type
157                                (token.id, String::with_capacity(0), (0, 0), None, 0)
158                            })
159                        })
160                        .collect();
161                    return Ok(tokens);
162                }
163            };
164
165            Ok(self
166                .splits
167                .into_iter()
168                .enumerate()
169                .flat_map(|(idx, split)| {
170                    let normalized = split.normalized;
171                    let offsets = normalized.offsets_original();
172                    let offset_converter = &offset_converter;
173
174                    split.tokens.unwrap().into_iter().map(move |token| {
175                        let mut offsets = normalized
176                            .convert_offsets(Range::Normalized(token.offsets.0..token.offsets.1))
177                            .map_or(token.offsets, |range| {
178                                (offsets.0 + range.start, offsets.0 + range.end)
179                            });
180
181                        // Convert to char offsets if relevant
182                        if let Some(converter) = offset_converter {
183                            offsets = converter.convert(offsets).unwrap_or(offsets);
184                        }
185
186                        (
187                            token.id,
188                            token.value,
189                            offsets,
190                            if word_idx.is_some() {
191                                word_idx
192                            } else {
193                                Some(idx as u32)
194                            },
195                            type_id,
196                        )
197                    })
198                })
199                .collect())
200        }
201    }
202
203    /// Returns a list of splits, each of them being a slice of the normalized
204    /// string, the associated offsets either in original or normalized
205    /// referential, as well as the potention tokens
206    pub fn get_splits(
207        &self,
208        offset_ref: OffsetReferential,
209        offset_type: OffsetType,
210    ) -> Vec<(&str, Offsets, &Option<Vec<Token>>)> {
211        let offset_converter = match offset_type {
212            OffsetType::Char => Some(BytesToCharOffsetConverter::new(&self.original)),
213            OffsetType::Byte => None,
214            OffsetType::None => None,
215        };
216
217        let mut offset = 0;
218        self.splits
219            .iter()
220            .map(|split| {
221                let mut offsets = match offset_ref {
222                    OffsetReferential::Original => split.normalized.offsets_original(),
223                    OffsetReferential::Normalized => {
224                        let len = split.normalized.len();
225                        offset += len;
226                        (offset - len, offset)
227                    }
228                };
229
230                // Convert to char offsets if relevant
231                if let Some(ref converter) = offset_converter {
232                    offsets = converter.convert(offsets).unwrap_or(offsets);
233                }
234
235                (split.normalized.get(), offsets, &split.tokens)
236            })
237            .collect()
238    }
239}
240
241impl From<NormalizedString> for PreTokenizedString {
242    fn from(s: NormalizedString) -> Self {
243        Self {
244            original: s.get_original().to_owned(),
245            splits: vec![Split {
246                normalized: s,
247                tokens: None,
248            }],
249        }
250    }
251}
252
253impl From<&str> for PreTokenizedString {
254    fn from(s: &str) -> Self {
255        let normalized: NormalizedString = s.into();
256        normalized.into()
257    }
258}
259
260impl From<String> for PreTokenizedString {
261    fn from(s: String) -> Self {
262        let normalized: NormalizedString = s.into();
263        normalized.into()
264    }
265}
266
267struct BytesToCharOffsetConverter {
268    map: HashMap<usize, usize>,
269}
270
271impl BytesToCharOffsetConverter {
272    pub fn new(sequence: &str) -> Self {
273        Self {
274            map: sequence
275                .char_indices()
276                .enumerate()
277                .flat_map(|(i, (b, c))| {
278                    let mut n = 0;
279                    std::iter::repeat_with(move || {
280                        let o = (b + n, i);
281                        n += 1;
282                        o
283                    })
284                    .take(c.len_utf8())
285                })
286                .collect(),
287        }
288    }
289
290    pub fn convert(&self, offsets: Offsets) -> Option<Offsets> {
291        match (self.map.get(&offsets.0), self.map.get(&offsets.1)) {
292            (Some(start), Some(end)) => Some((*start, *end)),
293            // If we reached the end, `end` is not in the map
294            (Some(start), None) => {
295                // But the one just before should be
296                let last = self.map.get(&(offsets.1 - 1)).copied().unwrap_or(start + 1);
297                Some((*start, last + 1))
298            }
299            _ => None,
300        }
301    }
302}