Skip to main content

bounded_str/
lib.rs

1#![no_std]
2
3use core::{
4    fmt::{self, Display, Formatter},
5    hash::{Hash, Hasher},
6    marker::PhantomData,
7    ops::Deref,
8    str::{self, FromStr},
9};
10
11pub trait LengthPolicy {
12    fn logical_len(s: &str) -> usize;
13    fn const_logical_len(s: &'static str) -> usize;
14}
15
16#[derive(Clone, Copy, Debug, Default)]
17pub struct Bytes;
18impl LengthPolicy for Bytes {
19    #[inline(always)]
20    fn logical_len(s: &str) -> usize { s.len() }
21    #[inline(always)]
22    fn const_logical_len(s: &'static str) -> usize { s.len() }
23}
24
25#[derive(Clone, Copy, Debug, Default)]
26pub struct Chars;
27impl LengthPolicy for Chars {
28    #[inline(always)]
29    fn logical_len(s: &str) -> usize { s.chars().count() }
30    #[inline(always)]
31    fn const_logical_len(s: &'static str) -> usize { s.chars().count() }
32}
33
34pub trait FormatPolicy {
35    fn check_format(s: &str) -> bool;
36    fn const_check_format(s: &'static str) -> bool;
37}
38
39#[derive(Clone, Copy, Debug, Default)]
40pub struct AllowAll;
41impl FormatPolicy for AllowAll {
42    #[inline(always)]
43    fn check_format(_: &str) -> bool { true }
44    #[inline(always)]
45    fn const_check_format(_: &'static str) -> bool { true }
46}
47
48#[derive(Clone, Copy, Debug, Default)]
49pub struct AsciiOnly;
50impl FormatPolicy for AsciiOnly {
51    #[inline(always)]
52    fn check_format(s: &str) -> bool { s.is_ascii() }
53    #[inline(always)]
54    fn const_check_format(s: &'static str) -> bool { s.is_ascii() }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum BoundedStrError {
59    TooShort,
60    TooLong,
61    TooManyBytes,
62    InvalidContent,
63    MutationFailed,
64}
65
66#[derive(Clone)]
67pub struct BoundedStr<
68    const MIN: usize,
69    const MAX: usize,
70    const MAX_BYTES: usize,
71    L: LengthPolicy = Bytes,
72    F: FormatPolicy = AllowAll,
73> {
74    len: usize,
75    buf: [u8; MAX_BYTES],
76    _marker: PhantomData<(L, F)>,
77}
78
79impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F>
80    BoundedStr<MIN, MAX, MAX_BYTES, L, F>
81where
82    L: LengthPolicy,
83    F: FormatPolicy,
84{
85 
86     const _CHECK_BOUNDS: () = {
87        assert!(MIN <= MAX, "MIN must be <= MAX");
88        assert!(MAX <= MAX_BYTES, "MAX must be <= MAX_BYTES");
89    };
90
91	
92    #[inline]
93    pub fn new(input: &str) -> Result<Self, BoundedStrError> {
94        let byte_len = input.len();
95        if byte_len > MAX_BYTES {
96            return Err(BoundedStrError::TooManyBytes);
97        }
98        let logical_len = L::logical_len(input);
99        if logical_len < MIN { return Err(BoundedStrError::TooShort); }
100        if logical_len > MAX { return Err(BoundedStrError::TooLong); }
101        if !F::check_format(input) { return Err(BoundedStrError::InvalidContent); }
102        let mut buf = [0u8; MAX_BYTES];
103        buf[..byte_len].copy_from_slice(input.as_bytes());
104        Ok(Self { len: byte_len, buf, _marker: PhantomData })
105    }
106
107    #[inline(always)]
108    pub fn const_new(input: &'static str) -> Result<Self, BoundedStrError> {
109        let byte_len = input.len();
110        if byte_len > MAX_BYTES {
111            return Err(BoundedStrError::TooManyBytes);
112        }
113        let logical_len = L::const_logical_len(input);
114        if logical_len < MIN { return Err(BoundedStrError::TooShort); }
115        if logical_len > MAX { return Err(BoundedStrError::TooLong); }
116        if !F::const_check_format(input) { return Err(BoundedStrError::InvalidContent); }
117        let mut buf = [0u8; MAX_BYTES];
118        let src = input.as_bytes();
119        let mut i = 0;
120        while i < byte_len {
121            buf[i] = src[i];
122            i += 1;
123        }
124        Ok(Self { len: byte_len, buf, _marker: PhantomData })
125    }
126
127    #[inline(always)]
128    pub fn as_str(&self) -> &str {
129        debug_assert!(core::str::from_utf8(&self.buf[..self.len]).is_ok());
130        unsafe { core::str::from_utf8_unchecked(&self.buf[..self.len]) }
131    }
132
133    #[inline(always)] pub fn len_bytes(&self) -> usize { self.len }
134    #[inline(always)] pub fn len_logical(&self) -> usize { L::logical_len(self.as_str()) }
135
136    pub fn mutate<Mut, Res>(&mut self, mutator: Mut) -> Result<Res, BoundedStrError>
137    where Mut: FnOnce(&mut [u8]) -> Res {
138        let old_len = self.len;
139        let res = mutator(&mut self.buf[..old_len]);
140        if let Ok(s) = core::str::from_utf8(&self.buf[..old_len]) {
141            if s.len() != old_len {
142                return Err(BoundedStrError::MutationFailed);
143            }
144            let logical_len = L::logical_len(s);
145            if logical_len < MIN || logical_len > MAX || !F::check_format(s) {
146                return Err(BoundedStrError::MutationFailed);
147            }
148            Ok(res)
149        } else {
150            Err(BoundedStrError::MutationFailed)
151        }
152    }
153
154    #[cfg(feature = "constant-time")]
155    #[inline]
156    pub fn ct_eq(&self, other: &Self) -> bool {
157        if self.len != other.len { return false; }
158        let mut diff: u8 = 0;
159        for i in 0..MAX_BYTES {
160            let a = if i < self.len { self.buf[i] } else { 0 };
161            let b = if i < other.len { other.buf[i] } else { 0 };
162            diff |= a ^ b;
163        }
164        diff == 0
165    }
166}
167
168impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> PartialEq
169    for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
170where L: LengthPolicy, F: FormatPolicy
171{
172    fn eq(&self, other: &Self) -> bool { self.as_str() == other.as_str() }
173}
174
175impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Eq
176    for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
177where L: LengthPolicy, F: FormatPolicy {}
178
179impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> PartialEq<&str>
180    for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
181where L: LengthPolicy, F: FormatPolicy
182{
183    fn eq(&self, other: &&str) -> bool { self.as_str() == *other }
184}
185
186impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Deref
187    for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
188where L: LengthPolicy, F: FormatPolicy
189{
190    type Target = str;
191    fn deref(&self) -> &str { self.as_str() }
192}
193
194impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> TryFrom<&str>
195    for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
196where L: LengthPolicy, F: FormatPolicy
197{
198    type Error = BoundedStrError;
199    fn try_from(value: &str) -> Result<Self, Self::Error> { Self::new(value) }
200}
201
202impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> FromStr
203    for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
204where L: LengthPolicy, F: FormatPolicy
205{
206    type Err = BoundedStrError;
207    fn from_str(s: &str) -> Result<Self, Self::Err> { Self::new(s) }
208}
209
210impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Hash
211    for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
212where L: LengthPolicy, F: FormatPolicy
213{
214    fn hash<H: Hasher>(&self, state: &mut H) { self.as_str().hash(state); }
215}
216
217impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Display
218    for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
219where L: LengthPolicy, F: FormatPolicy
220{
221    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str(self.as_str()) }
222}
223
224impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> fmt::Debug
225    for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
226where L: LengthPolicy, F: FormatPolicy
227{
228    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
229        f.debug_tuple("BoundedStr").field(&self.as_str()).finish()
230    }
231}
232
233#[cfg(feature = "zeroize")]
234impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Drop
235    for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
236where L: LengthPolicy, F: FormatPolicy
237{
238    fn drop(&mut self) {
239        for b in &mut self.buf { *b = 0; }
240    }
241}
242
243#[cfg(feature = "serde")]
244mod serde_impl {
245    use super::*;
246    use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
247
248    pub struct Visitor<const MIN: usize, const MAX: usize, const MAXB: usize, L, F> {
249        _marker: PhantomData<(L, F)>,
250    }
251
252    impl<'de, const MIN: usize, const MAX: usize, const MAXB: usize, L, F> de::Visitor<'de>
253        for Visitor<MIN, MAX, MAXB, L, F>
254    where
255        L: LengthPolicy + 'static,
256        F: FormatPolicy + 'static,
257    {
258        type Value = BoundedStr<MIN, MAX, MAXB, L, F>;
259
260        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
261            write!(f, "string [{MIN}..={MAX}]")
262        }
263
264        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
265        where E: de::Error,
266        {
267            BoundedStr::new(v).map_err(|err| match err {
268                BoundedStrError::TooShort | BoundedStrError::TooLong | BoundedStrError::TooManyBytes =>
269                    de::Error::invalid_length(v.len(), &self),
270                BoundedStrError::InvalidContent =>
271                    de::Error::invalid_value(de::Unexpected::Str(v), &self),
272                _ => de::Error::custom("unexpected error"),
273            })
274        }
275
276        #[cfg(feature = "alloc")]
277        fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
278        where E: de::Error,
279        {
280            self.visit_str(&v)
281        }
282    }
283
284    impl<'de, const MIN: usize, const MAX: usize, const MAXB: usize, L, F> Deserialize<'de>
285        for BoundedStr<MIN, MAX, MAXB, L, F>
286    where
287        L: LengthPolicy + 'static,
288        F: FormatPolicy + 'static,
289    {
290        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
291        where D: Deserializer<'de>
292        {
293            deserializer.deserialize_str(Visitor::<MIN, MAX, MAXB, L, F> { _marker: PhantomData })
294        }
295    }
296
297    impl<const MIN: usize, const MAX: usize, const MAXB: usize, L, F> Serialize
298        for BoundedStr<MIN, MAX, MAXB, L, F>
299    where
300        L: LengthPolicy,
301        F: FormatPolicy,
302    {
303        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
304        where S: Serializer {
305            serializer.serialize_str(self.as_str())
306        }
307    }
308}
309
310
311