py_regex/
lib.rs

1pub extern crate pyo3;
2use pyo3::PyResult;
3use pyo3::prelude::*;
4use pyo3::types::{PyAny, PyDict, PyIterator, PyModule};
5use std::collections::HashMap;
6
7/// A wrapper for a compiled regular expression from the Python `regex` library.
8#[derive(Debug)]
9pub struct PyRegex {
10    compiled: Py<PyAny>,
11}
12impl PyRegex {
13    /// Creates a new regular expression by compiling the pattern via Python's `regex.compile`.
14    pub fn new(pattern: &str) -> PyResult<Self> {
15        Python::with_gil(|py| {
16            Ok(PyRegex {
17                compiled: PyModule::import(py, "regex")?
18                    .call_method("compile", (pattern,), None)?
19                    .into(),
20            })
21        })
22    }
23
24    /// Constructs kwargs with `concurrent=True`.
25    fn kwargs(py: Python) -> Option<Bound<PyDict>> {
26        let kwargs = PyDict::new(py);
27        kwargs.set_item("concurrent", true).ok()?;
28        Some(kwargs)
29    }
30
31    /// Performs a search for the first match and returns a `PyRegexMatch` object.
32    pub fn search_match(&self, text: &str) -> PyResult<Option<PyRegexMatch>> {
33        Python::with_gil(|py| {
34            let result =
35                self.compiled
36                    .call_method(py, "search", (text,), Self::kwargs(py).as_ref())?;
37
38            Ok(if result.is_none(py) {
39                None
40            } else {
41                Some(PyRegexMatch { inner: result })
42            })
43        })
44    }
45
46    /// Returns a list of `PyRegexMatch` objects from `finditer()`.
47    pub fn find_iter(&self, text: &str) -> PyResult<Vec<PyRegexMatch>> {
48        Python::with_gil(|py| {
49            let mut matches = Vec::new();
50            let binding =
51                self.compiled
52                    .call_method(py, "finditer", (text,), Self::kwargs(py).as_ref())?;
53            let iter = binding.downcast_bound::<PyIterator>(py)?;
54            for item in iter {
55                let match_obj = item?;
56                matches.push(PyRegexMatch {
57                    inner: match_obj.into(),
58                });
59            }
60            Ok(matches)
61        })
62    }
63
64    // Other methods remain unchanged.
65    pub fn is_match(&self, text: &str) -> PyResult<bool> {
66        Python::with_gil(|py| {
67            Ok(!self
68                .compiled
69                .call_method(py, "search", (text,), Self::kwargs(py).as_ref())?
70                .is_none(py))
71        })
72    }
73
74    pub fn find_all(&self, text: &str) -> PyResult<Vec<String>> {
75        Python::with_gil(|py| {
76            self.compiled
77                .call_method(py, "findall", (text,), Self::kwargs(py).as_ref())?
78                .extract::<Vec<String>>(py)
79        })
80    }
81
82    pub fn replace(&self, text: &str, replacement: &str) -> PyResult<String> {
83        Python::with_gil(|py| {
84            self.compiled
85                .call_method(py, "sub", (replacement, text), Self::kwargs(py).as_ref())?
86                .extract::<String>(py)
87        })
88    }
89
90    pub fn split(&self, text: &str) -> PyResult<Vec<String>> {
91        Python::with_gil(|py| {
92            self.compiled
93                .call_method(py, "split", (text,), Self::kwargs(py).as_ref())?
94                .extract::<Vec<String>>(py)
95        })
96    }
97
98    /// Escapes a string.
99    pub fn escape(str: &str, special_only: bool, literal_spaces: bool) -> PyResult<String> {
100        Python::with_gil(|py| {
101            let kwargs = PyDict::new(py);
102            kwargs.set_item("special_only", special_only)?;
103            kwargs.set_item("literal_spaces", literal_spaces)?;
104            PyModule::import(py, "regex")?
105                .call_method("escape", (str,), Some::<Bound<PyDict>>(kwargs).as_ref())?
106                .extract::<String>()
107        })
108    }
109}
110
111/// A wrapper for the match object from the Python `regex` module.
112pub struct PyRegexMatch {
113    inner: Py<PyAny>,
114}
115
116impl PyRegexMatch {
117    /// Returns the match for the specified group.
118    /// For example, `group(0)` is the entire match, `group(1)` is the first subgroup, etc.
119    pub fn group(&self, group: u16) -> PyResult<Option<String>> {
120        Python::with_gil(|py| {
121            self.inner
122                .call_method1(py, "group", (group as usize,))?
123                .extract::<Option<String>>(py)
124        })
125    }
126
127    /// Returns all captured groups as a vector.
128    /// Analogous to Python's `groups()` method, which returns a tuple of all subgroups (starting from 1).
129    pub fn groups(&self) -> PyResult<Vec<Option<String>>> {
130        Python::with_gil(|py| {
131            self.inner
132                .call_method1(py, "groups", ())?
133                .extract::<Vec<Option<String>>>(py)
134        })
135    }
136
137    /// Returns the named groups dictionary (`groupdict()`) as a `HashMap`.
138    pub fn groupdict(&self) -> PyResult<HashMap<String, Option<String>>> {
139        Python::with_gil(|py| {
140            self.inner
141                .call_method1(py, "groupdict", ())?
142                .extract::<HashMap<String, Option<String>>>(py)
143        })
144    }
145
146    /// Returns the start position of the match for the specified group.
147    pub fn start(&self, group: u16) -> PyResult<isize> {
148        Python::with_gil(|py| {
149            self.inner
150                .call_method1(py, "start", (group as usize,))?
151                .extract::<isize>(py)
152        })
153    }
154
155    /// Returns the end position of the match for the specified group.
156    pub fn end(&self, group: u16) -> PyResult<isize> {
157        Python::with_gil(|py| {
158            self.inner
159                .call_method1(
160                    py,
161                    "end",
162                    (group as usize,), /* Option<&pyo3::Bound<'_, PyDict>> */
163                )?
164                .extract::<isize>(py)
165        })
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_escape() -> PyResult<()> {
175        // Initialize Python for multithreaded usage.
176        pyo3::prepare_freethreaded_python();
177
178        assert_eq!(PyRegex::escape("[]", false, false)?, "\\[\\]");
179
180        Ok(())
181    }
182    #[test]
183    fn test_pyregex_match_methods() -> PyResult<()> {
184        // Initialize Python for multithreaded usage.
185        pyo3::prepare_freethreaded_python();
186
187        // Use a pattern with a named group and multiple subgroups.
188        let pattern = r"(?P<word>\w+)-(\d+)";
189        let text = "Test-123";
190        let re = PyRegex::new(pattern)?;
191
192        if let Some(m) = re.search_match(text)? {
193            // Check the full match via group(0)
194            assert_eq!(m.group(0)?, Some("Test-123".to_string()));
195
196            // First subgroup (without a name)
197            assert_eq!(m.group(1)?, Some("Test".to_string()));
198
199            // Second subgroup (the number)
200            assert_eq!(m.group(2)?, Some("123".to_string()));
201
202            // Get the named groups dictionary
203            let gd = m.groupdict()?;
204            assert_eq!(gd.get("word").cloned(), Some(Some("Test".to_string())));
205
206            // Get the match span for group 0
207            let start = m.start(0)?;
208            let end = m.end(0)?;
209            println!("Match span for group 0: {}..{}", start, end);
210        } else {
211            panic!("No match found");
212        }
213
214        Ok(())
215    }
216}