gear-core 2.0.0-pre.1

Gear core library
Documentation
// Copyright (C) Gear Technologies Inc.
// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0

//! This module provides type for string with limited length.

use crate::limited::{LimitedVec, private::LimitedVisitor};
use alloc::{
    borrow::Cow,
    string::{FromUtf8Error, String},
};
use derive_more::{AsRef, Deref, Display, Into};
use parity_scale_codec::{Decode, Encode};
use scale_decode::{
    IntoVisitor, TypeResolver, Visitor,
    error::ErrorKind,
    visitor::{
        TypeIdFor, Unexpected,
        types::{Composite, Str, Tuple},
    },
};
use scale_encode::EncodeAsType;
use scale_info::TypeInfo;

/// Wrapped string to fit given amount of bytes.
///
/// The [`Cow`] is used to avoid allocating a new `String` when
/// the `LimitedStr` is created from a `&str`.
///
/// Plain [`str`] is not used because it can't be properly
/// encoded/decoded via scale codec.
#[derive(
    Debug,
    Display,
    Clone,
    Default,
    PartialEq,
    Eq,
    PartialOrd,
    Ord,
    Encode,
    EncodeAsType,
    Hash,
    TypeInfo,
    AsRef,
    Deref,
    Into,
)]
#[as_ref(forward)]
#[deref(forward)]
pub struct LimitedStr<'a, const N: usize = 1024>(Cow<'a, str>);

/// Finds the left-nearest UTF-8 character boundary
/// to given position in the string.
fn nearest_char_boundary(s: &str, pos: usize) -> usize {
    (0..=pos.min(s.len()))
        .rev()
        .find(|&pos| s.is_char_boundary(pos))
        .unwrap_or(0)
}

impl<'a, const N: usize> LimitedStr<'a, N> {
    /// Maximum length of the string.
    pub const MAX_LEN: usize = N;

    /// Constructs a limited string from a limited
    /// vector of bytes of the same size.
    pub fn from_utf8(vec: LimitedVec<u8, N>) -> Result<Self, FromUtf8Error> {
        String::from_utf8(vec.into_vec()).map(Cow::Owned).map(Self)
    }

    /// Constructs a limited string from a string.
    ///
    /// Checks the size of the string.
    pub fn try_new<S: Into<Cow<'a, str>>>(s: S) -> Result<Self, LimitedStrError> {
        let s = s.into();

        if s.len() > Self::MAX_LEN {
            Err(LimitedStrError)
        } else {
            Ok(Self(s))
        }
    }

    /// Constructs a limited string from a string
    /// truncating it if it's too long.
    pub fn truncated<S: Into<Cow<'a, str>>>(s: S) -> Self {
        let s = s.into();
        let truncation_pos = nearest_char_boundary(&s, Self::MAX_LEN);

        match s {
            Cow::Borrowed(s) => Self(s[..truncation_pos].into()),
            Cow::Owned(mut s) => {
                s.truncate(truncation_pos);
                Self(s.into())
            }
        }
    }

    /// Constructs a limited string from a static
    /// string literal small enough to fit the limit.
    ///
    /// Should be used only with static string literals.
    /// In that case it can check the string length
    /// in compile time.
    ///
    /// # Panics
    ///
    /// Can panic in runtime if the passed string is
    /// not a static string literal and is too long.
    #[track_caller]
    pub const fn from_small_str(s: &'static str) -> Self {
        if s.len() > Self::MAX_LEN {
            panic!("{}", LimitedStrError::MESSAGE)
        }

        Self(Cow::Borrowed(s))
    }

    /// Return string slice.
    pub fn as_str(&self) -> &str {
        self.as_ref()
    }

    /// Return inner value.
    pub fn into_inner(self) -> Cow<'a, str> {
        self.0
    }
}

impl<'a, const N: usize> TryFrom<&'a str> for LimitedStr<'a, N> {
    type Error = LimitedStrError;

    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
        Self::try_new(value)
    }
}

impl<'a, const N: usize> TryFrom<String> for LimitedStr<'a, N> {
    type Error = LimitedStrError;

    fn try_from(value: String) -> Result<Self, Self::Error> {
        Self::try_new(value)
    }
}

impl<'a, const N: usize> Decode for LimitedStr<'a, N> {
    fn decode<I: parity_scale_codec::Input>(
        input: &mut I,
    ) -> Result<Self, parity_scale_codec::Error> {
        LimitedVec::decode(input)
            .and_then(|vec| Self::from_utf8(vec).map_err(|_| "Invalid UTF-8 sequence".into()))
    }
}

impl<'a, Resolver, const N: usize> Visitor for LimitedVisitor<LimitedStr<'a, N>, Resolver>
where
    Resolver: TypeResolver,
{
    type Value<'scale, 'resolver> = LimitedStr<'a, N>;
    type Error = scale_decode::Error;
    type TypeResolver = Resolver;

    fn visit_str<'scale, 'resolver>(
        self,
        value: &mut Str<'scale>,
        type_id: TypeIdFor<Self>,
    ) -> Result<Self::Value<'scale, 'resolver>, Self::Error> {
        if value.len() > N {
            return Err(scale_decode::Error::new(ErrorKind::WrongLength {
                actual_len: value.len(),
                expected_len: N,
            }));
        }

        String::into_visitor::<Resolver>()
            .visit_str(value, type_id)
            .map(Cow::Owned)
            .map(LimitedStr)
    }

    fn visit_composite<'scale, 'resolver>(
        self,
        value: &mut Composite<'scale, 'resolver, Resolver>,
        _type_id: TypeIdFor<Self>,
    ) -> Result<Self::Value<'scale, 'resolver>, Self::Error> {
        if value.remaining() != 1 {
            return self.visit_unexpected(Unexpected::Composite);
        }

        value.decode_item(self).unwrap()
    }

    fn visit_tuple<'scale, 'resolver>(
        self,
        value: &mut Tuple<'scale, 'resolver, Resolver>,
        _type_id: TypeIdFor<Self>,
    ) -> Result<Self::Value<'scale, 'resolver>, Self::Error> {
        if value.remaining() != 1 {
            return self.visit_unexpected(Unexpected::Tuple);
        }
        value.decode_item(self).unwrap()
    }
}

/// The error type returned when a conversion from `&str` to [`LimitedStr`] fails.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Display)]
#[display("{}", Self::MESSAGE)]
pub struct LimitedStrError;

impl LimitedStrError {
    /// Static error message.
    pub const MESSAGE: &str = "string length limit is exceeded";

    /// Converts the error into a static error message.
    pub const fn as_str(&self) -> &'static str {
        Self::MESSAGE
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use rand::{Rng, distributions::Standard};

    fn assert_result(string: &'static str, max_bytes: usize, expectation: &'static str) {
        let string = &string[..nearest_char_boundary(string, max_bytes)];
        assert_eq!(string, expectation);
    }

    fn check_panicking(initial_string: &'static str, upper_boundary: usize) {
        let initial_size = initial_string.len();

        for max_bytes in 0..=upper_boundary {
            let string = &initial_string[..nearest_char_boundary(initial_string, max_bytes)];

            // Extra check just for confidence.
            if max_bytes >= initial_size {
                assert_eq!(string, initial_string);
            }
        }
    }

    #[test]
    fn truncate_test() {
        // String for demonstration with UTF_8 encoding.
        let utf_8 = "hello";
        // Length in bytes.
        assert_eq!(utf_8.len(), 5);
        // Length in chars.
        assert_eq!(utf_8.chars().count(), 5);

        // Check that `smart_truncate` never panics.
        //
        // It calls the `smart_truncate` with `max_bytes` arg in 0..= len * 2.
        check_panicking(utf_8, utf_8.len().saturating_mul(2));

        // Asserting results.
        assert_result(utf_8, 0, "");
        assert_result(utf_8, 1, "h");
        assert_result(utf_8, 2, "he");
        assert_result(utf_8, 3, "hel");
        assert_result(utf_8, 4, "hell");
        assert_result(utf_8, 5, "hello");
        assert_result(utf_8, 6, "hello");

        // String for demonstration with CJK encoding.
        let cjk = "你好吗";
        // Length in bytes.
        assert_eq!(cjk.len(), 9);
        // Length in chars.
        assert_eq!(cjk.chars().count(), 3);
        // Byte length of each char.
        assert!(cjk.chars().all(|c| c.len_utf8() == 3));

        // Check that `smart_truncate` never panics.
        //
        // It calls the `smart_truncate` with `max_bytes` arg in 0..= len * 2.
        check_panicking(cjk, cjk.len().saturating_mul(2));

        // Asserting results.
        assert_result(cjk, 0, "");
        assert_result(cjk, 1, "");
        assert_result(cjk, 2, "");
        assert_result(cjk, 3, "");
        assert_result(cjk, 4, "");
        assert_result(cjk, 5, "");
        assert_result(cjk, 6, "你好");
        assert_result(cjk, 7, "你好");
        assert_result(cjk, 8, "你好");
        assert_result(cjk, 9, "你好吗");
        assert_result(cjk, 10, "你好吗");

        // String for demonstration with mixed CJK and UTF-8 encoding.
        // Chaotic sum of "hello" and "你好吗".
        // Length in bytes.
        let mix = "你he好l吗lo";
        assert_eq!(mix.len(), utf_8.len() + cjk.len());
        assert_eq!(mix.len(), 14);
        // Length in chars.
        assert_eq!(
            mix.chars().count(),
            utf_8.chars().count() + cjk.chars().count()
        );
        assert_eq!(mix.chars().count(), 8);

        // Check that `smart_truncate` never panics.
        //
        // It calls the `smart_truncate` with `max_bytes` arg in 0..= len * 2.
        check_panicking(mix, mix.len().saturating_mul(2));

        // Asserting results.
        assert_result(mix, 0, "");
        assert_result(mix, 1, "");
        assert_result(mix, 2, "");
        assert_result(mix, 3, "");
        assert_result(mix, 4, "你h");
        assert_result(mix, 5, "你he");
        assert_result(mix, 6, "你he");
        assert_result(mix, 7, "你he");
        assert_result(mix, 8, "你he好");
        assert_result(mix, 9, "你he好l");
        assert_result(mix, 10, "你he好l");
        assert_result(mix, 11, "你he好l");
        assert_result(mix, 12, "你he好l吗");
        assert_result(mix, 13, "你he好l吗l");
        assert_result(mix, 14, "你he好l吗lo");
        assert_result(mix, 15, "你he好l吗lo");

        assert_eq!(LimitedStr::<1>::truncated(String::from(mix)).as_str(), "");
        assert_eq!(LimitedStr::<5>::truncated(mix).as_str(), "你he");
        assert_eq!(
            LimitedStr::<9>::truncated(String::from(mix)).as_str(),
            "你he好l"
        );
        assert_eq!(LimitedStr::<13>::truncated(mix).as_str(), "你he好l吗l");
    }

    #[test]
    fn truncate_test_fuzz() {
        for _ in 0..50 {
            let mut thread_rng = rand::thread_rng();

            let rand_len = thread_rng.gen_range(0..=100_000);
            let max_bytes = thread_rng.gen_range(0..=rand_len);
            let mut string = thread_rng
                .sample_iter::<char, _>(Standard)
                .take(rand_len)
                .collect::<String>();
            string.truncate(nearest_char_boundary(&string, max_bytes));

            if string.len() > max_bytes {
                panic!("String '{}' input invalidated algorithms property", string);
            }
        }
    }

    #[test]
    fn test_decode() {
        // Limited string is encoded just like a normal string
        let normal_str = "amogus attacks";
        let encoded_str = normal_str.encode();
        let limited_str = LimitedStr::<20>::decode(&mut &encoded_str[..]).unwrap();

        assert_eq!(normal_str, limited_str.as_str());
    }

    #[test]
    fn test_too_large_decode_fails() {
        let bad_str = "amogus attacks again, but this time it's much harder to defeat him";
        let encoded_str = bad_str.encode();

        LimitedStr::<20>::decode(&mut &encoded_str[..]).expect_err("The string must be too large");
    }
}