1use crate::error;
2use crate::parser::findall_iter::FindallIter;
3use crate::parser::format_parser::{CompiledFields, FormatParser};
4use fancy_regex::Regex;
5use formatparse_core::count_capturing_groups;
6use formatparse_core::parser::validate_input_length;
7use pyo3::exceptions::PyValueError;
8use pyo3::prelude::*;
9use pyo3::types::{PyAnyMethods, PyString, PyTuple};
10use pyo3::IntoPyObjectExt;
11use std::collections::HashMap;
12
13#[pymethods]
14impl FormatParser {
15 #[new]
16 #[pyo3(signature = (pattern=None, extra_types=None))]
17 fn new_py(
18 pattern: Option<&str>,
19 extra_types: Option<HashMap<String, PyObject>>,
20 ) -> PyResult<Self> {
21 match pattern {
22 Some(p) => Self::new_with_extra_types(p, extra_types),
23 None => {
24 let empty_regex =
27 Regex::new("^$").map_err(|e| crate::error::regex_error(&e.to_string()))?;
28 Ok(Self {
29 pattern: String::new(),
30 regex: empty_regex.clone(),
31 regex_str: String::new(),
32 regex_case_insensitive: None,
33 search_regex: empty_regex.clone(),
34 search_regex_case_insensitive: None,
35 fields: CompiledFields {
36 field_specs: Vec::new(),
37 field_names: Vec::new(),
38 normalized_names: Vec::new(),
39 custom_type_groups: Vec::new(),
40 has_nested_dict_fields: Vec::new(),
41 nested_parsers: Vec::new(),
42 field_count: 0,
43 },
44 name_mapping: HashMap::new(),
45 stored_extra_types: None,
46 allows_empty_default_string_match: false,
47 })
48 }
49 }
50 }
51
52 #[pyo3(signature = (string, case_sensitive=false, extra_types=None, evaluate_result=true))]
57 fn parse(
58 &self,
59 string: &str,
60 case_sensitive: bool,
61 extra_types: Option<HashMap<String, PyObject>>,
62 evaluate_result: bool,
63 ) -> PyResult<Option<PyObject>> {
64 validate_input_length(string).map_err(PyValueError::new_err)?;
66
67 if string.contains('\0') {
69 return Err(PyValueError::new_err("Input string contains null byte"));
70 }
71 let merged_extra_types =
73 Python::with_gil(|py| -> PyResult<Option<HashMap<String, PyObject>>> {
74 let mut merged = if let Some(ref stored) = self.stored_extra_types {
75 stored
76 .iter()
77 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
78 .collect()
79 } else {
80 HashMap::new()
81 };
82 if let Some(ref provided) = extra_types {
83 for (k, v) in provided {
84 merged.insert(k.clone(), v.clone_ref(py));
85 }
86 }
87 Ok(Some(merged))
88 })?;
89 self.parse_internal(
90 string,
91 case_sensitive,
92 merged_extra_types.as_ref(),
93 evaluate_result,
94 )
95 }
96
97 #[getter]
99 fn named_fields(&self) -> Vec<String> {
100 self.fields
102 .normalized_names
103 .iter()
104 .filter_map(|n| n.clone())
105 .collect()
106 }
107
108 #[getter]
115 fn regex_subpattern(&self) -> String {
116 self.regex_str.clone()
117 }
118
119 #[getter]
124 fn regex_capturing_group_count(&self) -> usize {
125 count_capturing_groups(&self.regex_str)
126 }
127
128 #[getter]
131 fn _expression(&self) -> String {
132 let mut result = self.regex_str.clone();
133
134 result = result.replace(r")\s+(", ") (");
137 result = result.replace(r")\s*(", ") (");
139
140 result = result.replace(
145 r"([+-]?(?:\d+\.\d+|\.\d+|\d+\.)(?:[eE][+-]?\d+)?)",
146 r"([-+ ]?\d*\.\d+)",
147 );
148
149 if result.starts_with("(") && result.ends_with(")") {
153 let inner = &result[1..result.len() - 1];
154 if inner.starts_with(" *(") && inner.ends_with(")") {
156 result = inner.to_string();
158 }
159 }
160
161 result
162 }
163
164 #[getter]
166 fn format(&self) -> Format {
167 Format {
168 pattern: self.pattern.clone(),
169 }
170 }
171
172 #[pyo3(signature = (string, case_sensitive=true, extra_types=None, evaluate_result=true))]
177 fn search(
178 &self,
179 string: &str,
180 case_sensitive: bool,
181 extra_types: Option<HashMap<String, PyObject>>,
182 evaluate_result: bool,
183 ) -> PyResult<Option<PyObject>> {
184 validate_input_length(string).map_err(PyValueError::new_err)?;
186
187 if string.contains('\0') {
189 return Err(PyValueError::new_err("Input string contains null byte"));
190 }
191
192 let merged_extra_types =
193 Python::with_gil(|py| -> PyResult<Option<HashMap<String, PyObject>>> {
194 let mut merged = if let Some(ref stored) = self.stored_extra_types {
195 stored
196 .iter()
197 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
198 .collect()
199 } else {
200 HashMap::new()
201 };
202 if let Some(ref provided) = extra_types {
203 for (k, v) in provided {
204 merged.insert(k.clone(), v.clone_ref(py));
205 }
206 }
207 Ok(Some(merged))
208 })?;
209
210 self.search_pattern(string, case_sensitive, merged_extra_types, evaluate_result)
211 }
212
213 #[pyo3(signature = (string, case_sensitive=false, extra_types=None, evaluate_result=true))]
220 fn findall_iter(
221 &self,
222 py: Python<'_>,
223 string: &str,
224 case_sensitive: bool,
225 extra_types: Option<HashMap<String, PyObject>>,
226 evaluate_result: bool,
227 ) -> PyResult<Py<FindallIter>> {
228 validate_input_length(string).map_err(PyValueError::new_err)?;
229
230 if string.contains('\0') {
231 return Err(PyValueError::new_err("Input string contains null byte"));
232 }
233
234 let merged_extra_types =
235 Python::with_gil(|py| -> PyResult<Option<HashMap<String, PyObject>>> {
236 let mut merged = if let Some(ref stored) = self.stored_extra_types {
237 stored
238 .iter()
239 .map(|(k, v)| (k.clone(), v.clone_ref(py)))
240 .collect()
241 } else {
242 HashMap::new()
243 };
244 if let Some(ref provided) = extra_types {
245 for (k, v) in provided {
246 merged.insert(k.clone(), v.clone_ref(py));
247 }
248 }
249 Ok(Some(merged))
250 })?;
251
252 let merged_map = merged_extra_types.unwrap_or_default();
253 let arc = std::sync::Arc::new(self.clone());
254 Py::new(
255 py,
256 FindallIter::new(
257 arc,
258 string.to_string(),
259 case_sensitive,
260 evaluate_result,
261 merged_map,
262 ),
263 )
264 }
265
266 fn __reduce__(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
271 let m = py.import("_formatparse")?;
272 let compile_fn = m.getattr("compile")?;
273 let args = PyTuple::new(py, [&self.pattern])?;
274 PyTuple::new(py, [compile_fn.as_any(), args.as_any()])?.into_py_any(py)
275 }
276
277 fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
279 use pyo3::types::PyDict;
280 let state = PyDict::new(py);
281 state.set_item("pattern", &self.pattern)?;
282 state.into_py_any(py)
283 }
284
285 fn __setstate__(&mut self, _py: Python, state: &Bound<'_, PyAny>) -> PyResult<()> {
287 use pyo3::types::PyDict;
288 let dict = state.downcast::<PyDict>()?;
289 let pattern: String = dict
290 .get_item("pattern")?
291 .ok_or_else(|| error::missing_field_error("pattern"))?
292 .extract()?;
293
294 let reconstructed = Self::new_with_extra_types(&pattern, None)?;
296
297 self.pattern = reconstructed.pattern;
299 self.regex_str = reconstructed.regex_str;
300 self.regex = reconstructed.regex;
301 self.regex_case_insensitive = reconstructed.regex_case_insensitive;
302 self.search_regex = reconstructed.search_regex;
303 self.search_regex_case_insensitive = reconstructed.search_regex_case_insensitive;
304 self.fields = reconstructed.fields;
305 self.name_mapping = reconstructed.name_mapping;
306 self.stored_extra_types = reconstructed.stored_extra_types;
307 self.allows_empty_default_string_match = reconstructed.allows_empty_default_string_match;
308 Ok(())
309 }
310}
311
312#[pyclass]
314pub struct Format {
315 pattern: String,
316}
317
318#[pymethods]
319impl Format {
320 fn format(&self, py: Python, args: &Bound<'_, PyAny>) -> PyResult<String> {
322 let pattern_obj = PyString::new(py, &self.pattern);
324 let format_method = pattern_obj.getattr("format")?;
325
326 let result = if let Ok(tuple) = args.downcast::<PyTuple>() {
328 format_method.call1(tuple)?
329 } else {
330 format_method.call1((args,))?
332 };
333 result.extract()
334 }
335}