1use js_regexp_macros::EnumConvert;
24use js_sys::{Function, JsString};
25use std::collections::HashMap;
26use wasm_bindgen::{JsCast, JsValue};
27
28pub use js_regexp_macros::flags;
29
30#[derive(Debug)]
34pub struct RegExp<'p> {
35 inner: js_sys::RegExp,
36 pattern: PatternSource<'p>,
37 flags: Flags,
38}
39impl<'p> RegExp<'p> {
40 pub fn new(pattern: &'p str, flags: Flags) -> Result<Self, JsValue> {
46 Ok(Self {
47 inner: construct_regexp_panicking(pattern, flags.build())?,
48 pattern: PatternSource::Ref(pattern),
49 flags,
50 })
51 }
52 pub fn new_with_owned_pattern(pattern: String, flags: Flags) -> Result<Self, JsValue> {
57 Ok(Self {
58 inner: construct_regexp_panicking(&pattern, flags.build())?,
59 pattern: PatternSource::Owned(pattern),
60 flags,
61 })
62 }
63 pub fn new_with_copied_names(pattern: &str, flags: Flags) -> Result<Self, JsValue> {
69 Ok(Self {
70 inner: construct_regexp_panicking(pattern, flags.build())?,
71 pattern: PatternSource::Copy,
72 flags,
73 })
74 }
75
76 pub fn exec<'h, 'a>(&'a mut self, haystack: &'h str) -> Option<ExecResult<'h, 'a>> {
81 match self.exec_internal(haystack) {
82 Ok(v) => v,
83 Err(v) => panic!("{:?}", v),
84 }
85 }
86
87 pub fn flags(&self) -> &FlagSets {
89 &self.flags.sets
90 }
91
92 pub fn stream<'s, 'h>(&'s mut self, haystack: &'h str) -> RegExpStream<'s, 'h, 'p> {
94 RegExpStream {
95 regex: self,
96 haystack,
97 }
98 }
99
100 pub fn inner(&mut self) -> &mut js_sys::RegExp {
108 &mut self.inner
109 }
110
111 fn exec_internal<'h>(
112 &'p self,
113 haystack: &'h str,
114 ) -> Result<Option<ExecResult<'h, 'p>>, JsError> {
115 let result = match self.inner.exec(haystack) {
116 Some(v) => v,
117 None => return Ok(None),
118 };
119
120 let utf16_match_index = get_value_property_str("index", &result)?
121 .as_f64()
122 .whatever()? as usize;
123 let utf8_match_index =
124 count_bytes_from_utf16_units(haystack, utf16_match_index).whatever()?;
125 let matched = &haystack[utf8_match_index..];
126 let string_match_js = result.iter().next().whatever()?;
127 let string_match_js: &JsString = string_match_js.dyn_ref().whatever()?;
128 let utf16_match_length = string_match_js.length() as usize;
129 let utf8_match_length =
130 count_bytes_from_utf16_units(matched, utf16_match_length).whatever()?;
131 let matched = &matched[..utf8_match_length];
132 let indices_array_js = get_value_property_str("indices", &result)?;
133
134 let mut exec_result = ExecResult {
135 match_slice: matched,
136 match_index: utf8_match_index,
137 match_length: utf16_match_length,
138 captures: None,
139 };
140 if !indices_array_js.is_array() {
141 return Ok(Some(exec_result));
142 }
143
144 let mut captures_vec = Vec::new();
145 let js_array_iter = js_sys::try_iter(&indices_array_js)?.whatever()?;
146 for indices_js in js_array_iter.skip(1) {
147 let indices_js = indices_js?;
148 let capture = slice_capture(haystack, &indices_js)?;
149 captures_vec.push(capture);
150 }
151 let named_indices_js = get_value_property_str("groups", &indices_array_js)?;
152 if !named_indices_js.is_object() {
153 let _ = exec_result.captures.insert(captures_vec);
154 return Ok(Some(exec_result));
155 }
156
157 let group_names = js_sys::Reflect::own_keys(&named_indices_js)?;
158 for group_name_js in group_names.iter() {
159 let group_name_js: JsString = group_name_js.dyn_into()?;
160 let indices_js = js_sys::Reflect::get(&named_indices_js, &group_name_js)?;
161 let capture = slice_capture(haystack, &indices_js)?;
162 let group_name = match self.pattern.get() {
163 Some(pattern) => {
164 GroupName::Ref(find_js_string(pattern, &group_name_js).whatever()?)
165 }
166 None => GroupName::Owned(group_name_js.as_string().whatever()?),
167 };
168 let _ = captures_vec
169 .iter_mut()
170 .find(|v| v.index == capture.index && v.length == capture.length)
171 .whatever()?
172 .group_name
173 .insert(group_name);
174 }
175
176 let _ = exec_result.captures.insert(captures_vec);
177 Ok(Some(exec_result))
178 }
179}
180
181#[derive(Debug, EnumConvert)]
184#[enum_convert(from)]
185enum JsError {
186 JavaScript(JsValue),
187 Other(OptionFail),
188}
189#[derive(Debug)]
190struct OptionFail {}
191impl std::fmt::Display for OptionFail {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 f.write_str("Something weird happened idk")
194 }
195}
196impl std::error::Error for OptionFail {}
197trait OptionContext<T> {
198 fn whatever(self) -> Result<T, OptionFail>;
199}
200impl<T> OptionContext<T> for Option<T> {
201 fn whatever(self) -> Result<T, OptionFail> {
202 self.ok_or(OptionFail {})
203 }
204}
205
206#[derive(Debug)]
207enum PatternSource<'a> {
208 Owned(String),
209 Ref(&'a str),
210 Copy,
211}
212impl<'a> PatternSource<'a> {
213 fn get(&'a self) -> Option<&'a str> {
214 match self {
215 PatternSource::Owned(s) => Some(s),
216 PatternSource::Ref(s) => Some(s),
217 PatternSource::Copy => None,
218 }
219 }
220}
221
222#[derive(Debug)]
224pub struct FlagSets {
225 pub has_indices: bool,
230 pub ignore_case: bool,
233 pub global: bool,
236 pub dot_all: bool,
239 pub multiline: bool,
242 pub sticky: bool,
245 pub unicode: bool,
249 pub unicode_sets: bool,
253}
254impl FlagSets {
255 fn new_empty_flagsets() -> FlagSets {
256 FlagSets {
257 has_indices: false,
258 ignore_case: false,
259 global: false,
260 dot_all: false,
261 multiline: false,
262 sticky: false,
263 unicode: false,
264 unicode_sets: false,
265 }
266 }
267}
268
269#[derive(Debug)]
273pub struct Flags {
274 sets: FlagSets,
275}
276impl Flags {
277 pub fn new(source: &str) -> Option<Self> {
280 let mut flags = FlagSets::new_empty_flagsets();
281 for ch in source.chars() {
282 match ch {
283 'd' => (!flags.has_indices).then(|| flags.has_indices = true)?,
284 'i' => (!flags.ignore_case).then(|| flags.ignore_case = true)?,
285 'g' => (!flags.global).then(|| flags.global = true)?,
286 's' => (!flags.dot_all).then(|| flags.dot_all = true)?,
287 'm' => (!flags.multiline).then(|| flags.multiline = true)?,
288 'y' => (!flags.sticky).then(|| flags.sticky = true)?,
289 'u' => (!flags.unicode && !flags.unicode_sets).then(|| flags.unicode = true)?,
290 'v' => {
291 (!flags.unicode && !flags.unicode_sets).then(|| flags.unicode_sets = true)?
292 }
293 _ => return None,
294 };
295 }
296 Some(Self { sets: flags })
297 }
298 pub fn new_unchecked(source: &str) -> Self {
302 let mut flags = FlagSets::new_empty_flagsets();
303 for ch in source.chars() {
304 match ch {
305 'd' => flags.has_indices = true,
306 'i' => flags.ignore_case = true,
307 'g' => flags.global = true,
308 's' => flags.dot_all = true,
309 'm' => flags.multiline = true,
310 'y' => flags.sticky = true,
311 'u' => flags.unicode = true,
312 'v' => flags.unicode_sets = true,
313 _ => {}
314 }
315 }
316 Self { sets: flags }
317 }
318 pub fn inspect(&self) -> &FlagSets {
320 &self.sets
321 }
322 fn build(&self) -> JsValue {
323 let mut bytes_rep = [0u8; 8];
327 let mut idx = 0;
328 fn set_fn(bytes: &mut [u8; 8], idx: &mut usize, v: u8) {
329 bytes[*idx] = v;
330 *idx += 1;
331 }
332 let mut set = |v: u8| set_fn(&mut bytes_rep, &mut idx, v);
333 self.sets.has_indices.then(|| set(b'd'));
334 self.sets.ignore_case.then(|| set(b'i'));
335 self.sets.global.then(|| set(b'g'));
336 self.sets.dot_all.then(|| set(b's'));
337 self.sets.multiline.then(|| set(b'm'));
338 self.sets.sticky.then(|| set(b'y'));
339 self.sets.unicode.then(|| set(b'u'));
340 self.sets.unicode_sets.then(|| set(b'v'));
341 JsValue::from_str(std::str::from_utf8(&bytes_rep[..idx]).unwrap())
342 }
343}
344impl From<&str> for Flags {
345 fn from(value: &str) -> Self {
347 Self::new_unchecked(value)
348 }
349}
350
351#[derive(Debug)]
353pub struct ExecResult<'h, 'p> {
354 pub match_slice: &'h str,
355 pub match_index: usize,
356 pub match_length: usize,
357 captures: Option<Vec<Capture<'h, 'p>>>,
358}
359impl ExecResult<'_, '_> {
360 pub fn captures(&self) -> Option<&Vec<Capture<'_, '_>>> {
363 self.captures.as_ref()
364 }
365 pub fn named_captures(&self) -> Option<HashMap<&str, &Capture<'_, '_>>> {
368 let captures = self.captures.as_ref()?;
369 let mut map = HashMap::new();
370 for capture in captures.iter() {
371 if let Some(v) = capture.name() {
372 map.insert(v, capture);
373 };
374 }
375 Some(map)
376 }
377}
378
379#[derive(Debug)]
381pub struct Capture<'h, 'p> {
382 group_name: Option<GroupName<'p>>,
383 pub index: usize,
384 pub length: usize,
385 pub slice: &'h str,
386}
387impl Capture<'_, '_> {
388 pub fn name(&self) -> Option<&str> {
389 Some(self.group_name.as_ref()?.into())
390 }
391}
392
393#[derive(Debug)]
394enum GroupName<'a> {
395 Owned(String),
396 Ref(&'a str),
397}
398impl<'a> Into<&'a str> for &'a GroupName<'a> {
399 fn into(self) -> &'a str {
400 match self {
401 GroupName::Owned(s) => &s[..],
402 GroupName::Ref(s) => s,
403 }
404 }
405}
406
407pub struct RegExpStream<'r, 'h, 'p> {
409 regex: &'r mut RegExp<'p>,
410 haystack: &'h str,
411}
412impl RegExpStream<'_, '_, '_> {
413 pub fn next<'s>(&'s mut self) -> Option<ExecResult<'s, 's>> {
415 self.regex.exec(self.haystack)
416 }
417}
418
419#[derive(Debug, EnumConvert)]
420#[enum_convert(from)]
421enum NewRegExpError {
422 SyntaxError(JsValue),
423 JsError(JsError),
424}
425fn construct_regexp_panicking(pattern: &str, flags: JsValue) -> Result<js_sys::RegExp, JsValue> {
426 construct_regexp(pattern, flags).map_err(|e| match e {
427 NewRegExpError::SyntaxError(v) => v,
428 NewRegExpError::JsError(e) => panic!("{:?}", e),
429 })
430}
431fn construct_regexp(pattern: &str, flags: JsValue) -> Result<js_sys::RegExp, NewRegExpError> {
432 let global = js_sys::global();
433 let regexp_object = get_value_property_str("RegExp", &global).map_err(Into::<JsError>::into)?;
434 let regexp_object: &Function = regexp_object
435 .dyn_ref()
436 .whatever()
437 .map_err(Into::<JsError>::into)?;
438 let args = js_sys::Array::new_with_length(2);
439 args.set(0, JsValue::from_str(pattern));
440 args.set(1, flags);
441 let regexp = js_sys::Reflect::construct(regexp_object, &args)
442 .map_err(|e| NewRegExpError::SyntaxError(e))?;
443 let regexp = regexp.dyn_into().map_err(Into::<JsError>::into)?;
444 Ok(regexp)
445}
446
447fn get_value_property_usize(key: usize, target: &JsValue) -> Result<JsValue, JsValue> {
448 let key = key as u32;
449 js_sys::Reflect::get_u32(target, key)
450}
451
452fn get_value_property_str(key: &str, target: &JsValue) -> Result<JsValue, JsValue> {
453 let key = JsValue::from_str(key);
454 js_sys::Reflect::get(target, &key)
455}
456
457fn slice_capture<'h, 'p>(haystack: &'h str, indices: &JsValue) -> Result<Capture<'h, 'p>, JsError> {
458 let utf16_index = get_value_property_usize(0, indices)?.as_f64().whatever()? as usize;
459 let utf16_end = get_value_property_usize(1, indices)?.as_f64().whatever()? as usize;
460 let utf16_length = utf16_end - utf16_index;
461 let capture = haystack;
462 let utf8_begin = count_bytes_from_utf16_units(capture, utf16_index).whatever()?;
463 let capture = &capture[utf8_begin..];
464 let utf8_length = count_bytes_from_utf16_units(capture, utf16_length).whatever()?;
465 let capture = &capture[..utf8_length];
466 Ok(Capture {
467 group_name: None,
468 index: utf8_begin,
469 length: utf8_length,
470 slice: capture,
471 })
472}
473
474fn find_js_string<'a>(s: &'a str, js_str: &JsString) -> Option<&'a str> {
475 let mut utf16_buf = [0u16, 2];
476 let mut s = s;
477 let end_index = 'lvl0: loop {
478 let mut js_str_iter = js_str.iter();
479 let mut s_iter = s.char_indices();
480 'lvl1: loop {
481 let (idx, ch) = match s_iter.next() {
482 Some(v) => v,
483 None => {
484 break 'lvl0 match js_str_iter.next() {
485 Some(_) => None,
486 None => Some(s.len()),
487 }
488 }
489 };
490 let units = ch.encode_utf16(&mut utf16_buf);
491 for unit in units.iter() {
492 let should_match = match js_str_iter.next() {
493 Some(v) => v,
494 None => {
495 break 'lvl0 Some(idx);
496 }
497 };
498 if unit != &should_match {
499 break 'lvl1;
500 }
501 }
502 }
503 match s.char_indices().nth(1) {
504 Some((v, _)) => s = &s[v..],
505 None => break None,
506 }
507 };
508 Some(&s[0..end_index?])
509}
510
511fn count_bytes_from_utf16_units(s: &str, n_units: usize) -> Option<usize> {
512 let mut n_units = n_units as isize;
513 let mut i = s.char_indices();
514 while n_units > 0 {
515 let (_, char) = i.next()?;
516 n_units -= char.len_utf16() as isize;
517 }
518 let bytes_counted = i.next().map(|v| v.0).unwrap_or(s.len());
519 Some(bytes_counted)
520}
521
522#[cfg(test)]
523mod tests {
524 use wasm_bindgen::{JsCast, JsValue};
525 use wasm_bindgen_test::wasm_bindgen_test;
526
527 use crate::{count_bytes_from_utf16_units, find_js_string, slice_capture};
528 wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
529
530 #[wasm_bindgen_test]
531 fn test_flags() {
532 let flags = super::Flags::new("x");
533 assert!(flags.is_none());
535 let flags = super::Flags::new("uv");
536 assert!(flags.is_none());
538 let flags = super::Flags::new("digs").unwrap();
539 let sets = flags.inspect();
540 assert!(sets.has_indices);
541 assert!(sets.ignore_case);
542 assert!(sets.global);
543 assert!(sets.dot_all);
544 assert!(!sets.unicode);
545 assert_eq!(flags.build().as_string().unwrap(), "digs");
547 }
548
549 #[wasm_bindgen_test]
550 fn test_count_bytes_from_utf16_units() {
551 let s = "cool string with fun characters such as: üöä, 宿, 漢字, and even 💙 as well as 🏳⚧, which is a ZWJ sequence";
552 let utf16_length = 105;
553 let utf8_length = s.len();
554 assert_eq!(utf8_length, 122);
555 assert_eq!(count_bytes_from_utf16_units(s, 87).unwrap(), 104);
556 assert_eq!(
557 count_bytes_from_utf16_units(s, utf16_length).unwrap(),
558 utf8_length
559 );
560 assert!(count_bytes_from_utf16_units(s, utf16_length + 1).is_none())
561 }
562
563 #[wasm_bindgen_test]
564 fn test_slice_capture() {
565 let haystack = "cool string with fun characters such as: üöä, 宿, 漢字, and even 💙 as well as 🏳⚧, which is a ZWJ sequence";
566 let begin_index_utf16 = 57;
567 let end_index_utf16 = 81;
568 let js_array = js_sys::Array::new_with_length(2);
569 js_array.set(0, JsValue::from_f64(begin_index_utf16 as f64));
570 js_array.set(1, JsValue::from_f64(end_index_utf16 as f64));
571 let capture = slice_capture(haystack, &js_array).unwrap();
572 assert_eq!("even 💙 as well as 🏳⚧,", capture.slice)
573 }
574
575 #[wasm_bindgen_test]
576 fn test_find_js_string() {
577 let s = "cool string with fun characters such as: üöä, 宿, 漢字, and even 💙 as well as 🏳⚧, which is a ZWJ sequence";
578 let slice = find_js_string(
579 s,
580 &JsValue::from_str("even 💙 as well as 🏳⚧,")
581 .dyn_into()
582 .unwrap(),
583 )
584 .unwrap();
585 assert_eq!("even 💙 as well as 🏳⚧,", slice)
586 }
587}