const_serialize/
str.rs

1use crate::*;
2use std::{char, fmt::Debug, hash::Hash, mem::MaybeUninit};
3
4const MAX_STR_SIZE: usize = 256;
5
6/// A string that is stored in a constant sized buffer that can be serialized and deserialized at compile time
7#[derive(Clone, Copy)]
8pub struct ConstStr {
9    bytes: [MaybeUninit<u8>; MAX_STR_SIZE],
10    len: u32,
11}
12
13impl Debug for ConstStr {
14    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
15        f.debug_struct("ConstStr")
16            .field("str", &self.as_str())
17            .finish()
18    }
19}
20
21#[cfg(feature = "serde")]
22mod serde_bytes {
23    use serde::{Deserialize, Serialize, Serializer};
24
25    use crate::ConstStr;
26
27    impl Serialize for ConstStr {
28        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
29        where
30            S: Serializer,
31        {
32            serializer.serialize_str(self.as_str())
33        }
34    }
35
36    impl<'de> Deserialize<'de> for ConstStr {
37        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
38        where
39            D: serde::Deserializer<'de>,
40        {
41            let s = String::deserialize(deserializer)?;
42            Ok(ConstStr::new(&s))
43        }
44    }
45}
46
47unsafe impl SerializeConst for ConstStr {
48    const MEMORY_LAYOUT: Layout = Layout::List(ListLayout::new(
49        std::mem::size_of::<Self>(),
50        std::mem::offset_of!(Self, len),
51        PrimitiveLayout {
52            size: std::mem::size_of::<u32>(),
53        },
54        std::mem::offset_of!(Self, bytes),
55        ArrayLayout {
56            len: MAX_STR_SIZE,
57            item_layout: &Layout::Primitive(PrimitiveLayout {
58                size: std::mem::size_of::<u8>(),
59            }),
60        },
61    ));
62}
63
64#[cfg(feature = "const-serialize-07")]
65unsafe impl const_serialize_07::SerializeConst for ConstStr {
66    const MEMORY_LAYOUT: const_serialize_07::Layout =
67        const_serialize_07::Layout::Struct(const_serialize_07::StructLayout::new(
68            std::mem::size_of::<Self>(),
69            &[
70                const_serialize_07::StructFieldLayout::new(
71                    std::mem::offset_of!(Self, bytes),
72                    const_serialize_07::Layout::List(const_serialize_07::ListLayout::new(
73                        MAX_STR_SIZE,
74                        &const_serialize_07::Layout::Primitive(
75                            const_serialize_07::PrimitiveLayout::new(std::mem::size_of::<u8>()),
76                        ),
77                    )),
78                ),
79                const_serialize_07::StructFieldLayout::new(
80                    std::mem::offset_of!(Self, len),
81                    const_serialize_07::Layout::Primitive(
82                        const_serialize_07::PrimitiveLayout::new(std::mem::size_of::<u32>()),
83                    ),
84                ),
85            ],
86        ));
87}
88
89impl ConstStr {
90    /// Create a new constant string
91    pub const fn new(s: &str) -> Self {
92        let str_bytes = s.as_bytes();
93        // This is serialized as a constant sized array in const-serialize-07 which requires all memory to be initialized
94        let mut bytes = if cfg!(feature = "const-serialize-07") {
95            [MaybeUninit::new(0); MAX_STR_SIZE]
96        } else {
97            [MaybeUninit::uninit(); MAX_STR_SIZE]
98        };
99        let mut i = 0;
100        while i < str_bytes.len() {
101            bytes[i] = MaybeUninit::new(str_bytes[i]);
102            i += 1;
103        }
104        Self {
105            bytes,
106            len: str_bytes.len() as u32,
107        }
108    }
109
110    /// Get the bytes of the initialized portion of the string
111    const fn bytes(&self) -> &[u8] {
112        // Safety: All bytes up to the pointer are initialized
113        unsafe {
114            &*(self.bytes.split_at(self.len as usize).0 as *const [MaybeUninit<u8>]
115                as *const [u8])
116        }
117    }
118
119    /// Get a reference to the string
120    pub const fn as_str(&self) -> &str {
121        let str_bytes = self.bytes();
122        match std::str::from_utf8(str_bytes) {
123            Ok(s) => s,
124            Err(_) => panic!(
125                "Invalid utf8; ConstStr should only ever be constructed from valid utf8 strings"
126            ),
127        }
128    }
129
130    /// Get the length of the string
131    pub const fn len(&self) -> usize {
132        self.len as usize
133    }
134
135    /// Check if the string is empty
136    pub const fn is_empty(&self) -> bool {
137        self.len == 0
138    }
139
140    /// Push a character onto the string
141    pub const fn push(self, byte: char) -> Self {
142        assert!(byte.is_ascii(), "Only ASCII bytes are supported");
143        let (bytes, len) = char_to_bytes(byte);
144        let (str, _) = bytes.split_at(len);
145        let Ok(str) = std::str::from_utf8(str) else {
146            panic!("Invalid utf8; char_to_bytes should always return valid utf8 bytes")
147        };
148        self.push_str(str)
149    }
150
151    /// Push a str onto the string
152    pub const fn push_str(self, str: &str) -> Self {
153        let Self { mut bytes, len } = self;
154        assert!(
155            str.len() + len as usize <= MAX_STR_SIZE,
156            "String is too long"
157        );
158        let str_bytes = str.as_bytes();
159        let new_len = len as usize + str_bytes.len();
160        let mut i = 0;
161        while i < str_bytes.len() {
162            bytes[len as usize + i] = MaybeUninit::new(str_bytes[i]);
163            i += 1;
164        }
165        Self {
166            bytes,
167            len: new_len as u32,
168        }
169    }
170
171    /// Split the string at a byte index. The byte index must be a char boundary
172    pub const fn split_at(self, index: usize) -> (Self, Self) {
173        let (left, right) = self.bytes().split_at(index);
174        let left = match std::str::from_utf8(left) {
175            Ok(s) => s,
176            Err(_) => {
177                panic!("Invalid utf8; you cannot split at a byte that is not a char boundary")
178            }
179        };
180        let right = match std::str::from_utf8(right) {
181            Ok(s) => s,
182            Err(_) => {
183                panic!("Invalid utf8; you cannot split at a byte that is not a char boundary")
184            }
185        };
186        (Self::new(left), Self::new(right))
187    }
188
189    /// Split the string at the last occurrence of a character
190    pub const fn rsplit_once(&self, char: char) -> Option<(Self, Self)> {
191        let str = self.as_str();
192        let mut index = str.len() - 1;
193        // First find the bytes we are searching for
194        let (char_bytes, len) = char_to_bytes(char);
195        let (char_bytes, _) = char_bytes.split_at(len);
196        let bytes = str.as_bytes();
197
198        // Then walk backwards from the end of the string
199        loop {
200            let byte = bytes[index];
201            // Look for char boundaries in the string and check if the bytes match
202            if let Some(char_boundary_len) = utf8_char_boundary_to_char_len(byte) {
203                // Split up the string into three sections: [before_char, in_char, after_char]
204                let (before_char, after_index) = bytes.split_at(index);
205                let (in_char, after_char) = after_index.split_at(char_boundary_len as usize);
206                if in_char.len() != char_boundary_len as usize {
207                    panic!("in_char.len() should always be equal to char_boundary_len as usize")
208                }
209                // Check if the bytes for the current char and the target char match
210                let mut in_char_eq = true;
211                let mut i = 0;
212                let min_len = if in_char.len() < char_bytes.len() {
213                    in_char.len()
214                } else {
215                    char_bytes.len()
216                };
217                while i < min_len {
218                    in_char_eq &= in_char[i] == char_bytes[i];
219                    i += 1;
220                }
221                // If they do, convert the bytes to strings and return the split strings
222                if in_char_eq {
223                    let Ok(before_char_str) = std::str::from_utf8(before_char) else {
224                        panic!("Invalid utf8; utf8_char_boundary_to_char_len should only return Some when the byte is a character boundary")
225                    };
226                    let Ok(after_char_str) = std::str::from_utf8(after_char) else {
227                        panic!("Invalid utf8; utf8_char_boundary_to_char_len should only return Some when the byte is a character boundary")
228                    };
229                    return Some((Self::new(before_char_str), Self::new(after_char_str)));
230                }
231            }
232            match index.checked_sub(1) {
233                Some(new_index) => index = new_index,
234                None => return None,
235            }
236        }
237    }
238
239    /// Split the string at the first occurrence of a character
240    pub const fn split_once(&self, char: char) -> Option<(Self, Self)> {
241        let str = self.as_str();
242        let mut index = 0;
243        // First find the bytes we are searching for
244        let (char_bytes, len) = char_to_bytes(char);
245        let (char_bytes, _) = char_bytes.split_at(len);
246        let bytes = str.as_bytes();
247
248        // Then walk forwards from the start of the string
249        while index < bytes.len() {
250            let byte = bytes[index];
251            // Look for char boundaries in the string and check if the bytes match
252            if let Some(char_boundary_len) = utf8_char_boundary_to_char_len(byte) {
253                // Split up the string into three sections: [before_char, in_char, after_char]
254                let (before_char, after_index) = bytes.split_at(index);
255                let (in_char, after_char) = after_index.split_at(char_boundary_len as usize);
256                if in_char.len() != char_boundary_len as usize {
257                    panic!("in_char.len() should always be equal to char_boundary_len as usize")
258                }
259                // Check if the bytes for the current char and the target char match
260                let mut in_char_eq = true;
261                let mut i = 0;
262                let min_len = if in_char.len() < char_bytes.len() {
263                    in_char.len()
264                } else {
265                    char_bytes.len()
266                };
267                while i < min_len {
268                    in_char_eq &= in_char[i] == char_bytes[i];
269                    i += 1;
270                }
271                // If they do, convert the bytes to strings and return the split strings
272                if in_char_eq {
273                    let Ok(before_char_str) = std::str::from_utf8(before_char) else {
274                        panic!("Invalid utf8; utf8_char_boundary_to_char_len should only return Some when the byte is a character boundary")
275                    };
276                    let Ok(after_char_str) = std::str::from_utf8(after_char) else {
277                        panic!("Invalid utf8; utf8_char_boundary_to_char_len should only return Some when the byte is a character boundary")
278                    };
279                    return Some((Self::new(before_char_str), Self::new(after_char_str)));
280                }
281            }
282            index += 1
283        }
284        None
285    }
286}
287
288impl PartialEq for ConstStr {
289    fn eq(&self, other: &Self) -> bool {
290        self.as_str() == other.as_str()
291    }
292}
293
294impl Eq for ConstStr {}
295
296impl PartialOrd for ConstStr {
297    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
298        Some(self.cmp(other))
299    }
300}
301
302impl Ord for ConstStr {
303    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
304        self.as_str().cmp(other.as_str())
305    }
306}
307
308impl Hash for ConstStr {
309    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
310        self.as_str().hash(state);
311    }
312}
313
314#[test]
315fn test_rsplit_once() {
316    let str = ConstStr::new("hello world");
317    assert_eq!(
318        str.rsplit_once(' '),
319        Some((ConstStr::new("hello"), ConstStr::new("world")))
320    );
321
322    let unicode_str = ConstStr::new("hi😀hello😀world😀world");
323    assert_eq!(
324        unicode_str.rsplit_once('😀'),
325        Some((ConstStr::new("hi😀hello😀world"), ConstStr::new("world")))
326    );
327    assert_eq!(unicode_str.rsplit_once('❌'), None);
328
329    for _ in 0..100 {
330        let random_str: String = (0..rand::random::<u8>() % 50)
331            .map(|_| rand::random::<char>())
332            .collect();
333        let konst = ConstStr::new(&random_str);
334        let mut seen_chars = std::collections::HashSet::new();
335        for char in random_str.chars().rev() {
336            let (char_bytes, len) = char_to_bytes(char);
337            let char_bytes = &char_bytes[..len];
338            assert_eq!(char_bytes, char.to_string().as_bytes());
339            if seen_chars.contains(&char) {
340                continue;
341            }
342            seen_chars.insert(char);
343            let (correct_left, correct_right) = random_str.rsplit_once(char).unwrap();
344            let (left, right) = konst.rsplit_once(char).unwrap();
345            println!("splitting {random_str:?} at {char:?}");
346            assert_eq!(left.as_str(), correct_left);
347            assert_eq!(right.as_str(), correct_right);
348        }
349    }
350}
351
352const CONTINUED_CHAR_MASK: u8 = 0b10000000;
353const BYTE_CHAR_BOUNDARIES: [u8; 4] = [0b00000000, 0b11000000, 0b11100000, 0b11110000];
354
355// Const version of https://doc.rust-lang.org/src/core/char/methods.rs.html#1765-1797
356const fn char_to_bytes(char: char) -> ([u8; 4], usize) {
357    let code = char as u32;
358    let len = char.len_utf8();
359    let mut bytes = [0; 4];
360    match len {
361        1 => {
362            bytes[0] = code as u8;
363        }
364        2 => {
365            bytes[0] = ((code >> 6) & 0x1F) as u8 | BYTE_CHAR_BOUNDARIES[1];
366            bytes[1] = (code & 0x3F) as u8 | CONTINUED_CHAR_MASK;
367        }
368        3 => {
369            bytes[0] = ((code >> 12) & 0x0F) as u8 | BYTE_CHAR_BOUNDARIES[2];
370            bytes[1] = ((code >> 6) & 0x3F) as u8 | CONTINUED_CHAR_MASK;
371            bytes[2] = (code & 0x3F) as u8 | CONTINUED_CHAR_MASK;
372        }
373        4 => {
374            bytes[0] = ((code >> 18) & 0x07) as u8 | BYTE_CHAR_BOUNDARIES[3];
375            bytes[1] = ((code >> 12) & 0x3F) as u8 | CONTINUED_CHAR_MASK;
376            bytes[2] = ((code >> 6) & 0x3F) as u8 | CONTINUED_CHAR_MASK;
377            bytes[3] = (code & 0x3F) as u8 | CONTINUED_CHAR_MASK;
378        }
379        _ => panic!(
380            "encode_utf8: need more than 4 bytes to encode the unicode character, but the buffer has 4 bytes"
381        ),
382    };
383    (bytes, len)
384}
385
386#[test]
387fn fuzz_char_to_bytes() {
388    use std::char;
389    for _ in 0..100 {
390        let char = rand::random::<char>();
391        let (bytes, len) = char_to_bytes(char);
392        let str = std::str::from_utf8(&bytes[..len]).unwrap();
393        assert_eq!(char.to_string(), str);
394    }
395}
396
397const fn utf8_char_boundary_to_char_len(byte: u8) -> Option<u8> {
398    match byte {
399        0b00000000..=0b01111111 => Some(1),
400        0b11000000..=0b11011111 => Some(2),
401        0b11100000..=0b11101111 => Some(3),
402        0b11110000..=0b11111111 => Some(4),
403        _ => None,
404    }
405}
406
407#[test]
408fn fuzz_utf8_byte_to_char_len() {
409    for _ in 0..100 {
410        let random_string: String = (0..rand::random::<u8>())
411            .map(|_| rand::random::<char>())
412            .collect();
413        let bytes = random_string.as_bytes();
414        let chars: std::collections::HashMap<_, _> = random_string.char_indices().collect();
415        for (i, byte) in bytes.iter().enumerate() {
416            match utf8_char_boundary_to_char_len(*byte) {
417                Some(char_len) => {
418                    let char = chars
419                        .get(&i)
420                        .unwrap_or_else(|| panic!("{byte:b} is not a character boundary"));
421                    assert_eq!(char.len_utf8(), char_len as usize);
422                }
423                None => {
424                    assert!(!chars.contains_key(&i), "{byte:b} is a character boundary");
425                }
426            }
427        }
428    }
429}