hstr/
wtf8_atom.rs

1use std::{
2    fmt::Debug,
3    hash::Hash,
4    mem::{forget, transmute, ManuallyDrop},
5    ops::Deref,
6};
7
8use debug_unreachable::debug_unreachable;
9
10use crate::{
11    macros::{get_hash, impl_from_alias, partial_eq},
12    tagged_value::TaggedValue,
13    wtf8::Wtf8,
14    Atom, DYNAMIC_TAG, INLINE_TAG, LEN_MASK, LEN_OFFSET, TAG_MASK,
15};
16
17/// A WTF-8 encoded atom. This is like [Atom], but can contain unpaired
18/// surrogates.
19///
20/// [Atom]: crate::Atom
21#[repr(transparent)]
22pub struct Wtf8Atom {
23    pub(crate) unsafe_data: TaggedValue,
24}
25
26impl Wtf8Atom {
27    #[inline(always)]
28    pub fn new<S>(s: S) -> Self
29    where
30        Self: From<S>,
31    {
32        Self::from(s)
33    }
34
35    /// Try to convert this to a UTF-8 [Atom].
36    ///
37    /// Returns [Atom] if the string is valid UTF-8, otherwise returns
38    /// the original [Wtf8Atom].
39    pub fn try_into_atom(self) -> Result<Atom, Wtf8Atom> {
40        if self.as_str().is_some() {
41            let atom = ManuallyDrop::new(self);
42            Ok(Atom {
43                unsafe_data: atom.unsafe_data,
44            })
45        } else {
46            Err(self)
47        }
48    }
49
50    #[inline(always)]
51    fn tag(&self) -> u8 {
52        self.unsafe_data.tag() & TAG_MASK
53    }
54
55    /// Return true if this is a dynamic Atom.
56    #[inline(always)]
57    fn is_dynamic(&self) -> bool {
58        self.tag() == DYNAMIC_TAG
59    }
60}
61
62impl Default for Wtf8Atom {
63    #[inline(never)]
64    fn default() -> Self {
65        Wtf8Atom::new("")
66    }
67}
68
69/// Immutable, so it's safe to be shared between threads
70unsafe impl Send for Wtf8Atom {}
71
72/// Immutable, so it's safe to be shared between threads
73unsafe impl Sync for Wtf8Atom {}
74
75impl Debug for Wtf8Atom {
76    #[inline]
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        Debug::fmt(&**self, f)
79    }
80}
81
82#[cfg(feature = "serde")]
83impl serde::ser::Serialize for Wtf8Atom {
84    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
85    where
86        S: serde::ser::Serializer,
87    {
88        use crate::wtf8::Wtf8;
89        fn convert_wtf8_to_raw(s: &Wtf8) -> String {
90            let mut result = String::new();
91            let mut iter = s.code_points().peekable();
92
93            while let Some(code_point) = iter.next() {
94                if let Some(c) = code_point.to_char() {
95                    // Escape literal '\u' sequences to avoid ambiguity with surrogate encoding.
96                    // Without this escaping, we couldn't distinguish between:
97                    // - JavaScript's "\uD800" (actual unpaired surrogate)
98                    // - JavaScript's "\\uD800" (literal text '\uD800')
99                    //
100                    // By escaping literal '\u' to '\\u', we ensure:
101                    // - Unpaired surrogates serialize as '\uXXXX'
102                    // - Literal '\u' text serializes as '\\uXXXX'
103                    //
104                    // However, we should only escape '\u' if it's followed by exactly 4 hex digits,
105                    // which would indicate a Unicode escape sequence. Otherwise, '\u' followed by
106                    // non-hex characters (like '\util') should not be escaped.
107                    if c == '\\' && iter.peek().map(|cp| cp.to_u32()) == Some('u' as u32) {
108                        // Look ahead to see if this is followed by exactly 4 hex digits
109                        let mut lookahead = iter.clone();
110                        lookahead.next(); // skip 'u'
111
112                        let mut hex_count = 0;
113                        let mut all_hex = true;
114                        for _ in 0..4 {
115                            if let Some(next_cp) = lookahead.next() {
116                                if let Some(next_c) = next_cp.to_char() {
117                                    if next_c.is_ascii_hexdigit() {
118                                        hex_count += 1;
119                                    } else {
120                                        all_hex = false;
121                                        break;
122                                    }
123                                } else {
124                                    all_hex = false;
125                                    break;
126                                }
127                            } else {
128                                all_hex = false;
129                                break;
130                            }
131                        }
132
133                        // Only escape if we have exactly 4 hex digits after '\u'
134                        if hex_count == 4 && all_hex {
135                            iter.next(); // skip 'u'
136                            result.push_str("\\\\u");
137                        } else {
138                            result.push(c);
139                        }
140                    } else {
141                        result.push(c)
142                    }
143                } else {
144                    // Unpaired surrogates can't be represented in valid UTF-8,
145                    // so encode them as '\uXXXX' for JavaScript compatibility
146                    result.push_str(format!("\\u{:04X}", code_point.to_u32()).as_str());
147                }
148            }
149
150            result
151        }
152
153        serializer.serialize_str(&convert_wtf8_to_raw(self))
154    }
155}
156
157#[cfg(feature = "serde")]
158impl<'de> serde::de::Deserialize<'de> for Wtf8Atom {
159    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
160    where
161        D: serde::Deserializer<'de>,
162    {
163        use crate::wtf8::{CodePoint, Wtf8Buf};
164        fn convert_wtf8_string_to_wtf8(s: String) -> Wtf8Buf {
165            let mut iter = s.chars().peekable();
166            let mut result = Wtf8Buf::with_capacity(s.len());
167
168            // This function reverses the encoding done in serialize.
169            // It handles two cases:
170            // 1. '\uXXXX' - Decode as an unpaired surrogate code point
171            // 2. '\\uXXXX' - Treat as literal text '\uXXXX'
172            while let Some(c) = iter.next() {
173                if c == '\\' {
174                    if iter.peek() == Some(&'u') {
175                        // Found '\u' - might be a surrogate encoding
176                        let _ = iter.next(); // skip 'u'
177
178                        // Try to read 4 hex digits
179                        let d1 = iter.next();
180                        let d2 = iter.next();
181                        let d3 = iter.next();
182                        let d4 = iter.next();
183
184                        if d1.is_some() && d2.is_some() && d3.is_some() && d4.is_some() {
185                            let hex = format!(
186                                "{}{}{}{}",
187                                d1.unwrap(),
188                                d2.unwrap(),
189                                d3.unwrap(),
190                                d4.unwrap()
191                            );
192                            if let Ok(code_point) = u16::from_str_radix(&hex, 16) {
193                                result.push(unsafe {
194                                    CodePoint::from_u32_unchecked(code_point as u32)
195                                });
196                                continue;
197                            }
198                        }
199
200                        result.push_char('\\');
201                        result.push_char('u');
202
203                        macro_rules! push_if_some {
204                            ($expr:expr) => {
205                                if let Some(c) = $expr {
206                                    result.push_char(c);
207                                }
208                            };
209                        }
210
211                        push_if_some!(d1);
212                        push_if_some!(d2);
213                        push_if_some!(d3);
214                        push_if_some!(d4);
215                    } else if iter.peek() == Some(&'\\') {
216                        // Found '\\' - this is an escaped backslash
217                        // '\\u' should become literal '\u' text
218                        let _ = iter.next(); // skip the second '\'
219                        if iter.peek() == Some(&'u') {
220                            let _ = iter.next(); // skip 'u'
221                            result.push_char('\\');
222                            result.push_char('u');
223                        } else {
224                            result.push_str("\\\\");
225                        }
226                    } else {
227                        result.push_char(c);
228                    }
229                } else {
230                    result.push_char(c);
231                }
232            }
233            result
234        }
235
236        String::deserialize(deserializer).map(|v| convert_wtf8_string_to_wtf8(v).into())
237    }
238}
239
240impl PartialEq for Wtf8Atom {
241    #[inline(never)]
242    fn eq(&self, other: &Self) -> bool {
243        partial_eq!(self, other);
244
245        // If the store is different, the string may be the same, even though the
246        // `unsafe_data` is different
247        self.as_wtf8() == other.as_wtf8()
248    }
249}
250
251impl Eq for Wtf8Atom {}
252
253impl Hash for Wtf8Atom {
254    #[inline(always)]
255    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
256        state.write_u64(self.get_hash());
257    }
258}
259
260impl Drop for Wtf8Atom {
261    #[inline(always)]
262    fn drop(&mut self) {
263        if self.is_dynamic() {
264            unsafe { drop(crate::dynamic::restore_arc(self.unsafe_data)) }
265        }
266    }
267}
268
269impl Clone for Wtf8Atom {
270    #[inline(always)]
271    fn clone(&self) -> Self {
272        Self::from_alias(self.unsafe_data)
273    }
274}
275
276impl Deref for Wtf8Atom {
277    type Target = Wtf8;
278
279    #[inline(always)]
280    fn deref(&self) -> &Self::Target {
281        self.as_wtf8()
282    }
283}
284
285impl AsRef<Wtf8> for Wtf8Atom {
286    #[inline(always)]
287    fn as_ref(&self) -> &Wtf8 {
288        self.as_wtf8()
289    }
290}
291
292impl PartialEq<Wtf8> for Wtf8Atom {
293    #[inline]
294    fn eq(&self, other: &Wtf8) -> bool {
295        self.as_wtf8() == other
296    }
297}
298
299impl PartialEq<crate::Atom> for Wtf8Atom {
300    #[inline]
301    fn eq(&self, other: &crate::Atom) -> bool {
302        self.as_str() == Some(other.as_str())
303    }
304}
305
306impl PartialEq<&'_ Wtf8> for Wtf8Atom {
307    #[inline]
308    fn eq(&self, other: &&Wtf8) -> bool {
309        self.as_wtf8() == *other
310    }
311}
312
313impl PartialEq<Wtf8Atom> for Wtf8 {
314    #[inline]
315    fn eq(&self, other: &Wtf8Atom) -> bool {
316        self == other.as_wtf8()
317    }
318}
319
320impl PartialEq<str> for Wtf8Atom {
321    #[inline]
322    fn eq(&self, other: &str) -> bool {
323        matches!(self.as_str(), Some(s) if s == other)
324    }
325}
326
327impl PartialEq<&str> for Wtf8Atom {
328    #[inline]
329    fn eq(&self, other: &&str) -> bool {
330        matches!(self.as_str(), Some(s) if s == *other)
331    }
332}
333
334impl Wtf8Atom {
335    pub(super) fn get_hash(&self) -> u64 {
336        get_hash!(self)
337    }
338
339    fn as_wtf8(&self) -> &Wtf8 {
340        match self.tag() {
341            DYNAMIC_TAG => unsafe {
342                let item = crate::dynamic::deref_from(self.unsafe_data);
343                Wtf8::from_bytes_unchecked(transmute::<&[u8], &'static [u8]>(&item.slice))
344            },
345            INLINE_TAG => {
346                let len = (self.unsafe_data.tag() & LEN_MASK) >> LEN_OFFSET;
347                let src = self.unsafe_data.data();
348                unsafe { Wtf8::from_bytes_unchecked(&src[..(len as usize)]) }
349            }
350            _ => unsafe { debug_unreachable!() },
351        }
352    }
353}
354
355impl_from_alias!(Wtf8Atom);
356
357#[cfg(test)]
358impl Wtf8Atom {
359    pub(crate) fn ref_count(&self) -> usize {
360        match self.tag() {
361            DYNAMIC_TAG => {
362                let ptr = unsafe { crate::dynamic::deref_from(self.unsafe_data) };
363
364                triomphe::ThinArc::strong_count(&ptr.0)
365            }
366            _ => 1,
367        }
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::wtf8::{CodePoint, Wtf8Buf};
375
376    #[test]
377    fn test_serialize_normal_utf8() {
378        let atom = Wtf8Atom::new("Hello, world!");
379        let serialized = serde_json::to_string(&atom).unwrap();
380        assert_eq!(serialized, "\"Hello, world!\"");
381    }
382
383    #[test]
384    fn test_deserialize_normal_utf8() {
385        let json = "\"Hello, world!\"";
386        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
387        assert_eq!(atom.as_str(), Some("Hello, world!"));
388    }
389
390    #[test]
391    fn test_serialize_unpaired_high_surrogate() {
392        // Create a WTF-8 string with an unpaired high surrogate (U+D800)
393        let mut wtf8 = Wtf8Buf::new();
394        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
395        let atom = Wtf8Atom::from(wtf8);
396
397        let serialized = serde_json::to_string(&atom).unwrap();
398        // The serialized output will have double escaping due to serde_json
399        assert_eq!(serialized, "\"\\\\uD800\"");
400    }
401
402    #[test]
403    fn test_serialize_unpaired_low_surrogate() {
404        // Create a WTF-8 string with an unpaired low surrogate (U+DC00)
405        let mut wtf8 = Wtf8Buf::new();
406        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdc00) });
407        let atom = Wtf8Atom::from(wtf8);
408
409        let serialized = serde_json::to_string(&atom).unwrap();
410        // The serialized output will have double escaping due to serde_json
411        assert_eq!(serialized, "\"\\\\uDC00\"");
412    }
413
414    #[test]
415    fn test_serialize_multiple_surrogates() {
416        // Create a WTF-8 string with multiple unpaired surrogates
417        let mut wtf8 = Wtf8Buf::new();
418        wtf8.push_str("Hello ");
419        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
420        wtf8.push_str(" World ");
421        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdc00) });
422        let atom = Wtf8Atom::from(wtf8);
423
424        let serialized = serde_json::to_string(&atom).unwrap();
425        // The serialized output will have double escaping due to serde_json
426        assert_eq!(serialized, "\"Hello \\\\uD800 World \\\\uDC00\"");
427    }
428
429    #[test]
430    fn test_serialize_literal_backslash_u() {
431        // Test that literal "\u" in the string gets escaped properly
432        let atom = Wtf8Atom::new("\\u0041");
433        let serialized = serde_json::to_string(&atom).unwrap();
434        // serde_json escapes the backslash, resulting in 4 backslashes
435        assert_eq!(serialized, "\"\\\\\\\\u0041\"");
436    }
437
438    #[test]
439    fn test_deserialize_escaped_backslash_u() {
440        // Test deserializing the escaped format for unpaired surrogates
441        let json = "\"\\\\uD800\"";
442        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
443        // This should be parsed as an unpaired surrogate
444        assert_eq!(atom.as_str(), None);
445        assert_eq!(atom.to_string_lossy(), "\u{FFFD}");
446    }
447
448    #[test]
449    fn test_deserialize_unpaired_surrogates() {
450        let json = "\"\\\\uD800\""; // Use escaped format that matches serialization
451        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
452        // Should contain an unpaired surrogate, so as_str() returns None
453        assert_eq!(atom.as_str(), None);
454        // But to_string_lossy should work
455        assert_eq!(atom.to_string_lossy(), "\u{FFFD}");
456    }
457
458    #[test]
459    fn test_round_trip_normal_string() {
460        let original = Wtf8Atom::new("Hello, δΈ–η•Œ! 🌍");
461        let serialized = serde_json::to_string(&original).unwrap();
462        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
463        assert_eq!(original.as_str(), deserialized.as_str());
464    }
465
466    #[test]
467    fn test_round_trip_unpaired_surrogates() {
468        // Create a string with unpaired surrogates
469        let mut wtf8 = Wtf8Buf::new();
470        wtf8.push_str("Before ");
471        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
472        wtf8.push_str(" Middle ");
473        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdc00) });
474        wtf8.push_str(" After");
475        let original = Wtf8Atom::from(wtf8);
476
477        let serialized = serde_json::to_string(&original).unwrap();
478        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
479
480        // Both should be equal when compared as WTF-8
481        assert_eq!(original, deserialized);
482
483        // Both should produce the same lossy string
484        assert_eq!(original.to_string_lossy(), deserialized.to_string_lossy());
485    }
486
487    #[test]
488    fn test_round_trip_mixed_content() {
489        // Create a complex string with normal text, emojis, and unpaired surrogates
490        let mut wtf8 = Wtf8Buf::new();
491        wtf8.push_str("Hello δΈ–η•Œ 🌍 ");
492        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd83d) }); // Unpaired high
493        wtf8.push_str(" test ");
494        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdca9) }); // Unpaired low
495        let original = Wtf8Atom::from(wtf8);
496
497        let serialized = serde_json::to_string(&original).unwrap();
498        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
499
500        assert_eq!(original, deserialized);
501    }
502
503    #[test]
504    fn test_empty_string() {
505        let atom = Wtf8Atom::new("");
506        let serialized = serde_json::to_string(&atom).unwrap();
507        assert_eq!(serialized, "\"\"");
508
509        let deserialized: Wtf8Atom = serde_json::from_str("\"\"").unwrap();
510        assert_eq!(deserialized.as_str(), Some(""));
511    }
512
513    #[test]
514    fn test_special_characters() {
515        let test_cases = vec![
516            ("\"", "\"\\\"\""),
517            ("\n\r\t", "\"\\n\\r\\t\""), // serde_json escapes control characters
518            ("\\", "\"\\\\\""),
519            ("/", "\"/\""),
520        ];
521
522        for (input, expected) in test_cases {
523            let atom = Wtf8Atom::new(input);
524            let serialized = serde_json::to_string(&atom).unwrap();
525            assert_eq!(serialized, expected, "Failed for input: {input:?}");
526
527            let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
528            assert_eq!(deserialized.as_str(), Some(input));
529        }
530    }
531
532    #[test]
533    fn test_consecutive_surrogates_not_paired() {
534        // Test that consecutive surrogates that don't form a valid pair
535        // are handled correctly
536        let mut wtf8 = Wtf8Buf::new();
537        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) }); // High surrogate
538        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) }); // Another high surrogate
539        let atom = Wtf8Atom::from(wtf8);
540
541        let serialized = serde_json::to_string(&atom).unwrap();
542        // The serialized output will have double escaping due to serde_json
543        assert_eq!(serialized, "\"\\\\uD800\\\\uD800\"");
544
545        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
546        assert_eq!(atom, deserialized);
547    }
548
549    #[test]
550    fn test_deserialize_incomplete_escape() {
551        // Test handling of incomplete escape sequences from our custom format
552        let json = "\"\\\\\\\\u123\""; // Escaped backslash + incomplete sequence
553        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
554        // JSON decodes \\\\u123 to \\u123, then our deserializer sees \u123 and treats
555        // it as literal
556        assert_eq!(atom.as_str(), Some("\\u123"));
557    }
558
559    #[test]
560    fn test_deserialize_invalid_hex() {
561        // Test handling of invalid hex in escape sequences from our custom format
562        let json = "\"\\\\\\\\uGGGG\""; // Escaped backslash + invalid hex
563        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
564        // JSON decodes \\\\uGGGG to \\uGGGG, then our deserializer sees \uGGGG and
565        // treats it as literal
566        assert_eq!(atom.as_str(), Some("\\uGGGG"));
567    }
568
569    #[test]
570    fn test_try_into_atom_valid_utf8() {
571        let wtf8_atom = Wtf8Atom::new("Valid UTF-8 string");
572        let result = wtf8_atom.try_into_atom();
573        assert!(result.is_ok());
574        assert_eq!(result.unwrap().as_str(), "Valid UTF-8 string");
575    }
576
577    #[test]
578    fn test_try_into_atom_invalid_utf8() {
579        // Create an atom with unpaired surrogates
580        let mut wtf8 = Wtf8Buf::new();
581        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
582        let wtf8_atom = Wtf8Atom::from(wtf8);
583
584        let result = wtf8_atom.try_into_atom();
585        assert!(result.is_err());
586        // Should return the original Wtf8Atom
587        let err_atom = result.unwrap_err();
588        assert_eq!(err_atom.to_string_lossy(), "\u{FFFD}");
589    }
590
591    #[test]
592    fn test_backslash_util_issue_11214() {
593        let atom =
594            Wtf8Atom::from("C:\\github\\swc-plugin-coverage-instrument\\spec\\util\\verifier.ts");
595        let serialized = serde_json::to_string(&atom).unwrap();
596
597        assert!(
598            !serialized.contains("spec\\\\\\\\util"),
599            "Found quadruple backslashes in spec segment! Serialized: {serialized}"
600        );
601
602        assert!(
603            serialized.contains("spec\\\\util"),
604            "Expected double backslashes in spec segment not found! Serialized: {serialized}",
605        );
606
607        // The expected serialized value should have consistent escaping
608        let expected = r#""C:\\github\\swc-plugin-coverage-instrument\\spec\\util\\verifier.ts""#;
609        assert_eq!(
610            serialized, expected,
611            "Serialized value should have consistent backslash escaping"
612        );
613
614        // Test round-trip
615        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
616        assert_eq!(atom, deserialized);
617    }
618}