use std::{fmt, marker::PhantomData, str::FromStr};
#[derive(Clone, Copy, Debug, thiserror::Error)]
pub enum StringError {
#[error("invalid length `{found_len}` max length: {max_len}")]
InvalidLength { max_len: usize, found_len: usize },
#[error("contained non ascii char")]
NonAsciiChar(char),
#[error("only printable(32-127) ascii characters allowed. Found {0:?}")]
InvalidChar(char),
}
trait StringVariant {
fn validate(s: &str) -> Result<(), StringError>;
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SensitiveUtf8;
impl StringVariant for SensitiveUtf8 {
fn validate(_: &str) -> Result<(), StringError> {
Ok(())
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct InsensitiveAscii;
impl StringVariant for InsensitiveAscii {
fn validate(s: &str) -> Result<(), StringError> {
if let Some(c) = s.chars().find(|c| !c.is_ascii()) {
return Err(StringError::NonAsciiChar(c));
}
if let Some(b) = s.as_bytes().iter().find(|c| !(32u8..127).contains(c)) {
return Err(StringError::InvalidChar(*b as char));
}
Ok(())
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct BaseString<V, const N: usize> {
s: String,
_marker: PhantomData<V>,
}
impl<V, const N: usize> fmt::Debug for BaseString<V, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "BaseString<{}>({:?})", N, self.s)
}
}
pub type CsString<const N: usize> = BaseString<SensitiveUtf8, N>;
pub type CiString<const N: usize> = BaseString<InsensitiveAscii, N>;
impl<V, const N: usize> FromStr for BaseString<V, N>
where
V: StringVariant,
{
type Err = StringError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if N < s.len() {
return Err(StringError::InvalidLength {
max_len: N,
found_len: s.len(),
});
}
V::validate(s)?;
Ok(BaseString {
s: s.into(),
_marker: PhantomData::default(),
})
}
}
impl<V, const N: usize> fmt::Display for BaseString<V, N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(&self.s)
}
}
impl<V, const N: usize> BaseString<V, N> {
pub fn new(s: impl Into<String>) -> Self {
let s = s.into();
if s.len() <= N {
Self {
s,
_marker: PhantomData::default(),
}
} else {
panic!("String to long");
}
}
}
impl<V, const N: usize> BaseString<V, N> {
pub fn as_str(&self) -> &str {
self.s.as_str()
}
}
impl<V, const N: usize> TryFrom<String> for BaseString<V, N>
where
V: StringVariant,
{
type Error = StringError;
fn try_from(s: String) -> Result<Self, Self::Error> {
V::validate(&s)?;
Ok(BaseString {
s,
_marker: PhantomData::default(),
})
}
}
impl<V, const N: usize> From<BaseString<V, N>> for String {
fn from(s: BaseString<V, N>) -> String {
s.s
}
}
impl<V, const N: usize> AsRef<str> for BaseString<V, N> {
fn as_ref(&self) -> &str {
&self.s
}
}
impl<V, const N: usize> PartialEq<&str> for BaseString<V, N> {
fn eq(&self, other: &&str) -> bool {
self.s.as_str().eq(*other)
}
}
impl<V, const N: usize> serde::Serialize for BaseString<V, N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(self.as_str())
}
}
impl<'de, V, const N: usize> serde::de::Deserialize<'de> for BaseString<V, N>
where
V: StringVariant,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
deserializer.deserialize_str(BaseStringVisitor(PhantomData::default()))
}
}
struct BaseStringVisitor<V, const N: usize>(std::marker::PhantomData<V>);
impl<'de, V, const N: usize> serde::de::Visitor<'de> for BaseStringVisitor<V, N>
where
V: StringVariant,
{
type Value = BaseString<V, N>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("an ascii printable string (32-127)")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
value.parse::<BaseString<V, N>>().map_err(E::custom)
}
}