1use crate::limited::{LimitedVec, private::LimitedVisitor};
7use alloc::{
8 borrow::Cow,
9 string::{FromUtf8Error, String},
10};
11use derive_more::{AsRef, Deref, Display, Into};
12use parity_scale_codec::{Decode, Encode};
13use scale_decode::{
14 IntoVisitor, TypeResolver, Visitor,
15 error::ErrorKind,
16 visitor::{
17 TypeIdFor, Unexpected,
18 types::{Composite, Str, Tuple},
19 },
20};
21use scale_encode::EncodeAsType;
22use scale_info::TypeInfo;
23
24#[derive(
32 Debug,
33 Display,
34 Clone,
35 Default,
36 PartialEq,
37 Eq,
38 PartialOrd,
39 Ord,
40 Encode,
41 EncodeAsType,
42 Hash,
43 TypeInfo,
44 AsRef,
45 Deref,
46 Into,
47)]
48#[as_ref(forward)]
49#[deref(forward)]
50pub struct LimitedStr<'a, const N: usize = 1024>(Cow<'a, str>);
51
52fn nearest_char_boundary(s: &str, pos: usize) -> usize {
55 (0..=pos.min(s.len()))
56 .rev()
57 .find(|&pos| s.is_char_boundary(pos))
58 .unwrap_or(0)
59}
60
61impl<'a, const N: usize> LimitedStr<'a, N> {
62 pub const MAX_LEN: usize = N;
64
65 pub fn from_utf8(vec: LimitedVec<u8, N>) -> Result<Self, FromUtf8Error> {
68 String::from_utf8(vec.into_vec()).map(Cow::Owned).map(Self)
69 }
70
71 pub fn try_new<S: Into<Cow<'a, str>>>(s: S) -> Result<Self, LimitedStrError> {
75 let s = s.into();
76
77 if s.len() > Self::MAX_LEN {
78 Err(LimitedStrError)
79 } else {
80 Ok(Self(s))
81 }
82 }
83
84 pub fn truncated<S: Into<Cow<'a, str>>>(s: S) -> Self {
87 let s = s.into();
88 let truncation_pos = nearest_char_boundary(&s, Self::MAX_LEN);
89
90 match s {
91 Cow::Borrowed(s) => Self(s[..truncation_pos].into()),
92 Cow::Owned(mut s) => {
93 s.truncate(truncation_pos);
94 Self(s.into())
95 }
96 }
97 }
98
99 #[track_caller]
111 pub const fn from_small_str(s: &'static str) -> Self {
112 if s.len() > Self::MAX_LEN {
113 panic!("{}", LimitedStrError::MESSAGE)
114 }
115
116 Self(Cow::Borrowed(s))
117 }
118
119 pub fn as_str(&self) -> &str {
121 self.as_ref()
122 }
123
124 pub fn into_inner(self) -> Cow<'a, str> {
126 self.0
127 }
128}
129
130impl<'a, const N: usize> TryFrom<&'a str> for LimitedStr<'a, N> {
131 type Error = LimitedStrError;
132
133 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
134 Self::try_new(value)
135 }
136}
137
138impl<'a, const N: usize> TryFrom<String> for LimitedStr<'a, N> {
139 type Error = LimitedStrError;
140
141 fn try_from(value: String) -> Result<Self, Self::Error> {
142 Self::try_new(value)
143 }
144}
145
146impl<'a, const N: usize> Decode for LimitedStr<'a, N> {
147 fn decode<I: parity_scale_codec::Input>(
148 input: &mut I,
149 ) -> Result<Self, parity_scale_codec::Error> {
150 LimitedVec::decode(input)
151 .and_then(|vec| Self::from_utf8(vec).map_err(|_| "Invalid UTF-8 sequence".into()))
152 }
153}
154
155impl<'a, Resolver, const N: usize> Visitor for LimitedVisitor<LimitedStr<'a, N>, Resolver>
156where
157 Resolver: TypeResolver,
158{
159 type Value<'scale, 'resolver> = LimitedStr<'a, N>;
160 type Error = scale_decode::Error;
161 type TypeResolver = Resolver;
162
163 fn visit_str<'scale, 'resolver>(
164 self,
165 value: &mut Str<'scale>,
166 type_id: TypeIdFor<Self>,
167 ) -> Result<Self::Value<'scale, 'resolver>, Self::Error> {
168 if value.len() > N {
169 return Err(scale_decode::Error::new(ErrorKind::WrongLength {
170 actual_len: value.len(),
171 expected_len: N,
172 }));
173 }
174
175 String::into_visitor::<Resolver>()
176 .visit_str(value, type_id)
177 .map(Cow::Owned)
178 .map(LimitedStr)
179 }
180
181 fn visit_composite<'scale, 'resolver>(
182 self,
183 value: &mut Composite<'scale, 'resolver, Resolver>,
184 _type_id: TypeIdFor<Self>,
185 ) -> Result<Self::Value<'scale, 'resolver>, Self::Error> {
186 if value.remaining() != 1 {
187 return self.visit_unexpected(Unexpected::Composite);
188 }
189
190 value.decode_item(self).unwrap()
191 }
192
193 fn visit_tuple<'scale, 'resolver>(
194 self,
195 value: &mut Tuple<'scale, 'resolver, Resolver>,
196 _type_id: TypeIdFor<Self>,
197 ) -> Result<Self::Value<'scale, 'resolver>, Self::Error> {
198 if value.remaining() != 1 {
199 return self.visit_unexpected(Unexpected::Tuple);
200 }
201 value.decode_item(self).unwrap()
202 }
203}
204
205#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Display)]
207#[display("{}", Self::MESSAGE)]
208pub struct LimitedStrError;
209
210impl LimitedStrError {
211 pub const MESSAGE: &str = "string length limit is exceeded";
213
214 pub const fn as_str(&self) -> &'static str {
216 Self::MESSAGE
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use rand::{Rng, distributions::Standard};
224
225 fn assert_result(string: &'static str, max_bytes: usize, expectation: &'static str) {
226 let string = &string[..nearest_char_boundary(string, max_bytes)];
227 assert_eq!(string, expectation);
228 }
229
230 fn check_panicking(initial_string: &'static str, upper_boundary: usize) {
231 let initial_size = initial_string.len();
232
233 for max_bytes in 0..=upper_boundary {
234 let string = &initial_string[..nearest_char_boundary(initial_string, max_bytes)];
235
236 if max_bytes >= initial_size {
238 assert_eq!(string, initial_string);
239 }
240 }
241 }
242
243 #[test]
244 fn truncate_test() {
245 let utf_8 = "hello";
247 assert_eq!(utf_8.len(), 5);
249 assert_eq!(utf_8.chars().count(), 5);
251
252 check_panicking(utf_8, utf_8.len().saturating_mul(2));
256
257 assert_result(utf_8, 0, "");
259 assert_result(utf_8, 1, "h");
260 assert_result(utf_8, 2, "he");
261 assert_result(utf_8, 3, "hel");
262 assert_result(utf_8, 4, "hell");
263 assert_result(utf_8, 5, "hello");
264 assert_result(utf_8, 6, "hello");
265
266 let cjk = "你好吗";
268 assert_eq!(cjk.len(), 9);
270 assert_eq!(cjk.chars().count(), 3);
272 assert!(cjk.chars().all(|c| c.len_utf8() == 3));
274
275 check_panicking(cjk, cjk.len().saturating_mul(2));
279
280 assert_result(cjk, 0, "");
282 assert_result(cjk, 1, "");
283 assert_result(cjk, 2, "");
284 assert_result(cjk, 3, "你");
285 assert_result(cjk, 4, "你");
286 assert_result(cjk, 5, "你");
287 assert_result(cjk, 6, "你好");
288 assert_result(cjk, 7, "你好");
289 assert_result(cjk, 8, "你好");
290 assert_result(cjk, 9, "你好吗");
291 assert_result(cjk, 10, "你好吗");
292
293 let mix = "你he好l吗lo";
297 assert_eq!(mix.len(), utf_8.len() + cjk.len());
298 assert_eq!(mix.len(), 14);
299 assert_eq!(
301 mix.chars().count(),
302 utf_8.chars().count() + cjk.chars().count()
303 );
304 assert_eq!(mix.chars().count(), 8);
305
306 check_panicking(mix, mix.len().saturating_mul(2));
310
311 assert_result(mix, 0, "");
313 assert_result(mix, 1, "");
314 assert_result(mix, 2, "");
315 assert_result(mix, 3, "你");
316 assert_result(mix, 4, "你h");
317 assert_result(mix, 5, "你he");
318 assert_result(mix, 6, "你he");
319 assert_result(mix, 7, "你he");
320 assert_result(mix, 8, "你he好");
321 assert_result(mix, 9, "你he好l");
322 assert_result(mix, 10, "你he好l");
323 assert_result(mix, 11, "你he好l");
324 assert_result(mix, 12, "你he好l吗");
325 assert_result(mix, 13, "你he好l吗l");
326 assert_result(mix, 14, "你he好l吗lo");
327 assert_result(mix, 15, "你he好l吗lo");
328
329 assert_eq!(LimitedStr::<1>::truncated(String::from(mix)).as_str(), "");
330 assert_eq!(LimitedStr::<5>::truncated(mix).as_str(), "你he");
331 assert_eq!(
332 LimitedStr::<9>::truncated(String::from(mix)).as_str(),
333 "你he好l"
334 );
335 assert_eq!(LimitedStr::<13>::truncated(mix).as_str(), "你he好l吗l");
336 }
337
338 #[test]
339 fn truncate_test_fuzz() {
340 for _ in 0..50 {
341 let mut thread_rng = rand::thread_rng();
342
343 let rand_len = thread_rng.gen_range(0..=100_000);
344 let max_bytes = thread_rng.gen_range(0..=rand_len);
345 let mut string = thread_rng
346 .sample_iter::<char, _>(Standard)
347 .take(rand_len)
348 .collect::<String>();
349 string.truncate(nearest_char_boundary(&string, max_bytes));
350
351 if string.len() > max_bytes {
352 panic!("String '{}' input invalidated algorithms property", string);
353 }
354 }
355 }
356
357 #[test]
358 fn test_decode() {
359 let normal_str = "amogus attacks";
361 let encoded_str = normal_str.encode();
362 let limited_str = LimitedStr::<20>::decode(&mut &encoded_str[..]).unwrap();
363
364 assert_eq!(normal_str, limited_str.as_str());
365 }
366
367 #[test]
368 fn test_too_large_decode_fails() {
369 let bad_str = "amogus attacks again, but this time it's much harder to defeat him";
370 let encoded_str = bad_str.encode();
371
372 LimitedStr::<20>::decode(&mut &encoded_str[..]).expect_err("The string must be too large");
373 }
374}