Skip to main content

str_utils/
escape_characters.rs

1use alloc::{borrow::Cow, str::from_utf8_unchecked, string::String, vec::Vec};
2
3/// To extend `str` and `Cow<str>` to have `escape_characters` and `escape_ascii_characters` method.
4///
5/// Typical use cases include preparing strings for SQL `LIKE` queries or other contexts where certain characters need to be escaped.
6pub trait EscapeCharacters<'a> {
7    /// Escapes all occurrences of the specified characters within a string.
8    ///
9    /// This function scans the input string for any character that matches one of
10    /// the `escaped_characters` or the `escape_character` itself, and prefixes
11    /// those characters with the provided `escape_character`.
12    fn escape_characters(self, escape_character: char, escaped_characters: &[char])
13        -> Cow<'a, str>;
14
15    /// Escapes ASCII characters within a UTF-8 string.
16    ///
17    /// Similar to [`EscapeCharacters::escape_characters`], but operates directly on bytes instead of Unicode scalar values.
18    /// This version is optimized for ASCII-only escaping and avoids unnecessary Unicode conversions.
19    ///
20    /// NOTE: The `escape_character` must be an ASCII character (i.e., in the range 0x00 to 0x7F) for this method to work correctly.
21    fn escape_ascii_characters(
22        self,
23        escape_character: u8,
24        escaped_characters: &[u8],
25    ) -> Cow<'a, str>;
26}
27
28impl<'a> EscapeCharacters<'a> for &'a str {
29    fn escape_characters(
30        self,
31        escape_character: char,
32        escaped_characters: &[char],
33    ) -> Cow<'a, str> {
34        let s = self;
35
36        let mut p = 0;
37
38        let mut chars = s.chars();
39
40        let need_escape = |c: char| {
41            c == escape_character
42                || escaped_characters.iter().any(|escaped_character| c.eq(escaped_character))
43        };
44
45        let first_c = loop {
46            let c = if let Some(c) = chars.next() {
47                c
48            } else {
49                return Cow::Borrowed(s);
50            };
51
52            if need_escape(c) {
53                break c;
54            }
55
56            p += c.len_utf8();
57        };
58
59        let mut new_s = String::with_capacity(s.len() + 1);
60
61        new_s.push_str(unsafe { from_utf8_unchecked(&s.as_bytes()[0..p]) });
62        new_s.push(escape_character);
63        new_s.push(first_c);
64
65        for c in chars {
66            if need_escape(c) {
67                new_s.push(escape_character);
68            }
69
70            new_s.push(c);
71        }
72
73        Cow::Owned(new_s)
74    }
75
76    fn escape_ascii_characters(
77        self,
78        escape_character: u8,
79        escaped_characters: &[u8],
80    ) -> Cow<'a, str> {
81        let s = self;
82
83        debug_assert!(escape_character.is_ascii(), "escape_character must be ASCII");
84
85        let bytes = s.as_bytes();
86
87        let length = bytes.len();
88
89        let mut p = 0;
90
91        let need_escape = |b: u8| {
92            b == escape_character
93                || escaped_characters.iter().any(|escaped_character| b.eq(escaped_character))
94        };
95
96        loop {
97            if p == length {
98                return Cow::Borrowed(s);
99            }
100
101            let e = bytes[p];
102
103            let width = unsafe { utf8_width::get_width_assume_valid(e) };
104
105            if width == 1 && need_escape(e) {
106                break;
107            }
108
109            p += width;
110        }
111
112        let mut new_v = Vec::with_capacity(bytes.len() + 1);
113
114        new_v.extend_from_slice(&bytes[..p]);
115        new_v.push(escape_character);
116
117        let mut start = p;
118
119        p += 1;
120
121        loop {
122            if p == length {
123                break;
124            }
125
126            let e = bytes[p];
127
128            let width = unsafe { utf8_width::get_width_assume_valid(e) };
129
130            if width == 1 && need_escape(e) {
131                new_v.extend_from_slice(&bytes[start..p]);
132                start = p + 1;
133
134                new_v.push(escape_character);
135                new_v.push(e);
136            }
137
138            p += width;
139        }
140
141        new_v.extend_from_slice(&bytes[start..p]);
142
143        Cow::Owned(unsafe { String::from_utf8_unchecked(new_v) })
144    }
145}
146
147impl<'a> EscapeCharacters<'a> for Cow<'a, str> {
148    #[inline]
149    fn escape_characters(
150        self,
151        escape_character: char,
152        escaped_characters: &[char],
153    ) -> Cow<'a, str> {
154        match self {
155            Cow::Borrowed(s) => s.escape_characters(escape_character, escaped_characters),
156            Cow::Owned(s) => Cow::Owned(cow_into_owned!(
157                s,
158                s.as_str().escape_characters(escape_character, escaped_characters),
159            )),
160        }
161    }
162
163    #[inline]
164    fn escape_ascii_characters(
165        self,
166        escape_character: u8,
167        escaped_characters: &[u8],
168    ) -> Cow<'a, str> {
169        match self {
170            Cow::Borrowed(s) => s.escape_ascii_characters(escape_character, escaped_characters),
171            Cow::Owned(s) => Cow::Owned(cow_into_owned!(
172                s,
173                s.as_str().escape_ascii_characters(escape_character, escaped_characters),
174            )),
175        }
176    }
177}