Skip to main content

bounded_str/
lib.rs

1#![no_std]
2#[cfg(feature = "alloc")]
3extern crate alloc;
4#[cfg(feature = "alloc")]
5use alloc::{vec::Vec};
6
7use core::{
8    fmt::{self, Display, Formatter},
9    hash::{Hash, Hasher},
10    marker::PhantomData,
11    ops::Deref,
12    str::{self, FromStr},
13};
14
15pub trait LengthPolicy {
16    fn logical_len(s: &str) -> usize;
17}
18
19#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
20pub struct Bytes;
21impl LengthPolicy for Bytes {
22    #[inline(always)] fn logical_len(s: &str) -> usize { s.len() }
23}
24
25#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
26pub struct Chars;
27impl LengthPolicy for Chars {
28    #[inline(always)] fn logical_len(s: &str) -> usize { s.chars().count() }
29}
30
31pub trait FormatPolicy {
32    fn check(s: &str) -> bool;
33}
34
35#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
36pub struct AllowAll;
37impl FormatPolicy for AllowAll {
38    #[inline(always)] fn check(_: &str) -> bool { true }
39}
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
42pub struct AsciiOnly;
43impl FormatPolicy for AsciiOnly {
44    #[inline(always)] fn check(s: &str) -> bool { s.is_ascii() }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum BoundedStrError {
49    TooShort,
50    TooLong,
51    TooManyBytes,
52    InvalidContent,
53    MutationFailed,
54}
55
56enum Storage<const MAX_BYTES: usize> {
57    Stack { buf: [u8; MAX_BYTES], len: usize },
58    #[cfg(feature = "alloc")]
59    Heap(Vec<u8>),
60}
61
62impl<const MAX_BYTES: usize> Clone for Storage<MAX_BYTES> {
63    fn clone(&self) -> Self {
64        match self {
65            Self::Stack { buf, len } => Self::Stack { buf: *buf, len: *len },
66            #[cfg(feature = "alloc")]
67            Self::Heap(v) => Self::Heap(v.clone()),
68        }
69    }
70}
71
72pub struct BoundedStr<
73    const MIN: usize,
74    const MAX: usize,
75    const MAX_BYTES: usize,
76    L: LengthPolicy = Bytes,
77    F: FormatPolicy = AllowAll,
78	const Z: bool = false,
79> {
80    storage: Storage<MAX_BYTES>,
81    _marker: PhantomData<(L, F, core::convert::Infallible)>, 
82}
83
84impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool>
85    BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z>
86{
87    const _CHECK: () = {
88        assert!(MIN <= MAX, "MIN must be <= MAX");
89    };
90
91    #[inline(always)]
92	pub fn len_bytes(&self) -> usize {
93        match &self.storage {
94            Storage::Stack { len, .. } => *len,
95            #[cfg(feature = "alloc")]
96            Storage::Heap(v) => v.len(),
97        }
98    }
99
100    #[inline(always)]
101    pub fn len_logical(&self) -> usize {
102        L::logical_len(self.as_str())
103    }
104
105    pub fn new(s: &str) -> Result<Self, BoundedStrError> {
106        let logical_len = L::logical_len(s);
107        if logical_len < MIN { return Err(BoundedStrError::TooShort); }
108        if logical_len > MAX { return Err(BoundedStrError::TooLong); }
109        if !F::check(s) { return Err(BoundedStrError::InvalidContent); }
110
111        let byte_len = s.len();
112
113        #[cfg(feature = "alloc")]
114        if byte_len > MAX_BYTES {
115            return Ok(Self {
116                storage: Storage::Heap(s.as_bytes().to_vec()),
117                _marker: PhantomData,
118            });
119        }
120
121        if byte_len > MAX_BYTES {
122            return Err(BoundedStrError::TooManyBytes);
123        }
124
125        let mut buf = [0u8; MAX_BYTES];
126        buf[..byte_len].copy_from_slice(s.as_bytes());
127        Ok(Self {
128            storage: Storage::Stack { buf, len: byte_len },
129            _marker: PhantomData,
130        })
131    }
132
133    pub fn mutate<Mut, R>(&mut self, mutator: Mut) -> Result<R, BoundedStrError>
134    where
135        Mut: FnOnce(&mut [u8], &mut usize) -> R, 
136    {
137        match &mut self.storage {
138            Storage::Stack { buf, len } => {
139                let mut temp_buf = *buf;
140                let mut temp_len = *len;
141                let res = mutator(&mut temp_buf, &mut temp_len);
142				
143                if temp_len > MAX_BYTES { return Err(BoundedStrError::TooManyBytes); }
144
145                if let Ok(s) = str::from_utf8(&temp_buf[..temp_len]) {
146                    let l_len = L::logical_len(s);
147                    
148                    if l_len >= MIN && l_len <= MAX && F::check(s) {
149                        *buf = temp_buf;
150                        *len = temp_len;
151                        return Ok(res);
152                    }
153                }
154                Err(BoundedStrError::MutationFailed)
155            }
156
157            #[cfg(feature = "alloc")]            
158            Storage::Heap(v) => {
159                let mut temp_vec = v.clone();                
160                let limit = core::cmp::max(MAX, MAX_BYTES);
161                
162                let old_len = temp_vec.len();
163
164                if temp_vec.len() < limit {
165                    temp_vec.resize(limit, 0); 
166                }
167                
168                let mut temp_len = old_len;
169                let res = mutator(&mut temp_vec, &mut temp_len);
170
171                if temp_len > limit { 
172                    Self::clear_temp_vec::<Z>(&mut temp_vec);
173                    return Err(BoundedStrError::TooManyBytes); 
174                }
175
176                temp_vec.truncate(temp_len);
177				
178                if let Ok(s) = str::from_utf8(&temp_vec) {
179                    let l_len = L::logical_len(s);
180                    if l_len >= MIN && l_len <= MAX && F::check(s) {
181                        *v = temp_vec;
182                        return Ok(res);
183                    }
184                }
185
186                Self::clear_temp_vec::<Z>(&mut temp_vec);
187                Err(BoundedStrError::MutationFailed)
188            }
189
190        }
191    }
192
193    #[inline(always)]
194	pub fn as_str(&self) -> &str {
195        match &self.storage {
196            Storage::Stack { buf, len } => unsafe { str::from_utf8_unchecked(&buf[..*len]) },
197            #[cfg(feature = "alloc")]
198            Storage::Heap(v) => unsafe { str::from_utf8_unchecked(v) },
199        }
200    }
201	
202	#[inline(always)]
203    pub fn as_bytes(&self) -> &[u8] {
204        match &self.storage {
205            Storage::Stack { buf, len } => &buf[..*len],
206            #[cfg(feature = "alloc")]
207            Storage::Heap(v) => v.as_slice(),
208        }
209    }
210	
211	#[cfg(feature = "constant-time")]
212	#[inline(never)]
213    fn constant_time_eq(&self, other: &[u8]) -> bool {
214        let a = self.as_bytes();
215        let b = other;
216
217        if a.len() != b.len() {
218            return false;
219        }
220
221        let mut result = 0u8;
222        for i in 0..a.len() {            
223            result |= a[i] ^ b[i];
224        }
225        result == 0
226    }
227	
228	#[inline(always)]
229    fn clear_temp_vec<const ZERO: bool>(v: &mut Vec<u8>) {
230        #[cfg(feature = "zeroize")]
231        if ZERO {
232            for byte in v.iter_mut() {
233                unsafe { core::ptr::write_volatile(byte, 0) };
234            }
235        }
236    }
237}
238
239
240impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool>
241    PartialEq for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z>
242{
243    fn eq(&self, other: &Self) -> bool {
244        #[cfg(feature = "constant-time")]
245        {
246            self.constant_time_eq(other.as_bytes())
247        }
248        #[cfg(not(feature = "constant-time"))]
249        {
250            self.as_str() == other.as_str()
251        }
252    }
253}
254
255impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool> 
256    Clone for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z> {
257    fn clone(&self) -> Self {
258        Self { storage: self.storage.clone(), _marker: PhantomData }
259    }
260}
261
262impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool>
263    Eq for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z> {}
264	
265impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool>
266    PartialEq<&str> for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z>
267{
268    fn eq(&self, other: &&str) -> bool { self.as_str() == *other }
269}
270
271impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool> 
272    Deref for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z> {
273    type Target = str;
274    fn deref(&self) -> &str { self.as_str() }
275}
276
277impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool>
278    TryFrom<&str> for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z>
279{
280    type Error = BoundedStrError;
281    fn try_from(s: &str) -> Result<Self, Self::Error> { Self::new(s) }
282}
283
284impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool>
285    FromStr for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z>
286{
287    type Err = BoundedStrError;
288    fn from_str(s: &str) -> Result<Self, Self::Err> { Self::new(s) }
289}
290
291impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool>
292    Hash for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z>
293{
294    fn hash<H: Hasher>(&self, state: &mut H) { self.as_str().hash(state) }
295}
296
297impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool>
298    Display for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z>
299{
300    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str(self.as_str()) }
301}
302
303impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool>
304    fmt::Debug for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z>
305{
306    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
307        f.debug_struct("BoundedStr")
308            .field("value", &self.as_str())
309            .field("len_bytes", &self.len_bytes())
310            .field("len_logical", &self.len_logical())
311            .finish()
312    }
313}
314
315impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool> 
316    Drop for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z> 
317{
318    #[inline(always)]
319    fn drop(&mut self) {
320        #[cfg(feature = "zeroize")]
321        if Z {
322            match &mut self.storage {
323                Storage::Stack { buf, .. } => {
324                    for byte in buf.iter_mut() {
325                        unsafe { core::ptr::write_volatile(byte, 0) };
326                    }
327                }
328                #[cfg(feature = "alloc")]
329                Storage::Heap(v) => {
330                    for byte in v.iter_mut() {
331                        unsafe { core::ptr::write_volatile(byte, 0) };
332                    }
333                }
334            }
335        }
336    }
337}
338
339
340#[cfg(feature = "serde")]
341impl<'de, const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool> 
342    serde::Deserialize<'de> for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z> 
343{
344    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
345    where
346        D: serde::Deserializer<'de>,
347    {
348        let s = <&str>::deserialize(deserializer)?;
349        
350        Self::new(s).map_err(|e| {
351            serde::de::Error::custom(match e {
352                BoundedStrError::TooShort => "string too short",
353                BoundedStrError::TooLong => "string too long",
354                BoundedStrError::TooManyBytes => "too many bytes for buffer",
355                BoundedStrError::InvalidContent => "invalid content format",
356                BoundedStrError::MutationFailed => "mutation failed",
357            })
358        })
359    }
360}
361
362#[cfg(feature = "serde")]
363impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L: LengthPolicy, F: FormatPolicy, const Z: bool> 
364    serde::Serialize for BoundedStr<MIN, MAX, MAX_BYTES, L, F, Z> 
365{
366    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
367    where
368        S: serde::Serializer,
369    {
370        serializer.serialize_str(self.as_str())
371    }
372}
373
374
375pub type StackStr<const MIN: usize, const MAX: usize, const MAXB: usize = MAX, L = Bytes, F = AllowAll, const Z: bool = false > = BoundedStr<MIN, MAX, MAXB, L, F, Z>;
376
377#[cfg(feature = "alloc")]
378pub type FlexStr<const MIN: usize, const MAX: usize, const MAXB: usize = 4096, L = Bytes, F = AllowAll, const Z: bool = false > = BoundedStr<MIN, MAX, MAXB, L, F, Z>;