ftl_jiter/
python.rs

1use ahash::AHashSet;
2use std::marker::PhantomData;
3
4use pyo3::exceptions::{PyTypeError, PyValueError};
5use pyo3::ffi;
6use pyo3::prelude::*;
7use pyo3::types::{PyBool, PyDict, PyList, PyString};
8use pyo3::ToPyObject;
9
10use smallvec::SmallVec;
11
12use crate::errors::{json_err, json_error, JsonError, JsonResult, DEFAULT_RECURSION_LIMIT};
13use crate::number_decoder::{AbstractNumberDecoder, NumberAny, NumberRange};
14use crate::parse::{Parser, Peek};
15use crate::py_lossless_float::{get_decimal_type, FloatMode};
16use crate::py_string_cache::{StringCacheAll, StringCacheKeys, StringCacheMode, StringMaybeCache, StringNoCache};
17use crate::string_decoder::{StringDecoder, Tape};
18use crate::{JsonErrorType, LosslessFloat};
19
20#[derive(Default)]
21#[allow(clippy::struct_excessive_bools)]
22pub struct PythonParse {
23    /// Whether to allow `(-)Infinity` and `NaN` values.
24    pub allow_inf_nan: bool,
25    /// Whether to cache strings to avoid constructing new Python objects,
26    pub cache_mode: StringCacheMode,
27    /// Whether to allow partial JSON data.
28    pub partial_mode: PartialMode,
29    /// Whether to catch duplicate keys in objects.
30    pub catch_duplicate_keys: bool,
31    /// How to return floats: as a `float` (`'float'`), `Decimal` (`'decimal'`) or
32    /// [`LosslessFloat`] (`'lossless-float'`)
33    pub float_mode: FloatMode,
34}
35
36impl PythonParse {
37    /// Parse a JSON value from a byte slice and return a Python object.
38    ///
39    /// # Arguments
40    ///
41    /// - `py`: [Python](https://docs.rs/pyo3/latest/pyo3/marker/struct.Python.html) marker token.
42    /// - `json_data`: The JSON data to parse.
43    ///   this should have a significant improvement on performance but increases memory slightly.
44    ///
45    /// # Returns
46    ///
47    /// A [PyObject](https://docs.rs/pyo3/latest/pyo3/type.PyObject.html) representing the parsed JSON value.
48    pub fn python_parse<'py>(self, py: Python<'py>, json_data: &[u8]) -> JsonResult<Bound<'py, PyAny>> {
49        macro_rules! ppp {
50            ($string_cache:ident, $key_check:ident, $parse_number:ident) => {
51                PythonParser::<$string_cache, $key_check, $parse_number>::parse(
52                    py,
53                    json_data,
54                    self.allow_inf_nan,
55                    self.partial_mode,
56                )
57            };
58        }
59        macro_rules! ppp_group {
60            ($string_cache:ident) => {
61                match (self.catch_duplicate_keys, self.float_mode) {
62                    (true, FloatMode::Float) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberLossy),
63                    (true, FloatMode::Decimal) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberDecimal),
64                    (true, FloatMode::LosslessFloat) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberLossless),
65                    (false, FloatMode::Float) => ppp!($string_cache, NoopKeyCheck, ParseNumberLossy),
66                    (false, FloatMode::Decimal) => ppp!($string_cache, NoopKeyCheck, ParseNumberDecimal),
67                    (false, FloatMode::LosslessFloat) => ppp!($string_cache, NoopKeyCheck, ParseNumberLossless),
68                }
69            };
70        }
71
72        match self.cache_mode {
73            StringCacheMode::All => ppp_group!(StringCacheAll),
74            StringCacheMode::Keys => ppp_group!(StringCacheKeys),
75            StringCacheMode::None => ppp_group!(StringNoCache),
76        }
77    }
78}
79
80/// Map a `JsonError` to a `PyErr` which can be raised as an exception in Python as a `ValueError`.
81pub fn map_json_error(json_data: &[u8], json_error: &JsonError) -> PyErr {
82    PyValueError::new_err(json_error.description(json_data))
83}
84
85struct PythonParser<'j, StringCache, KeyCheck, ParseNumber> {
86    _string_cache: PhantomData<StringCache>,
87    _key_check: PhantomData<KeyCheck>,
88    _parse_number: PhantomData<ParseNumber>,
89    parser: Parser<'j>,
90    tape: Tape,
91    recursion_limit: u8,
92    allow_inf_nan: bool,
93    partial_mode: PartialMode,
94}
95
96impl<'j, StringCache: StringMaybeCache, KeyCheck: MaybeKeyCheck, ParseNumber: MaybeParseNumber>
97    PythonParser<'j, StringCache, KeyCheck, ParseNumber>
98{
99    fn parse<'py>(
100        py: Python<'py>,
101        json_data: &[u8],
102        allow_inf_nan: bool,
103        partial_mode: PartialMode,
104    ) -> JsonResult<Bound<'py, PyAny>> {
105        let mut slf = PythonParser {
106            _string_cache: PhantomData::<StringCache>,
107            _key_check: PhantomData::<KeyCheck>,
108            _parse_number: PhantomData::<ParseNumber>,
109            parser: Parser::new(json_data),
110            tape: Tape::default(),
111            recursion_limit: DEFAULT_RECURSION_LIMIT,
112            allow_inf_nan,
113            partial_mode,
114        };
115
116        let peek = slf.parser.peek()?;
117        let v = slf.py_take_value(py, peek)?;
118        if !slf.partial_mode.is_active() {
119            slf.parser.finish()?;
120        }
121        Ok(v)
122    }
123
124    fn py_take_value<'py>(&mut self, py: Python<'py>, peek: Peek) -> JsonResult<Bound<'py, PyAny>> {
125        match peek {
126            Peek::Null => {
127                self.parser.consume_null()?;
128                Ok(py.None().into_bound(py))
129            }
130            Peek::True => {
131                self.parser.consume_true()?;
132                Ok(true.to_object(py).into_bound(py))
133            }
134            Peek::False => {
135                self.parser.consume_false()?;
136                Ok(false.to_object(py).into_bound(py))
137            }
138            Peek::String => {
139                let s = self
140                    .parser
141                    .consume_string::<StringDecoder>(&mut self.tape, self.partial_mode.allow_trailing_str())?;
142                Ok(StringCache::get_value(py, s.as_str(), s.ascii_only()).into_any())
143            }
144            Peek::Array => {
145                let peek_first = match self.parser.array_first() {
146                    Ok(Some(peek)) => peek,
147                    Err(e) if !self._allow_partial_err(&e) => return Err(e),
148                    Ok(None) | Err(_) => return Ok(PyList::empty_bound(py).into_any()),
149                };
150
151                let mut vec: SmallVec<[Bound<'_, PyAny>; 8]> = SmallVec::with_capacity(8);
152                if let Err(e) = self._parse_array(py, peek_first, &mut vec) {
153                    if !self._allow_partial_err(&e) {
154                        return Err(e);
155                    }
156                }
157
158                Ok(PyList::new_bound(py, vec).into_any())
159            }
160            Peek::Object => {
161                let dict = PyDict::new_bound(py);
162                if let Err(e) = self._parse_object(py, &dict) {
163                    if !self._allow_partial_err(&e) {
164                        return Err(e);
165                    }
166                }
167                Ok(dict.into_any())
168            }
169            _ => ParseNumber::parse_number(py, &mut self.parser, peek, self.allow_inf_nan),
170        }
171    }
172
173    fn _parse_array<'py>(
174        &mut self,
175        py: Python<'py>,
176        peek_first: Peek,
177        vec: &mut SmallVec<[Bound<'py, PyAny>; 8]>,
178    ) -> JsonResult<()> {
179        let v = self._check_take_value(py, peek_first)?;
180        vec.push(v);
181        while let Some(peek) = self.parser.array_step()? {
182            let v = self._check_take_value(py, peek)?;
183            vec.push(v);
184        }
185        Ok(())
186    }
187
188    fn _parse_object<'py>(&mut self, py: Python<'py>, dict: &Bound<'py, PyDict>) -> JsonResult<()> {
189        let set_item = |key: Bound<'py, PyString>, value: Bound<'py, PyAny>| {
190            let r = unsafe { ffi::PyDict_SetItem(dict.as_ptr(), key.as_ptr(), value.as_ptr()) };
191            // AFAIK this shouldn't happen since the key will always be a string  which is hashable
192            // we panic here rather than returning a result and using `?` below as it's up to 14% faster
193            // presumably because there are fewer branches
194            assert_ne!(r, -1, "PyDict_SetItem failed");
195        };
196        let mut check_keys = KeyCheck::default();
197        if let Some(first_key) = self.parser.object_first::<StringDecoder>(&mut self.tape)? {
198            let first_key_s = first_key.as_str();
199            check_keys.check(first_key_s, self.parser.index)?;
200            let first_key = StringCache::get_key(py, first_key_s, first_key.ascii_only());
201            let peek = self.parser.peek()?;
202            let first_value = self._check_take_value(py, peek)?;
203            set_item(first_key, first_value);
204            while let Some(key) = self.parser.object_step::<StringDecoder>(&mut self.tape)? {
205                let key_s = key.as_str();
206                check_keys.check(key_s, self.parser.index)?;
207                let key = StringCache::get_key(py, key_s, key.ascii_only());
208                let peek = self.parser.peek()?;
209                let value = self._check_take_value(py, peek)?;
210                set_item(key, value);
211            }
212        }
213        Ok(())
214    }
215
216    fn _allow_partial_err(&self, e: &JsonError) -> bool {
217        if self.partial_mode.is_active() {
218            matches!(
219                e.error_type,
220                JsonErrorType::EofWhileParsingList
221                    | JsonErrorType::EofWhileParsingObject
222                    | JsonErrorType::EofWhileParsingString
223                    | JsonErrorType::EofWhileParsingValue
224                    | JsonErrorType::ExpectedListCommaOrEnd
225                    | JsonErrorType::ExpectedObjectCommaOrEnd
226            )
227        } else {
228            false
229        }
230    }
231
232    fn _check_take_value<'py>(&mut self, py: Python<'py>, peek: Peek) -> JsonResult<Bound<'py, PyAny>> {
233        self.recursion_limit = match self.recursion_limit.checked_sub(1) {
234            Some(limit) => limit,
235            None => return json_err!(RecursionLimitExceeded, self.parser.index),
236        };
237
238        let r = self.py_take_value(py, peek);
239
240        self.recursion_limit += 1;
241        r
242    }
243}
244
245#[derive(Debug, Clone, Copy)]
246pub enum PartialMode {
247    Off,
248    On,
249    TrailingStrings,
250}
251
252impl Default for PartialMode {
253    fn default() -> Self {
254        Self::Off
255    }
256}
257
258const PARTIAL_ERROR: &str = "Invalid partial mode, should be `'off'`, `'on'`, `'trailing-strings'` or a `bool`";
259
260impl<'py> FromPyObject<'py> for PartialMode {
261    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
262        if let Ok(bool_mode) = ob.downcast::<PyBool>() {
263            Ok(bool_mode.is_true().into())
264        } else if let Ok(str_mode) = ob.extract::<&str>() {
265            match str_mode {
266                "off" => Ok(Self::Off),
267                "on" => Ok(Self::On),
268                "trailing-strings" => Ok(Self::TrailingStrings),
269                _ => Err(PyValueError::new_err(PARTIAL_ERROR)),
270            }
271        } else {
272            Err(PyTypeError::new_err(PARTIAL_ERROR))
273        }
274    }
275}
276
277impl From<bool> for PartialMode {
278    fn from(mode: bool) -> Self {
279        if mode {
280            Self::On
281        } else {
282            Self::Off
283        }
284    }
285}
286
287impl PartialMode {
288    fn is_active(self) -> bool {
289        !matches!(self, Self::Off)
290    }
291
292    fn allow_trailing_str(self) -> bool {
293        matches!(self, Self::TrailingStrings)
294    }
295}
296
297trait MaybeKeyCheck: Default {
298    fn check(&mut self, key: &str, index: usize) -> JsonResult<()>;
299}
300
301#[derive(Default)]
302struct NoopKeyCheck;
303
304impl MaybeKeyCheck for NoopKeyCheck {
305    fn check(&mut self, _key: &str, _index: usize) -> JsonResult<()> {
306        Ok(())
307    }
308}
309
310#[derive(Default)]
311struct DuplicateKeyCheck(AHashSet<String>);
312
313impl MaybeKeyCheck for DuplicateKeyCheck {
314    fn check(&mut self, key: &str, index: usize) -> JsonResult<()> {
315        if self.0.insert(key.to_owned()) {
316            Ok(())
317        } else {
318            Err(JsonError::new(JsonErrorType::DuplicateKey(key.to_owned()), index))
319        }
320    }
321}
322
323trait MaybeParseNumber {
324    fn parse_number<'py>(
325        py: Python<'py>,
326        parser: &mut Parser,
327        peek: Peek,
328        allow_inf_nan: bool,
329    ) -> JsonResult<Bound<'py, PyAny>>;
330}
331
332struct ParseNumberLossy;
333
334impl MaybeParseNumber for ParseNumberLossy {
335    fn parse_number<'py>(
336        py: Python<'py>,
337        parser: &mut Parser,
338        peek: Peek,
339        allow_inf_nan: bool,
340    ) -> JsonResult<Bound<'py, PyAny>> {
341        match parser.consume_number::<NumberAny>(peek.into_inner(), allow_inf_nan) {
342            Ok(number) => Ok(number.to_object(py).into_bound(py)),
343            Err(e) => {
344                if !peek.is_num() {
345                    Err(json_error!(ExpectedSomeValue, parser.index))
346                } else {
347                    Err(e)
348                }
349            }
350        }
351    }
352}
353
354struct ParseNumberLossless;
355
356impl MaybeParseNumber for ParseNumberLossless {
357    fn parse_number<'py>(
358        py: Python<'py>,
359        parser: &mut Parser,
360        peek: Peek,
361        allow_inf_nan: bool,
362    ) -> JsonResult<Bound<'py, PyAny>> {
363        match parser.consume_number::<NumberRange>(peek.into_inner(), allow_inf_nan) {
364            Ok(number_range) => {
365                let bytes = parser.slice(number_range.range).unwrap();
366                let obj = if number_range.is_int {
367                    NumberAny::decode(bytes, 0, peek.into_inner(), allow_inf_nan)?
368                        .0
369                        .to_object(py)
370                } else {
371                    LosslessFloat::new_unchecked(bytes.to_vec()).into_py(py)
372                };
373                Ok(obj.into_bound(py))
374            }
375            Err(e) => {
376                if !peek.is_num() {
377                    Err(json_error!(ExpectedSomeValue, parser.index))
378                } else {
379                    Err(e)
380                }
381            }
382        }
383    }
384}
385
386struct ParseNumberDecimal;
387
388impl MaybeParseNumber for ParseNumberDecimal {
389    fn parse_number<'py>(
390        py: Python<'py>,
391        parser: &mut Parser,
392        peek: Peek,
393        allow_inf_nan: bool,
394    ) -> JsonResult<Bound<'py, PyAny>> {
395        match parser.consume_number::<NumberRange>(peek.into_inner(), allow_inf_nan) {
396            Ok(number_range) => {
397                let bytes = parser.slice(number_range.range).unwrap();
398                if number_range.is_int {
399                    let obj = NumberAny::decode(bytes, 0, peek.into_inner(), allow_inf_nan)?
400                        .0
401                        .to_object(py);
402                    Ok(obj.into_bound(py))
403                } else {
404                    let decimal_type = get_decimal_type(py)
405                        .map_err(|e| JsonError::new(JsonErrorType::InternalError(e.to_string()), parser.index))?;
406                    // SAFETY: NumberRange::decode has already confirmed that bytes are a valid JSON number,
407                    // and therefore valid str
408                    let float_str = unsafe { std::str::from_utf8_unchecked(bytes) };
409                    decimal_type
410                        .call1((float_str,))
411                        .map_err(|e| JsonError::new(JsonErrorType::InternalError(e.to_string()), parser.index))
412                }
413            }
414            Err(e) => {
415                if !peek.is_num() {
416                    Err(json_error!(ExpectedSomeValue, parser.index))
417                } else {
418                    Err(e)
419                }
420            }
421        }
422    }
423}