Skip to main content

hstr/
wtf8_atom.rs

1use std::{
2    fmt::Debug,
3    hash::Hash,
4    mem::{forget, transmute, ManuallyDrop},
5    ops::Deref,
6};
7
8use debug_unreachable::debug_unreachable;
9
10use crate::{
11    macros::{get_hash, impl_from_alias},
12    tagged_value::TaggedValue,
13    wtf8::Wtf8,
14    Atom, DYNAMIC_TAG, INLINE_TAG, LEN_MASK, LEN_OFFSET, TAG_MASK,
15};
16
17/// A WTF-8 encoded atom. This is like [Atom], but can contain unpaired
18/// surrogates.
19///
20/// [Atom]: crate::Atom
21#[repr(transparent)]
22pub struct Wtf8Atom {
23    pub(crate) unsafe_data: TaggedValue,
24}
25
26impl Wtf8Atom {
27    #[inline(always)]
28    pub fn new<S>(s: S) -> Self
29    where
30        Self: From<S>,
31    {
32        Self::from(s)
33    }
34
35    /// Try to convert this to a UTF-8 [Atom].
36    ///
37    /// Returns [Atom] if the string is valid UTF-8, otherwise returns
38    /// the original [Wtf8Atom].
39    pub fn try_into_atom(self) -> Result<Atom, Wtf8Atom> {
40        if self.as_str().is_some() {
41            let atom = ManuallyDrop::new(self);
42            Ok(Atom {
43                unsafe_data: atom.unsafe_data,
44            })
45        } else {
46            Err(self)
47        }
48    }
49
50    #[inline(always)]
51    fn tag(&self) -> u8 {
52        self.unsafe_data.tag() & TAG_MASK
53    }
54
55    /// Return true if this is a dynamic Atom.
56    #[inline(always)]
57    fn is_dynamic(&self) -> bool {
58        self.tag() == DYNAMIC_TAG
59    }
60}
61
62impl Default for Wtf8Atom {
63    #[inline(never)]
64    fn default() -> Self {
65        Wtf8Atom::new("")
66    }
67}
68
69/// Immutable, so it's safe to be shared between threads
70unsafe impl Send for Wtf8Atom {}
71
72/// Immutable, so it's safe to be shared between threads
73unsafe impl Sync for Wtf8Atom {}
74
75impl Debug for Wtf8Atom {
76    #[inline]
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        Debug::fmt(&**self, f)
79    }
80}
81
82#[cfg(feature = "serde")]
83impl serde::ser::Serialize for Wtf8Atom {
84    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
85    where
86        S: serde::ser::Serializer,
87    {
88        use crate::wtf8::Wtf8;
89        fn convert_wtf8_to_raw(s: &Wtf8) -> String {
90            let mut result = String::new();
91            let mut iter = s.code_points().peekable();
92
93            while let Some(code_point) = iter.next() {
94                if let Some(c) = code_point.to_char() {
95                    // Escape literal '\u' sequences to avoid ambiguity with surrogate encoding.
96                    // Without this escaping, we couldn't distinguish between:
97                    // - JavaScript's "\uD800" (actual unpaired surrogate)
98                    // - JavaScript's "\\uD800" (literal text '\uD800')
99                    //
100                    // By escaping literal '\u' to '\\u', we ensure:
101                    // - Unpaired surrogates serialize as '\uXXXX'
102                    // - Literal '\u' text serializes as '\\uXXXX'
103                    //
104                    // However, we should only escape '\u' if it's followed by exactly 4 hex digits,
105                    // which would indicate a Unicode escape sequence. Otherwise, '\u' followed by
106                    // non-hex characters (like '\util') should not be escaped.
107                    if c == '\\' && iter.peek().map(|cp| cp.to_u32()) == Some('u' as u32) {
108                        // Look ahead to see if this is followed by exactly 4 hex digits
109                        let mut lookahead = iter.clone();
110                        lookahead.next(); // skip 'u'
111
112                        let mut hex_count = 0;
113                        let mut all_hex = true;
114                        for _ in 0..4 {
115                            if let Some(next_cp) = lookahead.next() {
116                                if let Some(next_c) = next_cp.to_char() {
117                                    if next_c.is_ascii_hexdigit() {
118                                        hex_count += 1;
119                                    } else {
120                                        all_hex = false;
121                                        break;
122                                    }
123                                } else {
124                                    all_hex = false;
125                                    break;
126                                }
127                            } else {
128                                all_hex = false;
129                                break;
130                            }
131                        }
132
133                        // Only escape if we have exactly 4 hex digits after '\u'
134                        if hex_count == 4 && all_hex {
135                            iter.next(); // skip 'u'
136                            result.push_str("\\\\u");
137                        } else {
138                            result.push(c);
139                        }
140                    } else {
141                        result.push(c)
142                    }
143                } else {
144                    // Unpaired surrogates can't be represented in valid UTF-8,
145                    // so encode them as '\uXXXX' for JavaScript compatibility
146                    result.push_str(format!("\\u{:04X}", code_point.to_u32()).as_str());
147                }
148            }
149
150            result
151        }
152
153        serializer.serialize_str(&convert_wtf8_to_raw(self))
154    }
155}
156
157#[cfg(feature = "serde")]
158impl<'de> serde::de::Deserialize<'de> for Wtf8Atom {
159    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
160    where
161        D: serde::Deserializer<'de>,
162    {
163        use crate::wtf8::{CodePoint, Wtf8Buf};
164        fn convert_wtf8_string_to_wtf8(s: String) -> Wtf8Buf {
165            let mut iter = s.chars().peekable();
166            let mut result = Wtf8Buf::with_capacity(s.len());
167
168            // This function reverses the encoding done in serialize.
169            // It handles two cases:
170            // 1. '\uXXXX' - Decode as an unpaired surrogate code point
171            // 2. '\\uXXXX' - Treat as literal text '\uXXXX'
172            while let Some(c) = iter.next() {
173                if c == '\\' {
174                    if iter.peek() == Some(&'u') {
175                        // Found '\u' - might be a surrogate encoding
176                        let _ = iter.next(); // skip 'u'
177
178                        // Try to read 4 hex digits
179                        let d1 = iter.next();
180                        let d2 = iter.next();
181                        let d3 = iter.next();
182                        let d4 = iter.next();
183
184                        match (d1, d2, d3, d4) {
185                            (Some(d1), Some(d2), Some(d3), Some(d4)) => {
186                                let hex = format!("{d1}{d2}{d3}{d4}");
187                                if let Ok(code_point) = u16::from_str_radix(&hex, 16) {
188                                    result.push(unsafe {
189                                        CodePoint::from_u32_unchecked(code_point as u32)
190                                    });
191                                    continue;
192                                }
193
194                                result.push_char('\\');
195                                result.push_char('u');
196                                result.push_char(d1);
197                                result.push_char(d2);
198                                result.push_char(d3);
199                                result.push_char(d4);
200                            }
201                            (d1, d2, d3, d4) => {
202                                result.push_char('\\');
203                                result.push_char('u');
204
205                                macro_rules! push_if_some {
206                                    ($expr:expr) => {
207                                        if let Some(c) = $expr {
208                                            result.push_char(c);
209                                        }
210                                    };
211                                }
212
213                                push_if_some!(d1);
214                                push_if_some!(d2);
215                                push_if_some!(d3);
216                                push_if_some!(d4);
217                            }
218                        }
219                    } else if iter.peek() == Some(&'\\') {
220                        // Found '\\' - this is an escaped backslash
221                        // '\\u' should become literal '\u' text
222                        let _ = iter.next(); // skip the second '\'
223                        if iter.peek() == Some(&'u') {
224                            let _ = iter.next(); // skip 'u'
225                            result.push_char('\\');
226                            result.push_char('u');
227                        } else {
228                            result.push_str("\\\\");
229                        }
230                    } else {
231                        result.push_char(c);
232                    }
233                } else {
234                    result.push_char(c);
235                }
236            }
237            result
238        }
239
240        String::deserialize(deserializer).map(|v| convert_wtf8_string_to_wtf8(v).into())
241    }
242}
243
244impl PartialEq for Wtf8Atom {
245    #[inline]
246    fn eq(&self, other: &Self) -> bool {
247        let unsafe_data = self.unsafe_data;
248        let other_unsafe_data = other.unsafe_data;
249
250        if unsafe_data == other_unsafe_data {
251            return true;
252        }
253
254        let tag = unsafe_data.tag() & TAG_MASK;
255
256        if tag != (other_unsafe_data.tag() & TAG_MASK) {
257            return false;
258        }
259
260        match tag {
261            // Inline atoms encode both their length and bytes in `unsafe_data`, so
262            // different raw values mean different strings.
263            INLINE_TAG => false,
264            DYNAMIC_TAG => {
265                let this = unsafe { crate::dynamic::deref_from(unsafe_data) };
266                let other = unsafe { crate::dynamic::deref_from(other_unsafe_data) };
267
268                if this.header.header.hash != other.header.header.hash {
269                    return false;
270                }
271
272                this.slice == other.slice
273            }
274            _ => unsafe { debug_unreachable!() },
275        }
276    }
277}
278
279impl Eq for Wtf8Atom {}
280
281impl Hash for Wtf8Atom {
282    #[inline(always)]
283    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
284        state.write_u64(self.get_hash());
285    }
286}
287
288impl Drop for Wtf8Atom {
289    #[inline(always)]
290    fn drop(&mut self) {
291        if self.is_dynamic() {
292            unsafe { drop(crate::dynamic::restore_arc(self.unsafe_data)) }
293        }
294    }
295}
296
297impl Clone for Wtf8Atom {
298    #[inline(always)]
299    fn clone(&self) -> Self {
300        Self::from_alias(self.unsafe_data)
301    }
302}
303
304impl Deref for Wtf8Atom {
305    type Target = Wtf8;
306
307    #[inline(always)]
308    fn deref(&self) -> &Self::Target {
309        self.as_wtf8()
310    }
311}
312
313impl AsRef<Wtf8> for Wtf8Atom {
314    #[inline(always)]
315    fn as_ref(&self) -> &Wtf8 {
316        self.as_wtf8()
317    }
318}
319
320impl PartialEq<Wtf8> for Wtf8Atom {
321    #[inline]
322    fn eq(&self, other: &Wtf8) -> bool {
323        self.as_wtf8() == other
324    }
325}
326
327impl PartialEq<crate::Atom> for Wtf8Atom {
328    #[inline]
329    fn eq(&self, other: &crate::Atom) -> bool {
330        self.as_str() == Some(other.as_str())
331    }
332}
333
334impl PartialEq<&'_ Wtf8> for Wtf8Atom {
335    #[inline]
336    fn eq(&self, other: &&Wtf8) -> bool {
337        self.as_wtf8() == *other
338    }
339}
340
341impl PartialEq<Wtf8Atom> for Wtf8 {
342    #[inline]
343    fn eq(&self, other: &Wtf8Atom) -> bool {
344        self == other.as_wtf8()
345    }
346}
347
348impl PartialEq<str> for Wtf8Atom {
349    #[inline]
350    fn eq(&self, other: &str) -> bool {
351        matches!(self.as_str(), Some(s) if s == other)
352    }
353}
354
355impl PartialEq<&str> for Wtf8Atom {
356    #[inline]
357    fn eq(&self, other: &&str) -> bool {
358        matches!(self.as_str(), Some(s) if s == *other)
359    }
360}
361
362impl Wtf8Atom {
363    pub(super) fn get_hash(&self) -> u64 {
364        get_hash!(self)
365    }
366
367    fn as_wtf8(&self) -> &Wtf8 {
368        match self.tag() {
369            DYNAMIC_TAG => unsafe {
370                let item = crate::dynamic::deref_from(self.unsafe_data);
371                Wtf8::from_bytes_unchecked(transmute::<&[u8], &'static [u8]>(&item.slice))
372            },
373            INLINE_TAG => {
374                let len = (self.unsafe_data.tag() & LEN_MASK) >> LEN_OFFSET;
375                let src = self.unsafe_data.data();
376                unsafe { Wtf8::from_bytes_unchecked(&src[..(len as usize)]) }
377            }
378            _ => unsafe { debug_unreachable!() },
379        }
380    }
381}
382
383impl_from_alias!(Wtf8Atom);
384
385#[cfg(test)]
386impl Wtf8Atom {
387    pub(crate) fn ref_count(&self) -> usize {
388        match self.tag() {
389            DYNAMIC_TAG => {
390                let ptr = unsafe { crate::dynamic::deref_from(self.unsafe_data) };
391
392                triomphe::ThinArc::strong_count(&ptr.0)
393            }
394            _ => 1,
395        }
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use crate::wtf8::{CodePoint, Wtf8Buf};
403
404    #[cfg(feature = "serde")]
405    #[test]
406    fn test_serialize_normal_utf8() {
407        let atom = Wtf8Atom::new("Hello, world!");
408        let serialized = serde_json::to_string(&atom).unwrap();
409        assert_eq!(serialized, "\"Hello, world!\"");
410    }
411
412    #[cfg(feature = "serde")]
413    #[test]
414    fn test_deserialize_normal_utf8() {
415        let json = "\"Hello, world!\"";
416        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
417        assert_eq!(atom.as_str(), Some("Hello, world!"));
418    }
419
420    #[cfg(feature = "serde")]
421    #[test]
422    fn test_serialize_unpaired_high_surrogate() {
423        // Create a WTF-8 string with an unpaired high surrogate (U+D800)
424        let mut wtf8 = Wtf8Buf::new();
425        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
426        let atom = Wtf8Atom::from(wtf8);
427
428        let serialized = serde_json::to_string(&atom).unwrap();
429        // The serialized output will have double escaping due to serde_json
430        assert_eq!(serialized, "\"\\\\uD800\"");
431    }
432
433    #[cfg(feature = "serde")]
434    #[test]
435    fn test_serialize_unpaired_low_surrogate() {
436        // Create a WTF-8 string with an unpaired low surrogate (U+DC00)
437        let mut wtf8 = Wtf8Buf::new();
438        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdc00) });
439        let atom = Wtf8Atom::from(wtf8);
440
441        let serialized = serde_json::to_string(&atom).unwrap();
442        // The serialized output will have double escaping due to serde_json
443        assert_eq!(serialized, "\"\\\\uDC00\"");
444    }
445
446    #[cfg(feature = "serde")]
447    #[test]
448    fn test_serialize_multiple_surrogates() {
449        // Create a WTF-8 string with multiple unpaired surrogates
450        let mut wtf8 = Wtf8Buf::new();
451        wtf8.push_str("Hello ");
452        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
453        wtf8.push_str(" World ");
454        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdc00) });
455        let atom = Wtf8Atom::from(wtf8);
456
457        let serialized = serde_json::to_string(&atom).unwrap();
458        // The serialized output will have double escaping due to serde_json
459        assert_eq!(serialized, "\"Hello \\\\uD800 World \\\\uDC00\"");
460    }
461
462    #[cfg(feature = "serde")]
463    #[test]
464    fn test_serialize_literal_backslash_u() {
465        // Test that literal "\u" in the string gets escaped properly
466        let atom = Wtf8Atom::new("\\u0041");
467        let serialized = serde_json::to_string(&atom).unwrap();
468        // serde_json escapes the backslash, resulting in 4 backslashes
469        assert_eq!(serialized, "\"\\\\\\\\u0041\"");
470    }
471
472    #[cfg(feature = "serde")]
473    #[test]
474    fn test_deserialize_escaped_backslash_u() {
475        // Test deserializing the escaped format for unpaired surrogates
476        let json = "\"\\\\uD800\"";
477        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
478        // This should be parsed as an unpaired surrogate
479        assert_eq!(atom.as_str(), None);
480        assert_eq!(atom.to_string_lossy(), "\u{FFFD}");
481    }
482
483    #[cfg(feature = "serde")]
484    #[test]
485    fn test_deserialize_unpaired_surrogates() {
486        let json = "\"\\\\uD800\""; // Use escaped format that matches serialization
487        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
488        // Should contain an unpaired surrogate, so as_str() returns None
489        assert_eq!(atom.as_str(), None);
490        // But to_string_lossy should work
491        assert_eq!(atom.to_string_lossy(), "\u{FFFD}");
492    }
493
494    #[cfg(feature = "serde")]
495    #[test]
496    fn test_round_trip_normal_string() {
497        let original = Wtf8Atom::new("Hello, δΈ–η•Œ! 🌍");
498        let serialized = serde_json::to_string(&original).unwrap();
499        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
500        assert_eq!(original.as_str(), deserialized.as_str());
501    }
502
503    #[cfg(feature = "serde")]
504    #[test]
505    fn test_round_trip_unpaired_surrogates() {
506        // Create a string with unpaired surrogates
507        let mut wtf8 = Wtf8Buf::new();
508        wtf8.push_str("Before ");
509        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
510        wtf8.push_str(" Middle ");
511        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdc00) });
512        wtf8.push_str(" After");
513        let original = Wtf8Atom::from(wtf8);
514
515        let serialized = serde_json::to_string(&original).unwrap();
516        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
517
518        // Both should be equal when compared as WTF-8
519        assert_eq!(original, deserialized);
520
521        // Both should produce the same lossy string
522        assert_eq!(original.to_string_lossy(), deserialized.to_string_lossy());
523    }
524
525    #[cfg(feature = "serde")]
526    #[test]
527    fn test_round_trip_mixed_content() {
528        // Create a complex string with normal text, emojis, and unpaired surrogates
529        let mut wtf8 = Wtf8Buf::new();
530        wtf8.push_str("Hello δΈ–η•Œ 🌍 ");
531        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd83d) }); // Unpaired high
532        wtf8.push_str(" test ");
533        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xdca9) }); // Unpaired low
534        let original = Wtf8Atom::from(wtf8);
535
536        let serialized = serde_json::to_string(&original).unwrap();
537        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
538
539        assert_eq!(original, deserialized);
540    }
541
542    #[cfg(feature = "serde")]
543    #[test]
544    fn test_empty_string() {
545        let atom = Wtf8Atom::new("");
546        let serialized = serde_json::to_string(&atom).unwrap();
547        assert_eq!(serialized, "\"\"");
548
549        let deserialized: Wtf8Atom = serde_json::from_str("\"\"").unwrap();
550        assert_eq!(deserialized.as_str(), Some(""));
551    }
552
553    #[cfg(feature = "serde")]
554    #[test]
555    fn test_special_characters() {
556        let test_cases = vec![
557            ("\"", "\"\\\"\""),
558            ("\n\r\t", "\"\\n\\r\\t\""), // serde_json escapes control characters
559            ("\\", "\"\\\\\""),
560            ("/", "\"/\""),
561        ];
562
563        for (input, expected) in test_cases {
564            let atom = Wtf8Atom::new(input);
565            let serialized = serde_json::to_string(&atom).unwrap();
566            assert_eq!(serialized, expected, "Failed for input: {input:?}");
567
568            let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
569            assert_eq!(deserialized.as_str(), Some(input));
570        }
571    }
572
573    #[cfg(feature = "serde")]
574    #[test]
575    fn test_consecutive_surrogates_not_paired() {
576        // Test that consecutive surrogates that don't form a valid pair
577        // are handled correctly
578        let mut wtf8 = Wtf8Buf::new();
579        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) }); // High surrogate
580        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) }); // Another high surrogate
581        let atom = Wtf8Atom::from(wtf8);
582
583        let serialized = serde_json::to_string(&atom).unwrap();
584        // The serialized output will have double escaping due to serde_json
585        assert_eq!(serialized, "\"\\\\uD800\\\\uD800\"");
586
587        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
588        assert_eq!(atom, deserialized);
589    }
590
591    #[cfg(feature = "serde")]
592    #[test]
593    fn test_deserialize_incomplete_escape() {
594        // Test handling of incomplete escape sequences from our custom format
595        let json = "\"\\\\\\\\u123\""; // Escaped backslash + incomplete sequence
596        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
597        // JSON decodes \\\\u123 to \\u123, then our deserializer sees \u123 and treats
598        // it as literal
599        assert_eq!(atom.as_str(), Some("\\u123"));
600    }
601
602    #[cfg(feature = "serde")]
603    #[test]
604    fn test_deserialize_invalid_hex() {
605        // Test handling of invalid hex in escape sequences from our custom format
606        let json = "\"\\\\\\\\uGGGG\""; // Escaped backslash + invalid hex
607        let atom: Wtf8Atom = serde_json::from_str(json).unwrap();
608        // JSON decodes \\\\uGGGG to \\uGGGG, then our deserializer sees \uGGGG and
609        // treats it as literal
610        assert_eq!(atom.as_str(), Some("\\uGGGG"));
611    }
612
613    #[test]
614    fn test_try_into_atom_valid_utf8() {
615        let wtf8_atom = Wtf8Atom::new("Valid UTF-8 string");
616        let result = wtf8_atom.try_into_atom();
617        assert!(result.is_ok());
618        assert_eq!(result.unwrap().as_str(), "Valid UTF-8 string");
619    }
620
621    #[test]
622    fn test_try_into_atom_invalid_utf8() {
623        // Create an atom with unpaired surrogates
624        let mut wtf8 = Wtf8Buf::new();
625        wtf8.push(unsafe { CodePoint::from_u32_unchecked(0xd800) });
626        let wtf8_atom = Wtf8Atom::from(wtf8);
627
628        let result = wtf8_atom.try_into_atom();
629        assert!(result.is_err());
630        // Should return the original Wtf8Atom
631        let err_atom = result.unwrap_err();
632        assert_eq!(err_atom.to_string_lossy(), "\u{FFFD}");
633    }
634
635    #[cfg(feature = "serde")]
636    #[test]
637    fn test_backslash_util_issue_11214() {
638        let atom =
639            Wtf8Atom::from("C:\\github\\swc-plugin-coverage-instrument\\spec\\util\\verifier.ts");
640        let serialized = serde_json::to_string(&atom).unwrap();
641
642        assert!(
643            !serialized.contains("spec\\\\\\\\util"),
644            "Found quadruple backslashes in spec segment! Serialized: {serialized}"
645        );
646
647        assert!(
648            serialized.contains("spec\\\\util"),
649            "Expected double backslashes in spec segment not found! Serialized: {serialized}",
650        );
651
652        // The expected serialized value should have consistent escaping
653        let expected = r#""C:\\github\\swc-plugin-coverage-instrument\\spec\\util\\verifier.ts""#;
654        assert_eq!(
655            serialized, expected,
656            "Serialized value should have consistent backslash escaping"
657        );
658
659        // Test round-trip
660        let deserialized: Wtf8Atom = serde_json::from_str(&serialized).unwrap();
661        assert_eq!(atom, deserialized);
662    }
663}