1use pyo3::prelude::*;
2use pyo3::exceptions::PyTypeError;
3use regex::Regex;
4
5mod pyregex;
6use crate::pyregex::PyRegex;
7
8#[pyclass]
12struct PreTokenizer {
13 #[pyo3(get)]
14 model_name: String,
15 #[pyo3(get)]
16 regex: PyRegex,
17}
18
19#[pymethods]
20impl PreTokenizer {
21 #[new]
22 fn new(model: &str) -> Self {
23 let regex_str = match model {
24 "gpt2" => r#"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+|\s+"#,
25 _ => r#"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+|\s+"#,
26 };
27 let regex_ref = Regex::new(regex_str).unwrap();
28
29 PreTokenizer {
30 model_name: model.to_string(),
31 regex: PyRegex {
32 pattern: regex_ref,
33 pattern_str: regex_str,
34 },
35 }
36 }
37
38 pub fn pre_tokenize(&self, input: &PyAny) -> PyResult<Vec<Vec<String>>> {
39 if let Ok(s) = input.extract::<String>() {
41 return Ok(vec![self.pre_tokenize_str(&s)?]);
42 }
43
44 if let Ok(s) = input.extract::<Vec<String>>() {
47 let mut result = Vec::new();
48 for item in s {
49 result.push(self.pre_tokenize_str(&item)?);
50 }
51 return Ok(result);
52 }
53
54 Err(PyErr::new::<PyTypeError, _>("Expected a string or a list of strings"))
56 }
57
58 pub fn pre_tokenize_str(&self, input: &str) -> PyResult<Vec<String>> {
59 Ok(self.regex.pattern.find_iter(input).map(|m| m.as_str().to_string()).collect::<Vec<_>>())
60 }
61}
62
63#[pymodule]
67fn _fastok(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
68 m.add_class::<PreTokenizer>()?;
69 Ok(())
70}