Skip to main content

gear_core/limited/
str.rs

1// Copyright (C) Gear Technologies Inc.
2// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
3
4//! This module provides type for string with limited length.
5
6use crate::limited::{LimitedVec, private::LimitedVisitor};
7use alloc::{
8    borrow::Cow,
9    string::{FromUtf8Error, String},
10};
11use derive_more::{AsRef, Deref, Display, Into};
12use parity_scale_codec::{Decode, Encode};
13use scale_decode::{
14    IntoVisitor, TypeResolver, Visitor,
15    error::ErrorKind,
16    visitor::{
17        TypeIdFor, Unexpected,
18        types::{Composite, Str, Tuple},
19    },
20};
21use scale_encode::EncodeAsType;
22use scale_info::TypeInfo;
23
24/// Wrapped string to fit given amount of bytes.
25///
26/// The [`Cow`] is used to avoid allocating a new `String` when
27/// the `LimitedStr` is created from a `&str`.
28///
29/// Plain [`str`] is not used because it can't be properly
30/// encoded/decoded via scale codec.
31#[derive(
32    Debug,
33    Display,
34    Clone,
35    Default,
36    PartialEq,
37    Eq,
38    PartialOrd,
39    Ord,
40    Encode,
41    EncodeAsType,
42    Hash,
43    TypeInfo,
44    AsRef,
45    Deref,
46    Into,
47)]
48#[as_ref(forward)]
49#[deref(forward)]
50pub struct LimitedStr<'a, const N: usize = 1024>(Cow<'a, str>);
51
52/// Finds the left-nearest UTF-8 character boundary
53/// to given position in the string.
54fn nearest_char_boundary(s: &str, pos: usize) -> usize {
55    (0..=pos.min(s.len()))
56        .rev()
57        .find(|&pos| s.is_char_boundary(pos))
58        .unwrap_or(0)
59}
60
61impl<'a, const N: usize> LimitedStr<'a, N> {
62    /// Maximum length of the string.
63    pub const MAX_LEN: usize = N;
64
65    /// Constructs a limited string from a limited
66    /// vector of bytes of the same size.
67    pub fn from_utf8(vec: LimitedVec<u8, N>) -> Result<Self, FromUtf8Error> {
68        String::from_utf8(vec.into_vec()).map(Cow::Owned).map(Self)
69    }
70
71    /// Constructs a limited string from a string.
72    ///
73    /// Checks the size of the string.
74    pub fn try_new<S: Into<Cow<'a, str>>>(s: S) -> Result<Self, LimitedStrError> {
75        let s = s.into();
76
77        if s.len() > Self::MAX_LEN {
78            Err(LimitedStrError)
79        } else {
80            Ok(Self(s))
81        }
82    }
83
84    /// Constructs a limited string from a string
85    /// truncating it if it's too long.
86    pub fn truncated<S: Into<Cow<'a, str>>>(s: S) -> Self {
87        let s = s.into();
88        let truncation_pos = nearest_char_boundary(&s, Self::MAX_LEN);
89
90        match s {
91            Cow::Borrowed(s) => Self(s[..truncation_pos].into()),
92            Cow::Owned(mut s) => {
93                s.truncate(truncation_pos);
94                Self(s.into())
95            }
96        }
97    }
98
99    /// Constructs a limited string from a static
100    /// string literal small enough to fit the limit.
101    ///
102    /// Should be used only with static string literals.
103    /// In that case it can check the string length
104    /// in compile time.
105    ///
106    /// # Panics
107    ///
108    /// Can panic in runtime if the passed string is
109    /// not a static string literal and is too long.
110    #[track_caller]
111    pub const fn from_small_str(s: &'static str) -> Self {
112        if s.len() > Self::MAX_LEN {
113            panic!("{}", LimitedStrError::MESSAGE)
114        }
115
116        Self(Cow::Borrowed(s))
117    }
118
119    /// Return string slice.
120    pub fn as_str(&self) -> &str {
121        self.as_ref()
122    }
123
124    /// Return inner value.
125    pub fn into_inner(self) -> Cow<'a, str> {
126        self.0
127    }
128}
129
130impl<'a, const N: usize> TryFrom<&'a str> for LimitedStr<'a, N> {
131    type Error = LimitedStrError;
132
133    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
134        Self::try_new(value)
135    }
136}
137
138impl<'a, const N: usize> TryFrom<String> for LimitedStr<'a, N> {
139    type Error = LimitedStrError;
140
141    fn try_from(value: String) -> Result<Self, Self::Error> {
142        Self::try_new(value)
143    }
144}
145
146impl<'a, const N: usize> Decode for LimitedStr<'a, N> {
147    fn decode<I: parity_scale_codec::Input>(
148        input: &mut I,
149    ) -> Result<Self, parity_scale_codec::Error> {
150        LimitedVec::decode(input)
151            .and_then(|vec| Self::from_utf8(vec).map_err(|_| "Invalid UTF-8 sequence".into()))
152    }
153}
154
155impl<'a, Resolver, const N: usize> Visitor for LimitedVisitor<LimitedStr<'a, N>, Resolver>
156where
157    Resolver: TypeResolver,
158{
159    type Value<'scale, 'resolver> = LimitedStr<'a, N>;
160    type Error = scale_decode::Error;
161    type TypeResolver = Resolver;
162
163    fn visit_str<'scale, 'resolver>(
164        self,
165        value: &mut Str<'scale>,
166        type_id: TypeIdFor<Self>,
167    ) -> Result<Self::Value<'scale, 'resolver>, Self::Error> {
168        if value.len() > N {
169            return Err(scale_decode::Error::new(ErrorKind::WrongLength {
170                actual_len: value.len(),
171                expected_len: N,
172            }));
173        }
174
175        String::into_visitor::<Resolver>()
176            .visit_str(value, type_id)
177            .map(Cow::Owned)
178            .map(LimitedStr)
179    }
180
181    fn visit_composite<'scale, 'resolver>(
182        self,
183        value: &mut Composite<'scale, 'resolver, Resolver>,
184        _type_id: TypeIdFor<Self>,
185    ) -> Result<Self::Value<'scale, 'resolver>, Self::Error> {
186        if value.remaining() != 1 {
187            return self.visit_unexpected(Unexpected::Composite);
188        }
189
190        value.decode_item(self).unwrap()
191    }
192
193    fn visit_tuple<'scale, 'resolver>(
194        self,
195        value: &mut Tuple<'scale, 'resolver, Resolver>,
196        _type_id: TypeIdFor<Self>,
197    ) -> Result<Self::Value<'scale, 'resolver>, Self::Error> {
198        if value.remaining() != 1 {
199            return self.visit_unexpected(Unexpected::Tuple);
200        }
201        value.decode_item(self).unwrap()
202    }
203}
204
205/// The error type returned when a conversion from `&str` to [`LimitedStr`] fails.
206#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Display)]
207#[display("{}", Self::MESSAGE)]
208pub struct LimitedStrError;
209
210impl LimitedStrError {
211    /// Static error message.
212    pub const MESSAGE: &str = "string length limit is exceeded";
213
214    /// Converts the error into a static error message.
215    pub const fn as_str(&self) -> &'static str {
216        Self::MESSAGE
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use rand::{Rng, distributions::Standard};
224
225    fn assert_result(string: &'static str, max_bytes: usize, expectation: &'static str) {
226        let string = &string[..nearest_char_boundary(string, max_bytes)];
227        assert_eq!(string, expectation);
228    }
229
230    fn check_panicking(initial_string: &'static str, upper_boundary: usize) {
231        let initial_size = initial_string.len();
232
233        for max_bytes in 0..=upper_boundary {
234            let string = &initial_string[..nearest_char_boundary(initial_string, max_bytes)];
235
236            // Extra check just for confidence.
237            if max_bytes >= initial_size {
238                assert_eq!(string, initial_string);
239            }
240        }
241    }
242
243    #[test]
244    fn truncate_test() {
245        // String for demonstration with UTF_8 encoding.
246        let utf_8 = "hello";
247        // Length in bytes.
248        assert_eq!(utf_8.len(), 5);
249        // Length in chars.
250        assert_eq!(utf_8.chars().count(), 5);
251
252        // Check that `smart_truncate` never panics.
253        //
254        // It calls the `smart_truncate` with `max_bytes` arg in 0..= len * 2.
255        check_panicking(utf_8, utf_8.len().saturating_mul(2));
256
257        // Asserting results.
258        assert_result(utf_8, 0, "");
259        assert_result(utf_8, 1, "h");
260        assert_result(utf_8, 2, "he");
261        assert_result(utf_8, 3, "hel");
262        assert_result(utf_8, 4, "hell");
263        assert_result(utf_8, 5, "hello");
264        assert_result(utf_8, 6, "hello");
265
266        // String for demonstration with CJK encoding.
267        let cjk = "你好吗";
268        // Length in bytes.
269        assert_eq!(cjk.len(), 9);
270        // Length in chars.
271        assert_eq!(cjk.chars().count(), 3);
272        // Byte length of each char.
273        assert!(cjk.chars().all(|c| c.len_utf8() == 3));
274
275        // Check that `smart_truncate` never panics.
276        //
277        // It calls the `smart_truncate` with `max_bytes` arg in 0..= len * 2.
278        check_panicking(cjk, cjk.len().saturating_mul(2));
279
280        // Asserting results.
281        assert_result(cjk, 0, "");
282        assert_result(cjk, 1, "");
283        assert_result(cjk, 2, "");
284        assert_result(cjk, 3, "你");
285        assert_result(cjk, 4, "你");
286        assert_result(cjk, 5, "你");
287        assert_result(cjk, 6, "你好");
288        assert_result(cjk, 7, "你好");
289        assert_result(cjk, 8, "你好");
290        assert_result(cjk, 9, "你好吗");
291        assert_result(cjk, 10, "你好吗");
292
293        // String for demonstration with mixed CJK and UTF-8 encoding.
294        // Chaotic sum of "hello" and "你好吗".
295        // Length in bytes.
296        let mix = "你he好l吗lo";
297        assert_eq!(mix.len(), utf_8.len() + cjk.len());
298        assert_eq!(mix.len(), 14);
299        // Length in chars.
300        assert_eq!(
301            mix.chars().count(),
302            utf_8.chars().count() + cjk.chars().count()
303        );
304        assert_eq!(mix.chars().count(), 8);
305
306        // Check that `smart_truncate` never panics.
307        //
308        // It calls the `smart_truncate` with `max_bytes` arg in 0..= len * 2.
309        check_panicking(mix, mix.len().saturating_mul(2));
310
311        // Asserting results.
312        assert_result(mix, 0, "");
313        assert_result(mix, 1, "");
314        assert_result(mix, 2, "");
315        assert_result(mix, 3, "你");
316        assert_result(mix, 4, "你h");
317        assert_result(mix, 5, "你he");
318        assert_result(mix, 6, "你he");
319        assert_result(mix, 7, "你he");
320        assert_result(mix, 8, "你he好");
321        assert_result(mix, 9, "你he好l");
322        assert_result(mix, 10, "你he好l");
323        assert_result(mix, 11, "你he好l");
324        assert_result(mix, 12, "你he好l吗");
325        assert_result(mix, 13, "你he好l吗l");
326        assert_result(mix, 14, "你he好l吗lo");
327        assert_result(mix, 15, "你he好l吗lo");
328
329        assert_eq!(LimitedStr::<1>::truncated(String::from(mix)).as_str(), "");
330        assert_eq!(LimitedStr::<5>::truncated(mix).as_str(), "你he");
331        assert_eq!(
332            LimitedStr::<9>::truncated(String::from(mix)).as_str(),
333            "你he好l"
334        );
335        assert_eq!(LimitedStr::<13>::truncated(mix).as_str(), "你he好l吗l");
336    }
337
338    #[test]
339    fn truncate_test_fuzz() {
340        for _ in 0..50 {
341            let mut thread_rng = rand::thread_rng();
342
343            let rand_len = thread_rng.gen_range(0..=100_000);
344            let max_bytes = thread_rng.gen_range(0..=rand_len);
345            let mut string = thread_rng
346                .sample_iter::<char, _>(Standard)
347                .take(rand_len)
348                .collect::<String>();
349            string.truncate(nearest_char_boundary(&string, max_bytes));
350
351            if string.len() > max_bytes {
352                panic!("String '{}' input invalidated algorithms property", string);
353            }
354        }
355    }
356
357    #[test]
358    fn test_decode() {
359        // Limited string is encoded just like a normal string
360        let normal_str = "amogus attacks";
361        let encoded_str = normal_str.encode();
362        let limited_str = LimitedStr::<20>::decode(&mut &encoded_str[..]).unwrap();
363
364        assert_eq!(normal_str, limited_str.as_str());
365    }
366
367    #[test]
368    fn test_too_large_decode_fails() {
369        let bad_str = "amogus attacks again, but this time it's much harder to defeat him";
370        let encoded_str = bad_str.encode();
371
372        LimitedStr::<20>::decode(&mut &encoded_str[..]).expect_err("The string must be too large");
373    }
374}