use std::fmt::Display;
use std::str::FromStr;
use borsh::{BorshDeserialize, BorshSerialize};
use thiserror::Error;
use crate::schema::{IndexLinking, Item, Link, Primitive, Schema, UniversalWallet};
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum SchemaStringError {
#[error("String was too long: {length}, maximum: {max}")]
StringTooLong { length: usize, max: usize },
#[error("String contained invalid character: {character}. Only printable ASCII characters are allowed.")]
InvalidCharacter { character: char },
}
#[derive(Default, Hash, Clone, PartialEq, Eq, PartialOrd, Ord, BorshSerialize)]
#[cfg_attr(
feature = "serde",
derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema)
)]
#[cfg_attr(feature = "serde", serde(try_from = "String", into = "String"))]
pub struct SizedSafeString<const MAX_LEN: usize>(String);
pub const DEFAULT_MAX_STRING_LENGTH: usize = 128;
pub type SafeString = SizedSafeString<DEFAULT_MAX_STRING_LENGTH>;
impl<const MAX_LEN: usize> SizedSafeString<MAX_LEN> {
pub fn as_str(&self) -> &str {
&self.0
}
pub const fn max_len(&self) -> usize {
MAX_LEN
}
pub const fn new() -> Self {
Self(String::new())
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn try_push(&mut self, c: char) -> Result<(), SchemaStringError> {
if self.len() >= MAX_LEN {
return Err(SchemaStringError::StringTooLong {
length: self.len() + 1,
max: MAX_LEN,
});
}
if !Self::is_valid_char(c) {
return Err(SchemaStringError::InvalidCharacter { character: c });
}
self.0.push(c);
Ok(())
}
pub const fn is_valid_char(c: char) -> bool {
c.is_ascii() && !c.is_ascii_control()
}
}
impl<const MAX_LEN: usize> BorshDeserialize for SizedSafeString<MAX_LEN> {
fn deserialize_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<Self> {
let len = u32::deserialize_reader(reader)? as usize;
if len > MAX_LEN {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unexpected length of input",
));
}
let mut output = Vec::with_capacity(len);
for _ in 0..len {
output.push(u8::deserialize_reader(reader)?);
}
let string = String::from_utf8(output)
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid UTF-8"))?;
for c in string.chars() {
if !Self::is_valid_char(c) {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid character",
));
}
}
Ok(Self(string))
}
}
impl<const MAX_LEN: usize> TryFrom<String> for SizedSafeString<MAX_LEN> {
type Error = SchemaStringError;
fn try_from(value: String) -> Result<Self, Self::Error> {
if value.len() > MAX_LEN {
return Err(SchemaStringError::StringTooLong {
length: value.len(),
max: MAX_LEN,
});
}
if let Some(invalid_c) = value.chars().find(|c| !Self::is_valid_char(*c)) {
return Err(SchemaStringError::InvalidCharacter {
character: invalid_c,
});
}
Ok(Self(value))
}
}
impl<const MAX_LEN: usize> FromStr for SizedSafeString<MAX_LEN> {
type Err = SchemaStringError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.try_into()
}
}
impl<const MAX_LEN: usize> UniversalWallet for SizedSafeString<MAX_LEN> {
fn scaffold() -> Item<IndexLinking> {
Item::Atom(Primitive::String)
}
fn get_child_links(_schema: &mut Schema) -> Vec<Link> {
Vec::new()
}
}
impl<'a, const MAX_LEN: usize> TryFrom<&'a str> for SizedSafeString<MAX_LEN> {
type Error = SchemaStringError;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
value.to_string().try_into()
}
}
impl<const MAX_LEN: usize> From<SizedSafeString<MAX_LEN>> for String {
fn from(value: SizedSafeString<MAX_LEN>) -> Self {
value.0
}
}
impl<const MAX_LEN: usize> AsRef<[u8]> for SizedSafeString<MAX_LEN> {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl<const MAX_LEN: usize> AsRef<str> for SizedSafeString<MAX_LEN> {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}
impl<const MAX_LEN: usize> std::fmt::Debug for SizedSafeString<MAX_LEN> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl<const MAX_LEN: usize> Display for SizedSafeString<MAX_LEN> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[cfg(test)]
mod tests {
use super::{SafeString, SchemaStringError, SizedSafeString};
#[test]
fn test_sizedsafestring_maxlen() {
let string_good: String = ['a'; 31].iter().collect();
let string_bad: String = ['a'; 32].iter().collect();
let conversion_good = <SizedSafeString<31>>::try_from(string_good);
assert!(conversion_good.is_ok());
let conversion_bad = <SizedSafeString<31>>::try_from(string_bad);
assert_eq!(
conversion_bad,
Err(SchemaStringError::StringTooLong {
length: 32,
max: 31
})
);
}
#[test]
fn test_safestring_default_len() {
let string_good: String = ['a'; 128].iter().collect();
let string_bad: String = ['a'; 129].iter().collect();
let conversion_good = SafeString::try_from(string_good);
assert!(conversion_good.is_ok());
let conversion_bad = SafeString::try_from(string_bad);
assert_eq!(
conversion_bad,
Err(SchemaStringError::StringTooLong {
length: 129,
max: 128
})
);
}
#[test]
fn test_safestring_rejects_nonascii() {
let string = "hello •";
let conversion = SafeString::try_from(string);
assert_eq!(
conversion,
Err(SchemaStringError::InvalidCharacter { character: '•' })
);
}
#[test]
fn test_safestring_rejects_control_chars() {
let string = "hello \n world";
let conversion = SafeString::try_from(string);
assert_eq!(
conversion,
Err(SchemaStringError::InvalidCharacter { character: '\n' })
);
}
#[test]
fn json_deserializing_safestring_accepts_valid() {
let de: SafeString = serde_json::from_str("\"Good string\"").unwrap();
let expected: SafeString = "Good string".try_into().unwrap();
assert_eq!(de, expected);
}
#[test]
fn json_deserializing_safestring_rejects_invalid() {
let de: Result<SafeString, _> = serde_json::from_str("\"Bad•string\"");
assert!(de.is_err());
assert_eq!(
de.unwrap_err().to_string(),
"String contained invalid character: •. Only printable ASCII characters are allowed."
);
}
#[test]
fn test_safe_string_borsh_invalid_char() {
use borsh::{to_vec, BorshDeserialize};
let input = String::from_utf8(vec![b'\n'; 1]).unwrap();
assert_eq!(None, SafeString::try_from(input.clone()).ok());
let encoded = to_vec(&input).unwrap();
let output = SafeString::try_from_slice(&encoded);
assert!(output.is_err());
}
#[test]
fn test_safe_string_borsh_too_long() {
use borsh::{to_vec, BorshDeserialize};
let large_input = String::from_utf8(vec![b'a'; 300]).unwrap();
assert_eq!(None, SafeString::try_from(large_input.clone()).ok());
let encoded = to_vec(&large_input).unwrap();
let output = SafeString::try_from_slice(&encoded);
assert!(output.is_err());
}
#[test]
fn test_safe_string_serde_invalid_char() {
let de: Result<SafeString, _> = serde_json::from_str("\"\\n\"");
assert!(de.is_err());
assert_eq!(
de.unwrap_err().to_string(),
"String contained invalid character: \n. Only printable ASCII characters are allowed."
);
}
#[test]
fn test_safe_string_serde_too_long() {
let large_input = String::from_utf8(vec![b'a'; 300]).unwrap();
let de: Result<SafeString, _> = serde_json::from_str(&format!("\"{large_input}\""));
assert!(de.is_err());
assert_eq!(
de.unwrap_err().to_string(),
"String was too long: 300, maximum: 128"
);
}
}