Skip to main content

markdown_it_rs_url/
urlencode.rs

1use std::borrow::Cow;
2
3const HEX_UPPER: &[u8; 16] = b"0123456789ABCDEF";
4
5#[derive(Debug, Clone, Copy, Eq, PartialEq)]
6pub struct AsciiSet {
7    allowed: u128,
8}
9
10impl AsciiSet {
11    pub const fn empty() -> Self {
12        Self { allowed: 0 }
13    }
14
15    pub const fn default() -> Self {
16        Self::empty()
17            .add(b';')
18            .add(b',')
19            .add(b'/')
20            .add(b'?')
21            .add(b':')
22            .add(b'@')
23            .add(b'&')
24            .add(b'=')
25            .add(b'+')
26            .add(b'$')
27            .add(b'-')
28            .add(b'_')
29            .add(b'.')
30            .add(b'!')
31            .add(b'~')
32            .add(b'*')
33            .add(b'\'')
34            .add(b'(')
35            .add(b')')
36            .add(b'#')
37    }
38
39    pub const fn add(mut self, byte: u8) -> Self {
40        if byte < 128 {
41            self.allowed |= 1u128 << byte;
42        }
43        self
44    }
45
46    pub fn contains(self, byte: u8) -> bool {
47        byte < 128 && (self.allowed & (1u128 << byte)) != 0
48    }
49}
50
51pub const ENCODE_DEFAULT_CHARS: AsciiSet = AsciiSet::default();
52
53pub fn encode(input: &str, exclude: AsciiSet, keep_escaped: bool) -> Cow<'_, str> {
54    let mut result: Option<String> = None;
55    let bytes = input.as_bytes();
56
57    let mut i = 0;
58    while i < bytes.len() {
59        let cur = bytes[i];
60
61        if keep_escaped
62            && cur == b'%'
63            && i + 2 < bytes.len()
64            && bytes[i + 1].is_ascii_hexdigit()
65            && bytes[i + 2].is_ascii_hexdigit()
66        {
67            if let Some(ref mut r) = result {
68                r.push('%');
69                r.push(bytes[i + 1] as char);
70                r.push(bytes[i + 2] as char);
71            }
72            i += 3;
73            continue;
74        }
75
76        if cur.is_ascii_alphanumeric() || exclude.contains(cur) {
77            if let Some(ref mut r) = result {
78                r.push(cur as char);
79            }
80        } else {
81            // needs escap
82            if result.is_none() {
83                // if empty
84                let mut new_result = String::with_capacity(input.len() * 3);
85                new_result.push_str(
86                    input
87                        .get(..i)
88                        .expect("encode only copies prefixes on UTF-8 boundaries"),
89                );
90                result = Some(new_result);
91            }
92            if let Some(ref mut r) = result {
93                r.push('%');
94                r.push(HEX_UPPER[(cur >> 4) as usize] as char);
95                r.push(HEX_UPPER[(cur & 0x0F) as usize] as char);
96            }
97        }
98
99        i += 1;
100    }
101
102    match result {
103        Some(r) => Cow::Owned(r),
104        None => Cow::Borrowed(input),
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn leaves_alphanumerics_unchanged() {
114        let mut alphabet_with_number = String::new();
115        for c in 'a'..='z' {
116            alphabet_with_number.push(c);
117        }
118        for n in '0'..='9' {
119            alphabet_with_number.push(n);
120        }
121        assert_eq!(
122            encode(&alphabet_with_number, ENCODE_DEFAULT_CHARS, true),
123            "abcdefghijklmnopqrstuvwxyz0123456789"
124        );
125    }
126
127    #[test]
128    fn encodes_spaces() {
129        let out = encode("a b", ENCODE_DEFAULT_CHARS, true);
130        assert_eq!(out, "a%20b");
131    }
132
133    #[test]
134    fn keeps_default_allowed_ascii() {
135        let out = encode("a/b?c=d#e", ENCODE_DEFAULT_CHARS, true);
136        assert_eq!(out, "a/b?c=d#e");
137    }
138
139    #[test]
140    fn encodes_reserved_htmlish_chars() {
141        let out = encode("<tag>", ENCODE_DEFAULT_CHARS, true);
142        assert_eq!(out, "%3Ctag%3E");
143    }
144
145    #[test]
146    fn encode_unicode_as_bytes() {
147        let out = encode("你好", ENCODE_DEFAULT_CHARS, true);
148        assert_eq!(out, "%E4%BD%A0%E5%A5%BD");
149    }
150
151    #[test]
152    fn preserves_valid_escape_sequences_then_enabled() {
153        let out = encode("a%20b", ENCODE_DEFAULT_CHARS, true);
154        assert_eq!(out, "a%20b");
155    }
156
157    #[test]
158    fn reencodes_percent_when_keep_escaped_is_disabled() {
159        let out = encode("a%20b", ENCODE_DEFAULT_CHARS, false);
160        assert_eq!(out, "a%2520b");
161    }
162
163    #[test]
164    fn encodes_invalid_escape_sequences() {
165        let out = encode("a%2g", ENCODE_DEFAULT_CHARS, true);
166        assert_eq!(out, "a%252g");
167    }
168
169    #[test]
170    fn encodes_alone_percent() {
171        let out = encode("100%", ENCODE_DEFAULT_CHARS, true);
172        assert_eq!(out, "100%25");
173    }
174
175    #[test]
176    fn encodes_a_empty_string() {
177        let out = encode("", ENCODE_DEFAULT_CHARS, true);
178        assert_eq!(out, "");
179    }
180
181    #[test]
182    fn encodes_a_single_escape_char() {
183        let out = encode(" ", ENCODE_DEFAULT_CHARS, true);
184        assert_eq!(out, "%20");
185    }
186
187    #[test]
188    fn encodes_breaking_escape() {
189        let out1 = encode("%", ENCODE_DEFAULT_CHARS, true);
190        assert_eq!(out1, "%25");
191
192        let out2 = encode("%2", ENCODE_DEFAULT_CHARS, true);
193        assert_eq!(out2, "%252");
194
195        let out3 = encode("%25", ENCODE_DEFAULT_CHARS, true);
196        assert_eq!(out3, "%25");
197    }
198
199    #[test]
200    fn encodes_valid_escaped_and_next_needs_escape_char() {
201        let out = encode("%20<", ENCODE_DEFAULT_CHARS, true);
202        assert_eq!(out, "%20%3C");
203    }
204
205    #[test]
206    fn encodes_utf8_with_ascii() {
207        let out = encode("hi你好a", ENCODE_DEFAULT_CHARS, true);
208        assert_eq!(out, "hi%E4%BD%A0%E5%A5%BDa")
209    }
210
211    #[test]
212    fn encodes_consecutive_space() {
213        let out = encode("   ", ENCODE_DEFAULT_CHARS, true);
214        assert_eq!(out, "%20%20%20");
215    }
216}