1#![no_std]
2
3use core::{
4 fmt::{self, Display, Formatter},
5 hash::{Hash, Hasher},
6 marker::PhantomData,
7 ops::Deref,
8 str::{self, FromStr},
9};
10
11pub trait LengthPolicy {
12 fn logical_len(s: &str) -> usize;
13 fn const_logical_len(s: &'static str) -> usize;
14}
15
16#[derive(Clone, Copy, Debug, Default)]
17pub struct Bytes;
18impl LengthPolicy for Bytes {
19 #[inline(always)]
20 fn logical_len(s: &str) -> usize { s.len() }
21 #[inline(always)]
22 fn const_logical_len(s: &'static str) -> usize { s.len() }
23}
24
25#[derive(Clone, Copy, Debug, Default)]
26pub struct Chars;
27impl LengthPolicy for Chars {
28 #[inline(always)]
29 fn logical_len(s: &str) -> usize { s.chars().count() }
30 #[inline(always)]
31 fn const_logical_len(s: &'static str) -> usize { s.chars().count() }
32}
33
34pub trait FormatPolicy {
35 fn check_format(s: &str) -> bool;
36 fn const_check_format(s: &'static str) -> bool;
37}
38
39#[derive(Clone, Copy, Debug, Default)]
40pub struct AllowAll;
41impl FormatPolicy for AllowAll {
42 #[inline(always)]
43 fn check_format(_: &str) -> bool { true }
44 #[inline(always)]
45 fn const_check_format(_: &'static str) -> bool { true }
46}
47
48#[derive(Clone, Copy, Debug, Default)]
49pub struct AsciiOnly;
50impl FormatPolicy for AsciiOnly {
51 #[inline(always)]
52 fn check_format(s: &str) -> bool { s.is_ascii() }
53 #[inline(always)]
54 fn const_check_format(s: &'static str) -> bool { s.is_ascii() }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum BoundedStrError {
59 TooShort,
60 TooLong,
61 TooManyBytes,
62 InvalidContent,
63 MutationFailed,
64}
65
66#[derive(Clone)]
67pub struct BoundedStr<
68 const MIN: usize,
69 const MAX: usize,
70 const MAX_BYTES: usize,
71 L: LengthPolicy = Bytes,
72 F: FormatPolicy = AllowAll,
73> {
74 len: usize,
75 buf: [u8; MAX_BYTES],
76 _marker: PhantomData<(L, F)>,
77}
78
79impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F>
80 BoundedStr<MIN, MAX, MAX_BYTES, L, F>
81where
82 L: LengthPolicy,
83 F: FormatPolicy,
84{
85
86 const _CHECK_BOUNDS: () = {
87 assert!(MIN <= MAX, "MIN must be <= MAX");
88 assert!(MAX <= MAX_BYTES, "MAX must be <= MAX_BYTES");
89 };
90
91
92 #[inline]
93 pub fn new(input: &str) -> Result<Self, BoundedStrError> {
94 let byte_len = input.len();
95 if byte_len > MAX_BYTES {
96 return Err(BoundedStrError::TooManyBytes);
97 }
98 let logical_len = L::logical_len(input);
99 if logical_len < MIN { return Err(BoundedStrError::TooShort); }
100 if logical_len > MAX { return Err(BoundedStrError::TooLong); }
101 if !F::check_format(input) { return Err(BoundedStrError::InvalidContent); }
102 let mut buf = [0u8; MAX_BYTES];
103 buf[..byte_len].copy_from_slice(input.as_bytes());
104 Ok(Self { len: byte_len, buf, _marker: PhantomData })
105 }
106
107 #[inline(always)]
108 pub fn const_new(input: &'static str) -> Result<Self, BoundedStrError> {
109 let byte_len = input.len();
110 if byte_len > MAX_BYTES {
111 return Err(BoundedStrError::TooManyBytes);
112 }
113 let logical_len = L::const_logical_len(input);
114 if logical_len < MIN { return Err(BoundedStrError::TooShort); }
115 if logical_len > MAX { return Err(BoundedStrError::TooLong); }
116 if !F::const_check_format(input) { return Err(BoundedStrError::InvalidContent); }
117 let mut buf = [0u8; MAX_BYTES];
118 let src = input.as_bytes();
119 let mut i = 0;
120 while i < byte_len {
121 buf[i] = src[i];
122 i += 1;
123 }
124 Ok(Self { len: byte_len, buf, _marker: PhantomData })
125 }
126
127 #[inline(always)]
128 pub fn as_str(&self) -> &str {
129 debug_assert!(core::str::from_utf8(&self.buf[..self.len]).is_ok());
130 unsafe { core::str::from_utf8_unchecked(&self.buf[..self.len]) }
131 }
132
133 #[inline(always)] pub fn len_bytes(&self) -> usize { self.len }
134 #[inline(always)] pub fn len_logical(&self) -> usize { L::logical_len(self.as_str()) }
135
136 pub fn mutate<Mut, Res>(&mut self, mutator: Mut) -> Result<Res, BoundedStrError>
137 where Mut: FnOnce(&mut [u8]) -> Res {
138 let old_len = self.len;
139 let res = mutator(&mut self.buf[..old_len]);
140 if let Ok(s) = core::str::from_utf8(&self.buf[..old_len]) {
141 if s.len() != old_len {
142 return Err(BoundedStrError::MutationFailed);
143 }
144 let logical_len = L::logical_len(s);
145 if logical_len < MIN || logical_len > MAX || !F::check_format(s) {
146 return Err(BoundedStrError::MutationFailed);
147 }
148 Ok(res)
149 } else {
150 Err(BoundedStrError::MutationFailed)
151 }
152 }
153
154 #[cfg(feature = "constant-time")]
155 #[inline]
156 pub fn ct_eq(&self, other: &Self) -> bool {
157 if self.len != other.len { return false; }
158 let mut diff: u8 = 0;
159 for i in 0..MAX_BYTES {
160 let a = if i < self.len { self.buf[i] } else { 0 };
161 let b = if i < other.len { other.buf[i] } else { 0 };
162 diff |= a ^ b;
163 }
164 diff == 0
165 }
166}
167
168impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> PartialEq
169 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
170where L: LengthPolicy, F: FormatPolicy
171{
172 fn eq(&self, other: &Self) -> bool { self.as_str() == other.as_str() }
173}
174
175impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Eq
176 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
177where L: LengthPolicy, F: FormatPolicy {}
178
179impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> PartialEq<&str>
180 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
181where L: LengthPolicy, F: FormatPolicy
182{
183 fn eq(&self, other: &&str) -> bool { self.as_str() == *other }
184}
185
186impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Deref
187 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
188where L: LengthPolicy, F: FormatPolicy
189{
190 type Target = str;
191 fn deref(&self) -> &str { self.as_str() }
192}
193
194impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> TryFrom<&str>
195 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
196where L: LengthPolicy, F: FormatPolicy
197{
198 type Error = BoundedStrError;
199 fn try_from(value: &str) -> Result<Self, Self::Error> { Self::new(value) }
200}
201
202impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> FromStr
203 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
204where L: LengthPolicy, F: FormatPolicy
205{
206 type Err = BoundedStrError;
207 fn from_str(s: &str) -> Result<Self, Self::Err> { Self::new(s) }
208}
209
210impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Hash
211 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
212where L: LengthPolicy, F: FormatPolicy
213{
214 fn hash<H: Hasher>(&self, state: &mut H) { self.as_str().hash(state); }
215}
216
217impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Display
218 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
219where L: LengthPolicy, F: FormatPolicy
220{
221 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str(self.as_str()) }
222}
223
224impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> fmt::Debug
225 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
226where L: LengthPolicy, F: FormatPolicy
227{
228 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
229 f.debug_tuple("BoundedStr").field(&self.as_str()).finish()
230 }
231}
232
233#[cfg(feature = "zeroize")]
234impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Drop
235 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
236where L: LengthPolicy, F: FormatPolicy
237{
238 fn drop(&mut self) {
239 for b in &mut self.buf { *b = 0; }
240 }
241}
242
243#[cfg(feature = "serde")]
244mod serde_impl {
245 use super::*;
246 use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
247
248 pub struct Visitor<const MIN: usize, const MAX: usize, const MAXB: usize, L, F> {
249 _marker: PhantomData<(L, F)>,
250 }
251
252 impl<'de, const MIN: usize, const MAX: usize, const MAXB: usize, L, F> de::Visitor<'de>
253 for Visitor<MIN, MAX, MAXB, L, F>
254 where
255 L: LengthPolicy + 'static,
256 F: FormatPolicy + 'static,
257 {
258 type Value = BoundedStr<MIN, MAX, MAXB, L, F>;
259
260 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
261 write!(f, "string [{MIN}..={MAX}]")
262 }
263
264 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
265 where E: de::Error,
266 {
267 BoundedStr::new(v).map_err(|err| match err {
268 BoundedStrError::TooShort | BoundedStrError::TooLong | BoundedStrError::TooManyBytes =>
269 de::Error::invalid_length(v.len(), &self),
270 BoundedStrError::InvalidContent =>
271 de::Error::invalid_value(de::Unexpected::Str(v), &self),
272 _ => de::Error::custom("unexpected error"),
273 })
274 }
275
276 #[cfg(feature = "alloc")]
277 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
278 where E: de::Error,
279 {
280 self.visit_str(&v)
281 }
282 }
283
284 impl<'de, const MIN: usize, const MAX: usize, const MAXB: usize, L, F> Deserialize<'de>
285 for BoundedStr<MIN, MAX, MAXB, L, F>
286 where
287 L: LengthPolicy + 'static,
288 F: FormatPolicy + 'static,
289 {
290 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
291 where D: Deserializer<'de>
292 {
293 deserializer.deserialize_str(Visitor::<MIN, MAX, MAXB, L, F> { _marker: PhantomData })
294 }
295 }
296
297 impl<const MIN: usize, const MAX: usize, const MAXB: usize, L, F> Serialize
298 for BoundedStr<MIN, MAX, MAXB, L, F>
299 where
300 L: LengthPolicy,
301 F: FormatPolicy,
302 {
303 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
304 where S: Serializer {
305 serializer.serialize_str(self.as_str())
306 }
307 }
308}
309
310
311