use alloc::{borrow::Cow, string::String};
use derive_more::{AsRef, Deref, Display, Into};
use parity_scale_codec::{Decode, Encode};
use scale_decode::DecodeAsType;
use scale_encode::EncodeAsType;
use scale_info::TypeInfo;
#[derive(
Debug,
Display,
Clone,
Default,
PartialEq,
Eq,
PartialOrd,
Ord,
Decode,
DecodeAsType,
Encode,
EncodeAsType,
Hash,
TypeInfo,
AsRef,
Deref,
Into,
)]
#[as_ref(forward)]
#[deref(forward)]
pub struct LimitedStr<'a, const N: usize = 1024>(Cow<'a, str>);
fn nearest_char_boundary(s: &str, pos: usize) -> usize {
(0..=pos.min(s.len()))
.rev()
.find(|&pos| s.is_char_boundary(pos))
.unwrap_or(0)
}
impl<'a, const N: usize> LimitedStr<'a, N> {
pub const MAX_LEN: usize = N;
pub fn try_new<S: Into<Cow<'a, str>>>(s: S) -> Result<Self, LimitedStrError> {
let s = s.into();
if s.len() > Self::MAX_LEN {
Err(LimitedStrError)
} else {
Ok(Self(s))
}
}
pub fn truncated<S: Into<Cow<'a, str>>>(s: S) -> Self {
let s = s.into();
let truncation_pos = nearest_char_boundary(&s, Self::MAX_LEN);
match s {
Cow::Borrowed(s) => Self(s[..truncation_pos].into()),
Cow::Owned(mut s) => {
s.truncate(truncation_pos);
Self(s.into())
}
}
}
#[track_caller]
pub const fn from_small_str(s: &'static str) -> Self {
if s.len() > Self::MAX_LEN {
panic!("{}", LimitedStrError::MESSAGE)
}
Self(Cow::Borrowed(s))
}
pub fn as_str(&self) -> &str {
self.as_ref()
}
pub fn into_inner(self) -> Cow<'a, str> {
self.0
}
}
impl<'a> TryFrom<&'a str> for LimitedStr<'a> {
type Error = LimitedStrError;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
Self::try_new(value)
}
}
impl<'a> TryFrom<String> for LimitedStr<'a> {
type Error = LimitedStrError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_new(value)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Display)]
#[display("{}", Self::MESSAGE)]
pub struct LimitedStrError;
impl LimitedStrError {
pub const MESSAGE: &str = "string length limit is exceeded";
pub const fn as_str(&self) -> &'static str {
Self::MESSAGE
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{Rng, distributions::Standard};
fn assert_result(string: &'static str, max_bytes: usize, expectation: &'static str) {
let string = &string[..nearest_char_boundary(string, max_bytes)];
assert_eq!(string, expectation);
}
fn check_panicking(initial_string: &'static str, upper_boundary: usize) {
let initial_size = initial_string.len();
for max_bytes in 0..=upper_boundary {
let string = &initial_string[..nearest_char_boundary(initial_string, max_bytes)];
if max_bytes >= initial_size {
assert_eq!(string, initial_string);
}
}
}
#[test]
fn truncate_test() {
let utf_8 = "hello";
assert_eq!(utf_8.len(), 5);
assert_eq!(utf_8.chars().count(), 5);
check_panicking(utf_8, utf_8.len().saturating_mul(2));
assert_result(utf_8, 0, "");
assert_result(utf_8, 1, "h");
assert_result(utf_8, 2, "he");
assert_result(utf_8, 3, "hel");
assert_result(utf_8, 4, "hell");
assert_result(utf_8, 5, "hello");
assert_result(utf_8, 6, "hello");
let cjk = "你好吗";
assert_eq!(cjk.len(), 9);
assert_eq!(cjk.chars().count(), 3);
assert!(cjk.chars().all(|c| c.len_utf8() == 3));
check_panicking(cjk, cjk.len().saturating_mul(2));
assert_result(cjk, 0, "");
assert_result(cjk, 1, "");
assert_result(cjk, 2, "");
assert_result(cjk, 3, "你");
assert_result(cjk, 4, "你");
assert_result(cjk, 5, "你");
assert_result(cjk, 6, "你好");
assert_result(cjk, 7, "你好");
assert_result(cjk, 8, "你好");
assert_result(cjk, 9, "你好吗");
assert_result(cjk, 10, "你好吗");
let mix = "你he好l吗lo";
assert_eq!(mix.len(), utf_8.len() + cjk.len());
assert_eq!(mix.len(), 14);
assert_eq!(
mix.chars().count(),
utf_8.chars().count() + cjk.chars().count()
);
assert_eq!(mix.chars().count(), 8);
check_panicking(mix, mix.len().saturating_mul(2));
assert_result(mix, 0, "");
assert_result(mix, 1, "");
assert_result(mix, 2, "");
assert_result(mix, 3, "你");
assert_result(mix, 4, "你h");
assert_result(mix, 5, "你he");
assert_result(mix, 6, "你he");
assert_result(mix, 7, "你he");
assert_result(mix, 8, "你he好");
assert_result(mix, 9, "你he好l");
assert_result(mix, 10, "你he好l");
assert_result(mix, 11, "你he好l");
assert_result(mix, 12, "你he好l吗");
assert_result(mix, 13, "你he好l吗l");
assert_result(mix, 14, "你he好l吗lo");
assert_result(mix, 15, "你he好l吗lo");
assert_eq!(LimitedStr::<1>::truncated(String::from(mix)).as_str(), "");
assert_eq!(LimitedStr::<5>::truncated(mix).as_str(), "你he");
assert_eq!(
LimitedStr::<9>::truncated(String::from(mix)).as_str(),
"你he好l"
);
assert_eq!(LimitedStr::<13>::truncated(mix).as_str(), "你he好l吗l");
}
#[test]
fn truncate_test_fuzz() {
for _ in 0..50 {
let mut thread_rng = rand::thread_rng();
let rand_len = thread_rng.gen_range(0..=100_000);
let max_bytes = thread_rng.gen_range(0..=rand_len);
let mut string = thread_rng
.sample_iter::<char, _>(Standard)
.take(rand_len)
.collect::<String>();
string.truncate(nearest_char_boundary(&string, max_bytes));
if string.len() > max_bytes {
panic!("String '{}' input invalidated algorithms property", string);
}
}
}
}