1#![no_std]
2#[cfg(feature = "alloc")]
3extern crate alloc;
4
5#[cfg(feature = "alloc")]
6use alloc::{vec::Vec, string::String};
7
8use core::{
9 fmt::{self, Display, Formatter},
10 hash::{Hash, Hasher},
11 marker::PhantomData,
12 ops::Deref,
13 str::{self, FromStr},
14};
15
16pub trait LengthPolicy {
17 fn logical_len(s: &str) -> usize;
18 fn const_logical_len(s: &'static str) -> usize;
19}
20
21#[derive(Clone, Copy, Debug, Default)]
22pub struct Bytes;
23impl LengthPolicy for Bytes {
24 #[inline(always)]
25 fn logical_len(s: &str) -> usize { s.len() }
26 #[inline(always)]
27 fn const_logical_len(s: &'static str) -> usize { s.len() }
28}
29
30#[derive(Clone, Copy, Debug, Default)]
31pub struct Chars;
32impl LengthPolicy for Chars {
33 #[inline(always)]
34 fn logical_len(s: &str) -> usize { s.chars().count() }
35 #[inline(always)]
36 fn const_logical_len(s: &'static str) -> usize { s.chars().count() }
37}
38
39pub trait FormatPolicy {
40 fn check_format(s: &str) -> bool;
41 fn const_check_format(s: &'static str) -> bool;
42}
43
44
45#[derive(Clone, Copy, Debug, Default)]
46pub struct AllowAll;
47impl FormatPolicy for AllowAll {
48 #[inline(always)]
49 fn check_format(_: &str) -> bool { true }
50 #[inline(always)]
51 fn const_check_format(_: &'static str) -> bool { true }
52}
53
54#[derive(Clone, Copy, Debug, Default)]
55pub struct AsciiOnly;
56impl FormatPolicy for AsciiOnly {
57 #[inline(always)]
58 fn check_format(s: &str) -> bool { s.is_ascii() }
59 #[inline(always)]
60 fn const_check_format(s: &'static str) -> bool { s.is_ascii() }
61}
62
63
64pub trait StoragePolicy {
65 const ALLOW_HEAP: bool;
66}
67
68#[derive(Clone, Copy, Debug, Default)]
69pub struct StackOnly;
70impl StoragePolicy for StackOnly {
71 const ALLOW_HEAP: bool = false; }
73
74#[derive(Clone, Copy, Debug, Default)]
75pub struct HeapAllowed;
76impl StoragePolicy for HeapAllowed {
77 const ALLOW_HEAP: bool = true; }
79
80
81#[derive(Debug, Clone, Copy, PartialEq, Eq)]
82pub enum BoundedStrError {
83 TooShort,
84 TooLong,
85 TooManyBytes,
86 InvalidContent,
87 MutationFailed,
88}
89
90
91pub struct BoundedStr<
92 const MIN: usize,
93 const MAX: usize,
94 const MAX_BYTES: usize,
95 L: LengthPolicy = Bytes,
96 F: FormatPolicy = AllowAll,
97 S: StoragePolicy = StackOnly,
98> {
99 len: usize,
100 buf: [u8; MAX_BYTES],
101 #[cfg(feature = "alloc")]
102 heap_buf: Option<Vec<u8>>,
103 _marker: PhantomData<(L, F, S)>,
104}
105
106impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F, S>
107 BoundedStr<MIN, MAX, MAX_BYTES, L, F, S>
108where
109 L: LengthPolicy,
110 F: FormatPolicy,
111 S: StoragePolicy,
112{
113 const _CHECK_BOUNDS: () = {
114 assert!(MIN <= MAX, "MIN must be <= MAX");
115 assert!(MAX <= MAX_BYTES, "MAX must be <= MAX_BYTES");
116 };
117
118
119 #[inline]
120 pub fn new(input: &str) -> Result<Self, BoundedStrError> {
121 let byte_len = input.len();
122 let logical_len = L::logical_len(input);
123
124 if logical_len < MIN { return Err(BoundedStrError::TooShort); }
125 if logical_len > MAX { return Err(BoundedStrError::TooLong); }
126 if !F::check_format(input) { return Err(BoundedStrError::InvalidContent); }
127
128if byte_len <= MAX_BYTES || !S::ALLOW_HEAP {
129 let mut buf = [0u8; MAX_BYTES];
131 buf[..byte_len].copy_from_slice(input.as_bytes());
132 #[cfg(feature = "alloc")]
133 let heap_buf = None;
134 Ok(Self { len: byte_len, buf, #[cfg(feature = "alloc")] heap_buf, _marker: PhantomData })
135} else {
136 #[cfg(feature = "alloc")]
137 {
138 let heap_vec = input.as_bytes().to_vec();
139 let mut buf = [0u8; MAX_BYTES];
140 buf[..MAX_BYTES].copy_from_slice(&heap_vec[..MAX_BYTES]);
141 Ok(Self { len: byte_len, buf, heap_buf: Some(heap_vec), _marker: PhantomData })
142 }
143}
144 }
145
146
147 #[inline(always)]
148pub fn const_new(input: &'static str) -> Result<Self, BoundedStrError> {
149 let byte_len = input.len();
150 if byte_len > MAX_BYTES { return Err(BoundedStrError::TooManyBytes); }
151 let logical_len = L::const_logical_len(input);
152 if logical_len < MIN { return Err(BoundedStrError::TooShort); }
153 if logical_len > MAX { return Err(BoundedStrError::TooLong); }
154 if !F::const_check_format(input) { return Err(BoundedStrError::InvalidContent); }
155
156 let mut buf = [0u8; MAX_BYTES];
157 buf[..byte_len].copy_from_slice(input.as_bytes());
158
159 #[cfg(feature = "alloc")]
160 let heap_buf = None; Ok(Self { len: byte_len, buf, #[cfg(feature = "alloc")] heap_buf, _marker: PhantomData })
163}
164
165 #[inline(always)]
166 pub fn as_str(&self) -> &str {
167 if self.len <= MAX_BYTES || !S::ALLOW_HEAP {
168 debug_assert!(core::str::from_utf8(&self.buf[..self.len]).is_ok());
169 unsafe { core::str::from_utf8_unchecked(&self.buf[..self.len]) }
170 } else {
171 #[cfg(feature = "alloc")]
172{
173 if let Some(heap_vec) = &self.heap_buf {
174 unsafe { core::str::from_utf8_unchecked(heap_vec) }
175 } else {
176 unsafe { core::str::from_utf8_unchecked(&self.buf[..self.len]) }
177 }
178}
179 }
180 }
181
182 #[inline(always)] pub fn len_bytes(&self) -> usize { self.len }
183 #[inline(always)] pub fn len_logical(&self) -> usize { L::logical_len(self.as_str()) }
184
185 pub fn mutate<Mut, Res>(&mut self, mutator: Mut) -> Result<Res, BoundedStrError>
186 where Mut: FnOnce(&mut [u8]) -> Res {
187 let old_len = self.len;
188 let res = mutator(&mut self.buf[..old_len]);
189 if let Ok(s) = core::str::from_utf8(&self.buf[..old_len]) {
190 if s.len() != old_len {
191 return Err(BoundedStrError::MutationFailed);
192 }
193 let logical_len = L::logical_len(s);
194 if logical_len < MIN || logical_len > MAX || !F::check_format(s) {
195 return Err(BoundedStrError::MutationFailed);
196 }
197 Ok(res)
198 } else {
199 Err(BoundedStrError::MutationFailed)
200 }
201 }
202
203 #[cfg(feature = "constant-time")]
204 #[inline]
205 pub fn ct_eq(&self, other: &Self) -> bool {
206 if self.len != other.len { return false; }
207 let mut diff: u8 = 0;
208 for i in 0..MAX_BYTES {
209 let a = if i < self.len { self.buf[i] } else { 0 };
210 let b = if i < other.len { other.buf[i] } else { 0 };
211 diff |= a ^ b;
212 }
213 diff == 0
214 }
215}
216
217impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> PartialEq
218 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
219where L: LengthPolicy, F: FormatPolicy
220{
221 fn eq(&self, other: &Self) -> bool { self.as_str() == other.as_str() }
222}
223
224impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Eq
225 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
226where L: LengthPolicy, F: FormatPolicy {}
227
228impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> PartialEq<&str>
229 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
230where L: LengthPolicy, F: FormatPolicy
231{
232 fn eq(&self, other: &&str) -> bool { self.as_str() == *other }
233}
234
235impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Deref
236 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
237where L: LengthPolicy, F: FormatPolicy
238{
239 type Target = str;
240 fn deref(&self) -> &str { self.as_str() }
241}
242
243impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> TryFrom<&str>
244 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
245where L: LengthPolicy, F: FormatPolicy
246{
247 type Error = BoundedStrError;
248 fn try_from(value: &str) -> Result<Self, Self::Error> { Self::new(value) }
249}
250
251impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> FromStr
252 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
253where L: LengthPolicy, F: FormatPolicy
254{
255 type Err = BoundedStrError;
256 fn from_str(s: &str) -> Result<Self, Self::Err> { Self::new(s) }
257}
258
259impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Hash
260 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
261where L: LengthPolicy, F: FormatPolicy
262{
263 fn hash<H: Hasher>(&self, state: &mut H) { self.as_str().hash(state); }
264}
265
266impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F> Display
267 for BoundedStr<MIN, MAX, MAX_BYTES, L, F>
268where L: LengthPolicy, F: FormatPolicy
269{
270 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str(self.as_str()) }
271}
272
273impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F, S> fmt::Debug
274 for BoundedStr<MIN, MAX, MAX_BYTES, L, F, S>
275where
276 L: LengthPolicy,
277 F: FormatPolicy,
278 S: StoragePolicy,
279{
280 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
281 f.debug_struct("BoundedStr")
282 .field("value", &self.as_str())
283 .field("len_bytes", &self.len_bytes())
284 .field("len_logical", &self.len_logical())
285 .finish()
286 }
287}
288
289#[cfg(feature = "zeroize")]
290impl<const MIN: usize, const MAX: usize, const MAX_BYTES: usize, L, F, S> Drop
291 for BoundedStr<MIN, MAX, MAX_BYTES, L, F, S>
292where
293 L: LengthPolicy,
294 F: FormatPolicy,
295 S: StoragePolicy,
296{
297 fn drop(&mut self) {
298 for b in &mut self.buf { *b = 0; }
299
300 #[cfg(feature = "alloc")]
301 if S::ALLOW_HEAP && let Some(heap_vec) = &mut self.heap_buf {
302 for b in heap_vec.iter_mut() {
303 *b = 0;
304 }
305 }
306 }
307}
308
309#[cfg(feature = "serde")]
310mod serde_impl {
311 use super::*;
312 use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
313
314 pub struct Visitor<const MIN: usize, const MAX: usize, const MAXB: usize, L, F> {
315 _marker: PhantomData<(L, F)>,
316 }
317
318 impl<'de, const MIN: usize, const MAX: usize, const MAXB: usize, L, F> de::Visitor<'de>
319 for Visitor<MIN, MAX, MAXB, L, F>
320 where
321 L: LengthPolicy + 'static,
322 F: FormatPolicy + 'static,
323 {
324 type Value = BoundedStr<MIN, MAX, MAXB, L, F>;
325
326 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
327 write!(f, "string [{MIN}..={MAX}]")
328 }
329
330 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
331 where E: de::Error,
332 {
333 BoundedStr::new(v).map_err(|err| match err {
334 BoundedStrError::TooShort | BoundedStrError::TooLong | BoundedStrError::TooManyBytes =>
335 de::Error::invalid_length(v.len(), &self),
336 BoundedStrError::InvalidContent =>
337 de::Error::invalid_value(de::Unexpected::Str(v), &self),
338 _ => de::Error::custom("unexpected error"),
339 })
340 }
341
342 #[cfg(feature = "alloc")]
343 fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
344 where E: de::Error,
345 {
346 self.visit_str(&v)
347 }
348 }
349
350 impl<'de, const MIN: usize, const MAX: usize, const MAXB: usize, L, F> Deserialize<'de>
351 for BoundedStr<MIN, MAX, MAXB, L, F>
352 where
353 L: LengthPolicy + 'static,
354 F: FormatPolicy + 'static,
355 {
356 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
357 where D: Deserializer<'de>
358 {
359 deserializer.deserialize_str(Visitor::<MIN, MAX, MAXB, L, F> { _marker: PhantomData })
360 }
361 }
362
363 impl<const MIN: usize, const MAX: usize, const MAXB: usize, L, F> Serialize
364 for BoundedStr<MIN, MAX, MAXB, L, F>
365 where
366 L: LengthPolicy,
367 F: FormatPolicy,
368 {
369 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
370 where S: Serializer {
371 serializer.serialize_str(self.as_str())
372 }
373 }
374}
375
376
377
378pub type StackStr<const MIN: usize, const MAX: usize, const MAXB: usize = MAX, L = Bytes, F = AllowAll> =
380 BoundedStr<MIN, MAX, MAXB, L, F, StackOnly>;
381
382pub type FlexStr<const MIN: usize, const MAX: usize, const MAXB: usize = 4096, L = Bytes, F = AllowAll> =
384 BoundedStr<MIN, MAX, MAXB, L, F, HeapAllowed>;