1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
//! Ergonomic Rust bindings to the JavaScript standard built-in `RegExp` object
//!
//! ### Basic usage
//! ```
//! use js_regexp::{RegExp, Flags}
//!
//! let re = RegExp::new(
//!     r#"(?<greeting>\w+), (?<name>\w+)"#,
//!     Flags::new("d").unwrap(),
//! )
//! .unwrap();
//!
//! let result = re.exec("Hello, Alice!").unwrap();
//! let named_captures = result.captures.unwrap();
//! let named_captures = named_captures.get_named_captures_map();
//!
//! assert_eq!("Hello, Alice", result.match_slice);
//! assert_eq!(0, result.match_index);
//! assert_eq!(12, result.match_length);
//! assert_eq!("Hello", named_captures.get("greeting").unwrap().slice);
//! assert_eq!(7, named_captures.get("name").unwrap().index);
//! ```

use anyhow::Context;
use js_sys::{Function, JsString};
use std::{
    collections::HashMap,
    hash::{Hash, Hasher},
};
use thiserror::Error;
use wasm_bindgen::{JsCast, JsValue};

/// A wrapped JavaScript `RegExp`. The main type of this crate.
///
/// [MDN documentation](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/RegExp)
#[derive(Debug)]
pub struct RegExp<'p> {
    inner: js_sys::RegExp,
    pattern: PatternSource<'p>,
    flags: Flags,
}
impl<'p> RegExp<'p> {
    /// Constructs a new regular expression, backed by a `RegExp` in JavaScript. \
    /// Returns an error if JavaScript throws a SyntaxError exception. \
    /// When constructed by this function, the returned value's lifetime becomes tied to the
    /// provided `&str` pattern. See [`new_with_ownership`](RegExp::new_with_ownership)
    /// for an alternative that takes ownership of a `String` pattern instead.
    pub fn new(pattern: &'p str, flags: Flags) -> Result<Self, JsValue> {
        Ok(Self {
            inner: construct_regexp_panicking(pattern, flags.build())?,
            pattern: PatternSource::Ref(pattern),
            flags,
        })
    }
    /// Constructs a new regular expression, backed by a `RegExp` in JavaScript. \
    /// Returns an error if JavaScript throws a SyntaxError exception. \
    /// Takes ownership of the provided `String` pattern. Use [`new`](RegExp::new) instead if you have a `&'static str`,
    /// or if it otherwise makes sense for the constructed value to store only a reference to your pattern.
    pub fn new_with_ownership(pattern: String, flags: Flags) -> Result<Self, JsValue> {
        Ok(Self {
            inner: construct_regexp_panicking(&pattern, flags.build())?,
            pattern: PatternSource::Owned(pattern),
            flags,
        })
    }
    /// Constructs a new regular expression, backed by a `RegExp` in JavaScript. \
    /// Returns an error if JavaScript throws a SyntaxError exception. \
    /// Unlike with [`new`](RegExp::new), the returned structure does not hold on to a reference to the provided
    /// `&str` pattern. This is achieved by copying any group names from the JavaScript heap every time the regular expression
    /// is used.
    pub fn new_with_copying(pattern: &str, flags: Flags) -> Result<Self, JsValue> {
        Ok(Self {
            inner: construct_regexp_panicking(pattern, flags.build())?,
            pattern: PatternSource::Copy,
            flags,
        })
    }

    /// Calls the underlying JavaScript `RegExp`'s `exec` method. \
    /// Returns `None` if the JavaScript call returns null.
    /// The returned [`ExecResult`]'s `captures` member is `None` if the underlying JavaScript call returns an object
    /// that does not have an `indices` property, which is only present when the [`d` flag](FlagSets.set_has_indices)
    /// is set for the expression.
    ///
    /// [MDN documentation](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/RegExp/exec)
    pub fn exec<'h>(&'p self, haystack: &'h str) -> Option<ExecResult<'h, 'p>> {
        match self.exec_internal(haystack) {
            Ok(v) => v,
            Err(v) => panic!("{:?}", v),
        }
    }
    /// Returns a read-only reference to the flags set for this regular expression.
    pub fn inspect_flags(&self) -> &FlagSets {
        &self.flags.sets
    }
    fn exec_internal<'h>(
        &'p self,
        haystack: &'h str,
    ) -> Result<Option<ExecResult<'h, 'p>>, JsError> {
        let result = match self.inner.exec(haystack) {
            Some(v) => v,
            None => return Ok(None),
        };

        let utf16_match_index = get_value_property_str("index", &result)?
            .as_f64()
            .option_context()? as usize;
        let utf8_match_index =
            count_bytes_from_utf16_units(haystack, utf16_match_index).option_context()?;
        let matched = &haystack[utf8_match_index..];
        let string_match_js = result.iter().next().option_context()?;
        let string_match_js: &JsString = string_match_js.dyn_ref().option_context()?;
        let utf16_match_length = string_match_js.length() as usize;
        let utf8_match_length =
            count_bytes_from_utf16_units(matched, utf16_match_length).option_context()?;
        let matched = &matched[..utf8_match_length];
        let indices_array_js = get_value_property_str("indices", &result)?;

        let mut exec_result = ExecResult {
            match_slice: matched,
            match_index: utf8_match_index,
            match_length: utf16_match_length,
            captures: None,
        };
        if !indices_array_js.is_array() {
            return Ok(Some(exec_result));
        }

        let mut captures_vec = Vec::new();
        let js_array_iter = js_sys::try_iter(&indices_array_js)?.option_context()?;
        for indices_js in js_array_iter.skip(1) {
            let indices_js = indices_js?;
            let capture = slice_capture(haystack, &indices_js)?;
            captures_vec.push(capture);
        }
        let named_indices_js = get_value_property_str("groups", &indices_array_js)?;
        if !named_indices_js.is_object() {
            let _ = exec_result
                .captures
                .insert(CapturesList { vec: captures_vec });
            return Ok(Some(exec_result));
        }

        let group_names = js_sys::Reflect::own_keys(&named_indices_js)?;
        for group_name_js in group_names.iter() {
            let group_name_js: JsString = group_name_js.dyn_into()?;
            let indices_js = js_sys::Reflect::get(&named_indices_js, &group_name_js)?;
            let capture = slice_capture(haystack, &indices_js)?;
            let group_name = match self.pattern.get() {
                Some(pattern) => {
                    GroupName::Ref(find_js_string(pattern, &group_name_js).option_context()?)
                }
                None => GroupName::Owned(group_name_js.as_string().option_context()?),
            };
            let _ = captures_vec
                .iter_mut()
                .find(|v| v.index == capture.index && v.length == capture.length)
                .option_context()?
                .group_name
                .insert(group_name);
        }

        let _ = exec_result
            .captures
            .insert(CapturesList { vec: captures_vec });
        Ok(Some(exec_result))
    }
}

/// An error that occurs when something unexpected happens
/// while interacting with JavaScript.
#[derive(Debug, Error)]
enum JsError {
    #[error("JavaScript exception")]
    JavaScript(JsValue),
    #[error("Other error")]
    Other(#[from] anyhow::Error),
}
impl From<JsValue> for JsError {
    fn from(value: JsValue) -> Self {
        JsError::JavaScript(value)
    }
}

trait OptionContext<T, E> {
    fn option_context(self) -> Result<T, anyhow::Error>
    where
        Self: Context<T, E>;
}
impl<T, E> OptionContext<T, E> for Option<T> {
    fn option_context(self) -> Result<T, anyhow::Error>
    where
        Self: Context<T, E>,
    {
        self.context("Unexpectedly failed to unwrap an option while interacting with JavaScript")
    }
}

#[derive(Debug)]
enum PatternSource<'a> {
    Owned(String),
    Ref(&'a str),
    Copy,
}
impl<'a> PatternSource<'a> {
    fn get(&'a self) -> Option<&'a str> {
        match self {
            PatternSource::Owned(s) => Some(s),
            PatternSource::Ref(s) => Some(s),
            PatternSource::Copy => None,
        }
    }
}

/// Boolean fields representing regular expression flags.
#[derive(Debug)]
pub struct FlagSets {
    /// The `d` flag, which causes capture indices to be returned when matching.
    /// [`ExecResult`]'s `captures` field is `None` when this flag is not set.
    /// ([MDN](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/RegExp/hasIndices#description))
    pub has_indices: bool,
    /// The `i` flag, which enables case-insensitive matching.
    /// ([MDN](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/RegExp/ignoreCase#description))
    pub ignore_case: bool,
    /// The `g` flag, which enables global matching.
    /// ([MDN](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/RegExp/global#description))
    pub global: bool,
    /// The `s` flag, which causes the `.` special character to match additonal line terminators.
    /// ([MDN](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/RegExp/dotAll#description))
    pub dot_all: bool,
    /// The `m` flag, which enables multiline matching.
    /// ([MDN](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/RegExp/multiline#description))
    pub multiline: bool,
    /// The `y` flag, which enables sticky matching.
    /// ([MDN](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/RegExp/sticky#description))
    pub sticky: bool,
    /// The `u` flag, which enables some unicode-related features.
    /// Can't be set at the same time as the `v` (`unicode_sets`) flag.
    /// ([MDN](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/RegExp/unicode#description))
    pub unicode: bool,
    /// The `v` flag, which enables a superset of the features enabled by the `u` (`unicode`) flag.
    /// Can't be set at the same time as the `u` flag.
    /// ([MDN](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/RegExp/unicodeSets#description))
    pub unicode_sets: bool,
}

/// Restrictive interface for setting regular expression flags.
#[derive(Debug)]
pub struct Flags {
    sets: FlagSets,
}
impl Flags {
    /// Takes a source flags string using the same format as the JavaScript `RegExp` constructor,
    /// but returns `None` if invalid flags, or invalid combinations of flags, are used.
    pub fn new(source: &str) -> Option<Self> {
        let mut flags = FlagSets {
            has_indices: false,
            ignore_case: false,
            global: false,
            dot_all: false,
            multiline: false,
            sticky: false,
            unicode: false,
            unicode_sets: false,
        };
        for ch in source.chars() {
            match ch {
                'd' => (!flags.has_indices).then(|| flags.has_indices = true)?,
                'i' => (!flags.ignore_case).then(|| flags.ignore_case = true)?,
                'g' => (!flags.global).then(|| flags.global = true)?,
                's' => (!flags.dot_all).then(|| flags.dot_all = true)?,
                'm' => (!flags.multiline).then(|| flags.multiline = true)?,
                'y' => (!flags.sticky).then(|| flags.sticky = true)?,
                'u' => (!flags.unicode && !flags.unicode_sets).then(|| flags.unicode = true)?,
                'v' => {
                    (!flags.unicode && !flags.unicode_sets).then(|| flags.unicode_sets = true)?
                }
                _ => return None,
            }
        }
        Some(Self { sets: flags })
    }
    /// Returns a read-only reference to the inner `FlagSets`
    pub fn inspect(&self) -> &FlagSets {
        &self.sets
    }
    fn build(&self) -> JsValue {
        let mut bytes_rep = [0u8; 7];
        let mut idx = 0;
        fn set_fn(bytes: &mut [u8; 7], idx: &mut usize, v: u8) {
            bytes[*idx] = v;
            *idx += 1;
        }
        let mut set = |v: u8| set_fn(&mut bytes_rep, &mut idx, v);
        self.sets.has_indices.then(|| set(b'd'));
        self.sets.ignore_case.then(|| set(b'i'));
        self.sets.global.then(|| set(b'g'));
        self.sets.dot_all.then(|| set(b's'));
        self.sets.multiline.then(|| set(b'm'));
        self.sets.sticky.then(|| set(b'y'));
        self.sets.unicode.then(|| set(b'u'));
        self.sets.unicode_sets.then(|| set(b'v'));
        JsValue::from_str(std::str::from_utf8(&bytes_rep[..idx]).unwrap())
    }
}

/// The result of a successful [`RegExp::exec`] call.
#[derive(Debug)]
pub struct ExecResult<'h, 'p> {
    pub match_slice: &'h str,
    pub match_index: usize,
    pub match_length: usize,
    pub captures: Option<CapturesList<'h, 'p>>,
}

/// A list of [`Capture`]s.
#[derive(Debug)]
pub struct CapturesList<'h, 'p> {
    pub vec: Vec<Capture<'h, 'p>>,
}
impl<'h, 'p> CapturesList<'h, 'p> {
    /// Maps group names to captures from the inner `Vec`
    pub fn get_named_captures_map(&self) -> HashMap<&str, &Capture<'h, 'p>> {
        let mut map = HashMap::new();
        for capture in self.vec.iter() {
            let key = match &capture.group_name {
                Some(GroupName::Owned(s)) => &s[..],
                Some(GroupName::Ref(s)) => s,
                None => continue,
            };
            map.insert(key, capture);
        }
        map
    }
}

/// An index, length, slice, and optional group name of a capture in a haystack.
#[derive(Debug)]
pub struct Capture<'h, 'p> {
    pub group_name: Option<GroupName<'p>>,
    pub index: usize,
    pub length: usize,
    pub slice: &'h str,
}

/// A name of a named capture group, backed either by a slice of a pattern or
/// an owned `String` copied from JavaScript.
#[derive(Debug)]
pub enum GroupName<'a> {
    Owned(String),
    Ref(&'a str),
}
impl PartialEq for GroupName<'_> {
    fn eq(&self, other: &Self) -> bool {
        let a = match self {
            GroupName::Owned(s) => &s[..],
            GroupName::Ref(s) => s,
        };
        let b = match other {
            GroupName::Owned(s) => &s[..],
            GroupName::Ref(s) => s,
        };
        a == b
    }
}
impl Eq for GroupName<'_> {}
impl Hash for GroupName<'_> {
    fn hash<H: Hasher>(&self, state: &mut H) {
        let s = match self {
            GroupName::Owned(s) => &s[..],
            GroupName::Ref(s) => s,
        };
        s.hash(state);
    }
}
impl<'a> From<&'a str> for GroupName<'a> {
    fn from(value: &'a str) -> Self {
        Self::Ref(value)
    }
}
impl From<String> for GroupName<'_> {
    fn from(value: String) -> Self {
        Self::Owned(value)
    }
}
impl<'a> Into<&'a str> for &'a GroupName<'a> {
    fn into(self) -> &'a str {
        match self {
            GroupName::Owned(s) => &s[..],
            GroupName::Ref(s) => s,
        }
    }
}

#[derive(Debug, Error)]
enum NewRegExpError {
    #[error("Syntax error")]
    SyntaxError(JsValue),
    #[error("Unexpected error")]
    JsError(#[from] JsError),
}
fn construct_regexp_panicking(pattern: &str, flags: JsValue) -> Result<js_sys::RegExp, JsValue> {
    construct_regexp(pattern, flags).map_err(|e| match e {
        NewRegExpError::SyntaxError(v) => v,
        NewRegExpError::JsError(e) => panic!("{:?}", e),
    })
}
fn construct_regexp(pattern: &str, flags: JsValue) -> Result<js_sys::RegExp, NewRegExpError> {
    let global = js_sys::global();
    let regexp_object = get_value_property_str("RegExp", &global).map_err(Into::<JsError>::into)?;
    let regexp_object: &Function = regexp_object
        .dyn_ref()
        .option_context()
        .map_err(Into::<JsError>::into)?;
    let args = js_sys::Array::new_with_length(2);
    args.set(0, JsValue::from_str(pattern));
    args.set(1, flags);
    let regexp = js_sys::Reflect::construct(regexp_object, &args)
        .map_err(|e| NewRegExpError::SyntaxError(e))?;
    let regexp = regexp.dyn_into().map_err(Into::<JsError>::into)?;
    Ok(regexp)
}

fn get_value_property_usize(key: usize, target: &JsValue) -> Result<JsValue, JsValue> {
    let key = key as u32;
    js_sys::Reflect::get_u32(target, key)
}

fn get_value_property_str(key: &str, target: &JsValue) -> Result<JsValue, JsValue> {
    let key = JsValue::from_str(key);
    js_sys::Reflect::get(target, &key)
}

fn slice_capture<'h, 'p>(haystack: &'h str, indices: &JsValue) -> Result<Capture<'h, 'p>, JsError> {
    let utf16_index = get_value_property_usize(0, indices)?
        .as_f64()
        .option_context()? as usize;
    let utf16_end = get_value_property_usize(1, indices)?
        .as_f64()
        .option_context()? as usize;
    let utf16_length = utf16_end - utf16_index;
    let capture = haystack;
    let utf8_begin = count_bytes_from_utf16_units(capture, utf16_index).option_context()?;
    let capture = &capture[utf8_begin..];
    let utf8_length = count_bytes_from_utf16_units(capture, utf16_length).option_context()?;
    let capture = &capture[..utf8_length];
    Ok(Capture {
        group_name: None,
        index: utf8_begin,
        length: utf8_length,
        slice: capture,
    })
}

fn find_js_string<'a>(s: &'a str, js_str: &JsString) -> Option<&'a str> {
    let mut utf16_buf = [0u16, 2];
    let mut s = s;
    let end_index = 'lvl0: loop {
        let mut js_str_iter = js_str.iter();
        let mut s_iter = s.char_indices();
        'lvl1: loop {
            let (idx, ch) = match s_iter.next() {
                Some(v) => v,
                None => {
                    break 'lvl0 match js_str_iter.next() {
                        Some(_) => None,
                        None => Some(s.len()),
                    }
                }
            };
            let units = ch.encode_utf16(&mut utf16_buf);
            for unit in units.iter() {
                let should_match = match js_str_iter.next() {
                    Some(v) => v,
                    None => {
                        break 'lvl0 Some(idx);
                    }
                };
                if unit != &should_match {
                    break 'lvl1;
                }
            }
        }
        match s.char_indices().nth(1) {
            Some((v, _)) => s = &s[v..],
            None => break None,
        }
    };
    Some(&s[0..end_index?])
}

fn count_bytes_from_utf16_units(s: &str, n_units: usize) -> Option<usize> {
    let mut n_units = n_units as isize;
    let mut i = s.char_indices();
    while n_units > 0 {
        let (_, char) = i.next()?;
        n_units -= char.len_utf16() as isize;
    }
    let bytes_counted = i.next().map(|v| v.0).unwrap_or(s.len());
    Some(bytes_counted)
}

#[cfg(test)]
mod tests {
    use wasm_bindgen::{JsCast, JsValue};
    use wasm_bindgen_test::wasm_bindgen_test;

    use crate::{count_bytes_from_utf16_units, find_js_string, slice_capture};
    wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);

    #[wasm_bindgen_test]
    fn test_flags() {
        let flags = super::Flags::new("x");
        // Rejects invalid flag
        assert!(flags.is_none());
        let flags = super::Flags::new("uv");
        // Rejects invalid combination
        assert!(flags.is_none());
        let flags = super::Flags::new("digs").unwrap();
        let sets = flags.inspect();
        assert!(sets.has_indices);
        assert!(sets.ignore_case);
        assert!(sets.global);
        assert!(sets.dot_all);
        assert!(!sets.unicode);
        // Constructs the correct flags string
        assert_eq!(flags.build().as_string().unwrap(), "digs");
    }

    #[wasm_bindgen_test]
    fn test_count_bytes_from_utf16_units() {
        let s = "cool string with fun characters such as: üöä, 宿, 漢字, and even 💙 as well as 🏳‍⚧, which is a ZWJ sequence";
        let utf16_length = 105;
        let utf8_length = s.len();
        assert_eq!(utf8_length, 122);
        assert_eq!(count_bytes_from_utf16_units(s, 87).unwrap(), 104);
        assert_eq!(
            count_bytes_from_utf16_units(s, utf16_length).unwrap(),
            utf8_length
        );
        assert!(count_bytes_from_utf16_units(s, utf16_length + 1).is_none())
    }

    #[wasm_bindgen_test]
    fn test_slice_capture() {
        let haystack = "cool string with fun characters such as: üöä, 宿, 漢字, and even 💙 as well as 🏳‍⚧, which is a ZWJ sequence";
        let begin_index_utf16 = 57;
        let end_index_utf16 = 81;
        let js_array = js_sys::Array::new_with_length(2);
        js_array.set(0, JsValue::from_f64(begin_index_utf16 as f64));
        js_array.set(1, JsValue::from_f64(end_index_utf16 as f64));
        let capture = slice_capture(haystack, &js_array).unwrap();
        assert_eq!("even 💙 as well as 🏳‍⚧,", capture.slice)
    }

    #[wasm_bindgen_test]
    fn test_find_js_string() {
        let s = "cool string with fun characters such as: üöä, 宿, 漢字, and even 💙 as well as 🏳‍⚧, which is a ZWJ sequence";
        let slice = find_js_string(
            s,
            &JsValue::from_str("even 💙 as well as 🏳‍⚧,")
                .dyn_into()
                .unwrap(),
        )
        .unwrap();
        assert_eq!("even 💙 as well as 🏳‍⚧,", slice)
    }
}