str_utils/
escape_characters.rs

1use alloc::{borrow::Cow, str, vec::Vec};
2
3/// To extend types which implement `AsRef<str>` to have `escape_characters` and `escape_ascii_characters` method.
4///
5/// Typical use cases include preparing strings for SQL `LIKE` queries or other contexts where certain characters need to be escaped.
6pub trait EscapeCharacters {
7    /// Escapes all occurrences of the specified characters within a string.
8    ///
9    /// This function scans the input string for any character that matches one of
10    /// the `escaped_characters` or the `escape_character` itself, and prefixes
11    /// those characters with the provided `escape_character`.
12    fn escape_characters(
13        &self,
14        escape_character: char,
15        escaped_characters: &[char],
16    ) -> Cow<'_, str>;
17
18    /// Escapes ASCII characters within a UTF-8 string.
19    ///
20    /// Similar to [`EscapeCharacters::escape_characters`], but operates directly on bytes instead of Unicode scalar values.
21    /// This version is optimized for ASCII-only escaping and avoids unnecessary Unicode conversions.
22    fn escape_ascii_characters(
23        &self,
24        escape_character: u8,
25        escaped_characters: &[u8],
26    ) -> Cow<'_, str>;
27}
28
29impl<T: AsRef<str>> EscapeCharacters for T {
30    fn escape_characters(
31        &self,
32        escape_character: char,
33        escaped_characters: &[char],
34    ) -> Cow<'_, str> {
35        let s = self.as_ref();
36
37        if escaped_characters.is_empty() {
38            return Cow::from(s);
39        }
40
41        let mut p = 0;
42
43        let mut chars = s.chars();
44
45        let need_escape = |c: char| {
46            c == escape_character
47                || escaped_characters.iter().any(|escaped_character| c.eq(escaped_character))
48        };
49
50        let first_c = loop {
51            let c = if let Some(c) = chars.next() {
52                c
53            } else {
54                return Cow::from(s);
55            };
56
57            if need_escape(c) {
58                break c;
59            }
60
61            p += c.len_utf8();
62        };
63
64        let mut new_s = String::from(unsafe { str::from_utf8_unchecked(&s.as_bytes()[0..p]) });
65
66        new_s.push(escape_character);
67        new_s.push(first_c);
68
69        for c in chars {
70            if need_escape(c) {
71                new_s.push(escape_character);
72            }
73
74            new_s.push(c);
75        }
76
77        Cow::from(new_s)
78    }
79
80    fn escape_ascii_characters(
81        &self,
82        escape_character: u8,
83        escaped_characters: &[u8],
84    ) -> Cow<'_, str> {
85        let s = self.as_ref();
86
87        if escaped_characters.is_empty() {
88            return Cow::from(s);
89        }
90
91        let bytes = s.as_bytes();
92
93        let length = bytes.len();
94
95        let mut p = 0;
96
97        let need_escape = |b: u8| {
98            b == escape_character
99                || escaped_characters.iter().any(|escaped_character| b.eq(escaped_character))
100        };
101
102        loop {
103            if p == length {
104                return Cow::from(s);
105            }
106
107            let e = bytes[p];
108
109            let width = unsafe { utf8_width::get_width_assume_valid(e) };
110
111            if width == 1 && need_escape(e) {
112                break;
113            }
114
115            p += width;
116        }
117
118        let mut new_v = Vec::with_capacity(bytes.len() + 1);
119
120        new_v.extend_from_slice(&bytes[..p]);
121        new_v.push(escape_character);
122
123        let mut start = p;
124
125        p += 1;
126
127        loop {
128            if p == length {
129                break;
130            }
131
132            let e = bytes[p];
133
134            let width = unsafe { utf8_width::get_width_assume_valid(e) };
135
136            if width == 1 && need_escape(e) {
137                new_v.extend_from_slice(&bytes[start..p]);
138                start = p + 1;
139
140                new_v.push(escape_character);
141                new_v.push(e);
142            }
143
144            p += width;
145        }
146
147        new_v.extend_from_slice(&bytes[start..p]);
148
149        Cow::from(unsafe { String::from_utf8_unchecked(new_v) })
150    }
151}