py_regex/
lib.rs

1extern crate pyo3;
2use pyo3::prelude::*;
3use pyo3::types::{PyAny, PyDict, PyIterator, PyModule};
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::thread;
7
8/// A wrapper for a compiled regular expression from the Python `regex` library.
9#[derive(Debug)]
10pub struct PyRegex {
11    compiled: Py<PyAny>,
12}
13impl PyRegex {
14    /// Creates a new regular expression by compiling the pattern via Python's `regex.compile`.
15    pub fn new(pattern: &str) -> PyResult<Self> {
16        Python::with_gil(|py| {
17            let regex_module = PyModule::import(py, "regex")?;
18            let compiled = regex_module.call_method("compile", (pattern,), None)?;
19            Ok(PyRegex {
20                compiled: compiled.into(),
21            })
22        })
23    }
24
25    /// Constructs kwargs with `concurrent=True`.
26    fn kwargs(py: Python) -> Option<Bound<PyDict>> {
27        let kwargs = PyDict::new(py);
28        kwargs.set_item("concurrent", true).ok()?;
29        Some(kwargs)
30    }
31
32    /// Performs a search for the first match and returns a `PyRegexMatch` object.
33    pub fn search_match(&self, text: &str) -> PyResult<Option<PyRegexMatch>> {
34        Python::with_gil(|py| {
35            let result =
36                self.compiled
37                    .call_method(py, "search", (text,), Self::kwargs(py).as_ref())?;
38
39            Ok(if result.is_none(py) {
40                None
41            } else {
42                Some(PyRegexMatch {
43                    inner: result.into(),
44                })
45            })
46        })
47    }
48
49    /// Returns a list of `PyRegexMatch` objects from `finditer()`.
50    pub fn find_iter(&self, text: &str) -> PyResult<Vec<PyRegexMatch>> {
51        Python::with_gil(|py| {
52            let mut matches = Vec::new();
53            let binding =
54                self.compiled
55                    .call_method(py, "finditer", (text,), Self::kwargs(py).as_ref())?;
56            let iter = binding.downcast_bound::<PyIterator>(py)?;
57            for item in iter {
58                let match_obj = item?;
59                matches.push(PyRegexMatch {
60                    inner: match_obj.into(),
61                });
62            }
63            Ok(matches)
64        })
65    }
66
67    // Other methods remain unchanged.
68    pub fn is_match(&self, text: &str) -> PyResult<bool> {
69        Python::with_gil(|py| {
70            Ok(!self
71                .compiled
72                .call_method(py, "search", (text,), Self::kwargs(py).as_ref())?
73                .is_none(py))
74        })
75    }
76
77    pub fn find_all(&self, text: &str) -> PyResult<Vec<String>> {
78        Python::with_gil(|py| {
79            self.compiled
80                .call_method(py, "findall", (text,), Self::kwargs(py).as_ref())?
81                .extract::<Vec<String>>(py)
82        })
83    }
84
85    pub fn replace(&self, text: &str, replacement: &str) -> PyResult<String> {
86        Python::with_gil(|py| {
87            self.compiled
88                .call_method(py, "sub", (replacement, text), Self::kwargs(py).as_ref())?
89                .extract::<String>(py)
90        })
91    }
92
93    pub fn split(&self, text: &str) -> PyResult<Vec<String>> {
94        Python::with_gil(|py| {
95            let kwargs = Self::kwargs(py);
96            self.compiled
97                .call_method(py, "split", (text,), kwargs.as_ref())?
98                .extract::<Vec<String>>(py)
99        })
100    }
101}
102
103/// A wrapper for the match object from the Python `regex` module.
104pub struct PyRegexMatch {
105    inner: Py<PyAny>,
106}
107
108impl PyRegexMatch {
109    /// Returns the match for the specified group.
110    /// For example, `group(0)` is the entire match, `group(1)` is the first subgroup, etc.
111    pub fn group(&self, group: u16) -> PyResult<Option<String>> {
112        Python::with_gil(|py| {
113            self.inner
114                .call_method1(py, "group", (group as usize,))?
115                .extract::<Option<String>>(py)
116        })
117    }
118
119    /// Returns all captured groups as a vector.
120    /// Analogous to Python's `groups()` method, which returns a tuple of all subgroups (starting from 1).
121    pub fn groups(&self) -> PyResult<Vec<Option<String>>> {
122        Python::with_gil(|py| {
123            self.inner
124                .call_method1(py, "groups", ())?
125                .extract::<Vec<Option<String>>>(py)
126        })
127    }
128
129    /// Returns the named groups dictionary (`groupdict()`) as a `HashMap`.
130    pub fn groupdict(&self) -> PyResult<HashMap<String, Option<String>>> {
131        Python::with_gil(|py| {
132            self.inner
133                .call_method1(py, "groupdict", ())?
134                .extract::<HashMap<String, Option<String>>>(py)
135        })
136    }
137
138    /// Returns the start position of the match for the specified group.
139    pub fn start(&self, group: u16) -> PyResult<isize> {
140        Python::with_gil(|py| {
141            self.inner
142                .call_method1(py, "start", (group as usize,))?
143                .extract::<isize>(py)
144        })
145    }
146
147    /// Returns the end position of the match for the specified group.
148    pub fn end(&self, group: u16) -> PyResult<isize> {
149        Python::with_gil(|py| {
150            self.inner
151                .call_method1(
152                    py,
153                    "end",
154                    (group as usize,), /* Option<&pyo3::Bound<'_, PyDict>> */
155                )?
156                .extract::<isize>(py)
157        })
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164
165    #[test]
166    fn test_pyregex_match_methods() -> PyResult<()> {
167        // Initialize Python for multithreaded usage.
168        pyo3::prepare_freethreaded_python();
169
170        // Use a pattern with a named group and multiple subgroups.
171        let pattern = r"(?P<word>\w+)-(\d+)";
172        let text = "Test-123";
173        let re = PyRegex::new(pattern)?;
174
175        if let Some(m) = re.search_match(text)? {
176            // Check the full match via group(0)
177            assert_eq!(m.group(0)?, Some("Test-123".to_string()));
178
179            // First subgroup (without a name)
180            assert_eq!(m.group(1)?, Some("Test".to_string()));
181
182            // Second subgroup (the number)
183            assert_eq!(m.group(2)?, Some("123".to_string()));
184
185            // Get the named groups dictionary
186            let gd = m.groupdict()?;
187            assert_eq!(gd.get("word").cloned(), Some(Some("Test".to_string())));
188
189            // Get the match span for group 0
190            let start = m.start(0)?;
191            let end = m.end(0)?;
192            println!("Match span for group 0: {}..{}", start, end);
193        } else {
194            panic!("No match found");
195        }
196
197        Ok(())
198    }
199}
200
201#[allow(dead_code)]
202fn example() -> PyResult<()> {
203    // Initialize Python for multithreaded usage.
204    pyo3::prepare_freethreaded_python();
205
206    let pattern = r"(?P<id>\d+)";
207    let text = "IDs: 101, 202, 303";
208    let re = Arc::new(PyRegex::new(pattern)?);
209
210    // Example usage of `search_match()` to obtain a `PyRegexMatch` object.
211    if let Some(m) = re.search_match(text)? {
212        println!("Full match (group 0): {:?}", m.group(0)?);
213        println!("Group 'id' (as group 0 here): {:?}", m.group(0)?);
214        println!("Groupdict: {:?}", m.groupdict()?);
215        println!("Span for group 0: {}..{}", m.start(0)?, m.end(0)?);
216    }
217
218    // Example of multithreaded usage of `find_iter()`, returning a `Vec<PyRegexMatch>`.
219    let mut handles = vec![];
220    for i in 0..4 {
221        let re_clone = Arc::clone(&re);
222        let text_clone = text.to_string();
223        let handle = thread::spawn(move || -> PyResult<()> {
224            let matches = re_clone.find_iter(&text_clone)?;
225            println!("Thread {}: found {} matches.", i, matches.len());
226            for m in matches {
227                println!("Thread {}: match group 0: {:?}", i, m.group(0)?);
228            }
229            Ok(())
230        });
231        handles.push(handle);
232    }
233
234    for handle in handles {
235        handle.join().unwrap()?;
236    }
237
238    Ok(())
239}