_fastok/
lib.rs

1use pyo3::prelude::*;
2use pyo3::exceptions::PyTypeError;
3use regex::Regex;
4
5mod pyregex;
6use crate::pyregex::PyRegex;
7
8// The `#[pyo3(get)]` is a macro attribute that automatically implements a getter for the
9// attributes within a struct so that those can be accessed via Python.
10
11#[pyclass]
12struct PreTokenizer {
13    #[pyo3(get)]
14    model_name: String,
15    #[pyo3(get)]
16    regex: PyRegex,
17}
18
19#[pymethods]
20impl PreTokenizer {
21    #[new]
22    fn new(model: &str) -> Self {
23        let regex_str = match model {
24            "gpt2" => r#"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+|\s+"#,
25            _ => r#"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+|\s+"#,
26        };
27        let regex_ref = Regex::new(regex_str).unwrap();
28
29        PreTokenizer {
30            model_name: model.to_string(),
31            regex: PyRegex {
32                pattern: regex_ref,
33                pattern_str: regex_str,
34            },
35        }
36    }
37
38    pub fn pre_tokenize(&self, input: &PyAny) -> PyResult<Vec<Vec<String>>> {
39        // If the input is a string, then we can simply call the `pre_tokenize_str` function.
40        if let Ok(s) = input.extract::<String>() {
41            return Ok(vec![self.pre_tokenize_str(&s)?]);
42        }
43
44        // If the input is a list, then we need to iterate over the elements of the list and
45        // call the `pre_tokenize_str` function on each element.
46        if let Ok(s) = input.extract::<Vec<String>>() {
47            let mut result = Vec::new();
48            for item in s {
49                result.push(self.pre_tokenize_str(&item)?);
50            }
51            return Ok(result);
52        }
53
54        // If the input is neither a string nor a list, then we raise a `TypeError` exception.
55        Err(PyErr::new::<PyTypeError, _>("Expected a string or a list of strings"))
56    }
57
58    pub fn pre_tokenize_str(&self, input: &str) -> PyResult<Vec<String>> {    
59        Ok(self.regex.pattern.find_iter(input).map(|m| m.as_str().to_string()).collect::<Vec<_>>())
60    }
61}
62
63// A Python module implemented in Rust. The name of this function must match
64// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
65// import the module.
66#[pymodule]
67fn _fastok(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
68    m.add_class::<PreTokenizer>()?;
69    Ok(())
70}