rsonpath-lib 0.6.1

Blazing fast JSONPath query engine powered by SIMD. Core library of `rsonpath`.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
//! This module can only be included if the code is compiled with AVX2 support
//! and on x86/x86_64 architecture for safety.
cfg_if::cfg_if! {
    if #[cfg(not(all(
        any(target_arch = "x86", target_arch = "x86_64"),
        simd = "avx2")
    ))] {
        compile_error!{
            "internal error: AVX2 code included on unsupported target; \
            please report this issue at https://github.com/V0ldek/rsonpath/issues/new?template=bug_report.md"
        }
    }
}

use super::*;
use crate::debug;
use crate::input::{InputBlock, InputBlockIterator};
use crate::FallibleIterator;
use crate::{bin, block};

#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use std::marker::PhantomData;

const SIZE: usize = 64;

pub(crate) struct Avx2QuoteClassifier<'i, I>
where
    I: InputBlockIterator<'i, SIZE>,
{
    iter: I,
    classifier: BlockAvx2Classifier,
    phantom: PhantomData<&'i ()>,
}

impl<'i, I> Avx2QuoteClassifier<'i, I>
where
    I: InputBlockIterator<'i, SIZE>,
{
    #[inline]
    pub(crate) fn new(iter: I) -> Self {
        Self {
            iter,
            // SAFETY: target_feature invariant
            classifier: unsafe { BlockAvx2Classifier::new() },
            phantom: PhantomData,
        }
    }

    #[inline]
    pub(crate) fn resume(
        iter: I,
        first_block: Option<I::Block>,
    ) -> (Self, Option<QuoteClassifiedBlock<I::Block, SIZE>>) {
        let mut s = Self {
            iter,
            // SAFETY: target feature invariant
            classifier: unsafe { BlockAvx2Classifier::new() },
            phantom: PhantomData,
        };

        let block = first_block.map(|b| {
            // SAFETY: target feature invariant
            let mask = unsafe { s.classifier.classify(&b) };
            QuoteClassifiedBlock {
                block: b,
                within_quotes_mask: mask,
            }
        });

        (s, block)
    }
}

impl<'i, I> FallibleIterator for Avx2QuoteClassifier<'i, I>
where
    I: InputBlockIterator<'i, SIZE>,
{
    type Item = QuoteClassifiedBlock<I::Block, SIZE>;
    type Error = InputError;

    fn next(&mut self) -> Result<Option<Self::Item>, Self::Error> {
        match self.iter.next()? {
            Some(block) => {
                // SAFETY: target_feature invariant
                let mask = unsafe { self.classifier.classify(&block) };
                let classified_block = QuoteClassifiedBlock {
                    block,
                    within_quotes_mask: mask,
                };
                Ok(Some(classified_block))
            }
            None => Ok(None),
        }
    }
}

impl<'i, I> QuoteClassifiedIterator<'i, I, SIZE> for Avx2QuoteClassifier<'i, I>
where
    I: InputBlockIterator<'i, SIZE>,
{
    fn get_offset(&self) -> usize {
        self.iter.get_offset() - 64
    }

    fn offset(&mut self, count: isize) {
        debug_assert!(count >= 0);
        debug!("Offsetting by {count}");

        if count == 0 {
            return;
        }

        self.iter.offset(count);
    }

    fn flip_quotes_bit(&mut self) {
        self.classifier.flip_prev_quote_mask();
    }
}

impl<'i, I> InnerIter<I> for Avx2QuoteClassifier<'i, I>
where
    I: InputBlockIterator<'i, SIZE>,
{
    fn into_inner(self) -> I {
        self.iter
    }
}

struct BlockAvx2Classifier {
    /// Compressed information about the state from the previous block.
    /// The first bit is lit iff the previous block ended with an unescaped escape character.
    /// The second bit is lit iff the previous block ended with a starting quote,
    /// meaning that it was not escaped, nor was it the closing quote of a quoted sequence.
    prev_block_mask: u8,
}

struct BlockClassification {
    slashes: u32,
    quotes: u32,
}

impl BlockAvx2Classifier {
    /// Bitmask selecting bits on even positions when indexing from zero.
    const ODD: u64 = 0b0101_0101_0101_0101_0101_0101_0101_0101_0101_0101_0101_0101_0101_0101_0101_0101_u64;
    /// Bitmask selecting bits on odd positions when indexing from zero.
    const EVEN: u64 = 0b1010_1010_1010_1010_1010_1010_1010_1010_1010_1010_1010_1010_1010_1010_1010_1010_u64;

    /// Set the inter-block state based on slash overflow and the quotes mask.
    fn update_prev_block_mask(&mut self, set_slash_mask: bool, quotes: u64) {
        let slash_mask = u8::from(set_slash_mask);
        let quote_mask = (((quotes & (1 << 63)) >> 62) as u8) & 0x02;
        self.prev_block_mask = slash_mask | quote_mask;
    }

    /// Flip the inter-block state bit representing the quote state.
    fn flip_prev_quote_mask(&mut self) {
        self.prev_block_mask ^= 0x02;
    }

    /// Returns 0x01 if the last character of the previous block was an unescaped escape character,
    /// zero otherwise.
    fn get_prev_slash_mask(&self) -> u64 {
        u64::from(self.prev_block_mask & 0x01)
    }

    /// Returns 0x01 if the last character of the previous block was an unescaped quote, zero otherwise.
    fn get_prev_quote_mask(&self) -> u64 {
        u64::from((self.prev_block_mask & 0x02) >> 1)
    }

    #[target_feature(enable = "avx2")]
    unsafe fn new() -> Self {
        Self { prev_block_mask: 0 }
    }

    #[target_feature(enable = "avx2")]
    unsafe fn quote_mask() -> __m256i {
        _mm256_set1_epi8(b'"' as i8)
    }

    #[target_feature(enable = "avx2")]
    unsafe fn slash_mask() -> __m256i {
        _mm256_set1_epi8(b'\\' as i8)
    }

    #[target_feature(enable = "avx2")]
    unsafe fn all_ones128() -> __m128i {
        _mm_set1_epi8(0xFF_u8 as i8)
    }

    #[target_feature(enable = "avx2")]
    #[target_feature(enable = "pclmulqdq")]
    unsafe fn classify<'a, B: InputBlock<'a, 64>>(&mut self, two_blocks: &B) -> u64 {
        /* For a 64-bit architecture, we classify two adjacent 32-byte blocks and combine their masks
         * into a single 64-bit mask, which is significantly more performant.
         *
         * The step-by-step algorithm for determining which characters are within quotes is as follows:
         *   I. Determine which characters are escaped.
         *      1. Find all backslashes '\' and produce a 64-bit bitmask marking their positions.
         *      2. Identify backslashes not preceded by any other backslashes, the "starts".
         *      3. Find the "ends", positions right after a contiguous sequences of backslashes.
         *          a) Use the "add-carry trick".
         *          b) Do this separately for "starts" at even and odd positions.
         *      4. If an "end" of an even-position "start" occurs at an odd position, it is escaped.
         *         Analogously for "ends" of odd-position "starts" occurring at even positions.
         *   II. Determine quoted sequences.
         *      1. Find all quotes '"' and produce a 64-bit bitmask marking their positions.
         *      2. Exclude escaped quotes based on step I.
         *      3. Mark all characters between quotes by running a cumulative XOR on the bitmask.
         */

        // Steps I.1., II.1.
        let (block1, block2) = two_blocks.halves();
        let classification1 = Self::classify_block(block1);
        let classification2 = Self::classify_block(block2);

        // Masks are combined by shifting the latter block's 32-bit masks left by 32 bits.
        // From now on when we refer to a "block" we mean the combined 64 bytes of the input.
        let slashes = u64::from(classification1.slashes) | (u64::from(classification2.slashes) << 32);
        let quotes = u64::from(classification1.quotes) | (u64::from(classification2.quotes) << 32);

        let (escaped, set_prev_slash_mask) = if slashes == 0 {
            // If there are no slashes in the input steps I.2, I.3, I.4 can be skipped.
            (self.get_prev_slash_mask(), false)
        } else {
            /* Step I.2.
             *
             * A character is a start of the sequence if it is not preceded by a backslash.
             * We also check whether the last character of the previous block was an unescaped backslash
             * to correctly classify the first character in the block.
             *
             * Visualization for 8-byte-long blocks:
             *                  | prev bl.|curr bl. |
             *  bitmask index   | 76543210 76543210 |
             *  input           | \x\\\\x\ \x\\\x\\ |
             *  slashes         | 10111101 10111011 |
             *  slashes << 1    | 01011110 01011101 |
             *  prev_slash      | 00000000 10000000 |
             *  starts          | 10100001 00100010 |
             *  even_starts     | 00000001 00000000 |
             *  odd_starts      | 10100000 00100010 |
             */

            let slashes_excluding_escaped_first = slashes & !self.get_prev_slash_mask();
            let starts = slashes_excluding_escaped_first & !(slashes_excluding_escaped_first << 1);
            let odd_starts = Self::ODD & starts;
            let even_starts = Self::EVEN & starts;

            /* Step I.3.
             *
             * Recall that in binary arithmetic an addition of two ones at the same place
             * causes a carry - the result bit is set to zero, and the one is carried forward to the next place.
             * To find an end of a contiguous sequence of ones we can use an "add-carry trick" - by adding a number
             * with a bit set exactly at the start of the sequence and adding it to the original mask we cause a carry
             * that propagates up until the end of the sequence.
             *
             * This can overflow, so we use `wrapping_add` to ignore that. In case of the slashes starting at even
             * positions we want to explicitly check for that overflow - if it occurs, it means that all the bits
             * from some even position `i` up to the position `0` were lit, and thus the backslash at position `0`
             * is _not_ escaped (since there was an even number of backslashes preceding it).
             * We should therefore set the `prev_slash` mask if and only if an overflow occurs here.
             *
             * Visualization for 8-byte-long blocks:
             *                    | prev bl.|curr bl. |
             *  bitmask index     | 76543210 76543210 |
             *  input             | \x\\\\x\ \x\\\x\\ |
             *  slashes           | 10111101 10111011 |
             *  even_starts       | 00000001 00000000 |
             *  even_starts_carry | 10111100 10111011 | <-- Overflow occurs!
             *  slashes           | 10111101 10111011 |
             *  odd_starts        | 10100000 00100010 |
             *  odd_starts_carry  | 01000011 10000100 | <-- Overflow occurs, but is inconsequential.
             */

            let odd_starts_carry = odd_starts.wrapping_add(slashes);
            let (even_starts_carry, set_prev_slash_mask) = even_starts.overflowing_add(slashes);

            // We need to exclude `slashes`, as the ones from the opposite-parity positions
            // cause noise in the mask. Note in the above how `even_starts_carry` contains
            // almost all bits copied over from slashes that did not cause a carry, but
            // in actuality the only "end of an even start" is the one lost to overflow.
            let ends_of_odd_starts = odd_starts_carry & !slashes;
            let ends_of_even_starts = even_starts_carry & !slashes;

            /* Step I.4.
             *
             * Find the characters preceded by a contiguous sequence of backslashes of odd length.
             * Note that the `escaped` mask is completely arbitrary for the backslash characters themselves,
             * but that is irrelevant to any further processing steps.
             *
             * Visualization for 8-byte-long blocks:
             *                      | prev bl.|curr bl. |
             *  bitmask index       | 76543210 76543210 |
             *  input               | \x\\\\x\ \x\\\x\\ |
             *  ends_of_odd_starts  | 01000010 00000100 |
             *  ends_of_even_starts | 00000001 00000000 |
             *  prev_slash          | 00000000 10000000 |
             *  escaped             | 01000000 10000100 |
             */
            let escaped =
                (ends_of_odd_starts & Self::EVEN) | (ends_of_even_starts & Self::ODD) | self.get_prev_slash_mask();

            (escaped, set_prev_slash_mask)
        };

        /* Step II.2.
         *
         * Select only unescaped quotes.
         *
         * We also check whether the last character of the previous block was still within quotes
         * and flip the first bit if it was. Assume that is the case - then there are two possibilities:
         *  1. The first character of the current block was a quote.
         *     That quote is then not marked as an unescaped quote, but clearly it was a closing quote,
         *     so it can be safely ignored.
         *  2. The first character of the current block was not a quote.
         *     As it follows from the clmul operation, the first character in the current block will then
         *     correctly be marked as quoted.
         */
        let nonescaped_quotes = (quotes & !escaped) ^ self.get_prev_quote_mask();

        /*
         * Step II.3.
         *
         * The clmul operation's semantics when given a 128-bit vector `a` as the first operand and
         * an all-ones 128-bit vector `b` as the second operand are the same as a cumulative XOR.
         * Therefore, a lit bit of `nonescaped_quotes` will be "spread" up until a pairing lit bit
         * occurs in the mask, which exactly corresponds to marking all characters after a quote
         * up until the pairing closing quote is found.
         *
         * We only use the lower 64 bits of the vector, so we first copy the mask with `_mm_set_epi64x`
         * and then extract the 64-bit result with `_mm_cvtsi128_si64`.
         *
         * Again, note that the quoted classification for the delimiting quotes themselves can be arbitrary.
         *
         * Visualization for 8-byte-long blocks:
         *                      | prev bl.|curr bl. |
         *  bitmask index       | 76543210 76543210 |
         *  input               | "xx"xxx" xx"xx"xx |
         *  quotes              | 10010001 00100100 |
         *  prev_quote          | 00000000 10000000 |
         *  nonescaped_quotes   | 10010001 10100100 |
         *  cumulative_xor      | 11100001 11000111 |
         */
        let nonescaped_quotes_vector = _mm_set_epi64x(0, nonescaped_quotes as i64);
        let cumulative_xor = _mm_clmulepi64_si128::<0>(nonescaped_quotes_vector, Self::all_ones128());

        let within_quotes = _mm_cvtsi128_si64(cumulative_xor) as u64;
        self.update_prev_block_mask(set_prev_slash_mask, within_quotes);

        block!(two_blocks[..64]);
        bin!("slashes", slashes);
        bin!("quotes", quotes);
        bin!("prev_slash_bit", self.get_prev_slash_mask());
        bin!("prev_quote_bit", self.get_prev_quote_mask());
        bin!("escaped", escaped);
        bin!("quotes & !escaped", quotes & !escaped);
        bin!("nonescaped_quotes", nonescaped_quotes);
        bin!("within_quotes", within_quotes);

        within_quotes
    }

    #[target_feature(enable = "avx2")]
    unsafe fn classify_block(block: &[u8]) -> BlockClassification {
        let byte_vector = _mm256_loadu_si256(block.as_ptr().cast::<__m256i>());

        let slash_cmp = _mm256_cmpeq_epi8(byte_vector, Self::slash_mask());
        let slashes = _mm256_movemask_epi8(slash_cmp) as u32;

        let quote_cmp = _mm256_cmpeq_epi8(byte_vector, Self::quote_mask());
        let quotes = _mm256_movemask_epi8(quote_cmp) as u32;

        BlockClassification { slashes, quotes }
    }
}

#[cfg(test)]
mod tests {
    use super::Avx2QuoteClassifier;
    use crate::{
        input::{Input, OwnedBytes},
        result::empty::EmptyRecorder,
        FallibleIterator,
    };
    use test_case::test_case;

    #[test_case("" => None)]
    #[test_case("abcd" => Some(0))]
    #[test_case(r#""abcd""# => Some(0b01_1111))]
    #[test_case(r#""number": 42, "string": "something" "# => Some(0b0011_1111_1111_0001_1111_1100_0000_0111_1111))]
    #[test_case(r#"abc\"abc\""# => Some(0b00_0000_0000))]
    #[test_case(r#"abc\\"abc\\""# => Some(0b0111_1110_0000))]
    #[test_case(r#"{"aaa":[{},{"b":{"c":[1,2,3]}}],"e":{"a":[[],[1,2,3],"# => Some(0b0_0000_0000_0000_0110_0011_0000_0000_0000_0110_0011_0000_0001_1110))]
    fn single_block(str: &str) -> Option<u64> {
        let owned_str = str.to_owned();
        let input = OwnedBytes::new(&owned_str).unwrap();
        let iter = input.iter_blocks::<_, 64>(&EmptyRecorder);
        let mut classifier = Avx2QuoteClassifier::new(iter);
        classifier.next().unwrap().map(|x| x.within_quotes_mask)
    }
}