1use std::{fmt, marker::PhantomData, str::FromStr};
2
3#[derive(Clone, Copy, Debug, thiserror::Error)]
4pub enum StringError {
5 #[error("invalid length `{found_len}` max length: {max_len}")]
6 InvalidLength { max_len: usize, found_len: usize },
7
8 #[error("contained non ascii char")]
9 NonAsciiChar(char),
10
11 #[error("only printable(32-127) ascii characters allowed. Found {0:?}")]
12 InvalidChar(char),
13}
14
15trait StringVariant {
16 fn validate(s: &str) -> Result<(), StringError>;
17}
18
19#[derive(Clone, Debug, PartialEq, Eq)]
20pub struct SensitiveUtf8;
21
22impl StringVariant for SensitiveUtf8 {
23 fn validate(_: &str) -> Result<(), StringError> {
24 Ok(())
25 }
26}
27
28#[derive(Clone, Debug, PartialEq, Eq)]
29pub struct InsensitiveAscii;
30impl StringVariant for InsensitiveAscii {
31 fn validate(s: &str) -> Result<(), StringError> {
32 if let Some(c) = s.chars().find(|c| !c.is_ascii()) {
33 return Err(StringError::NonAsciiChar(c));
34 }
35
36 if let Some(b) = s.as_bytes().iter().find(|c| !(32u8..127).contains(c)) {
37 return Err(StringError::InvalidChar(*b as char));
38 }
39
40 Ok(())
41 }
42}
43
44#[derive(Clone, PartialEq, Eq)]
45pub struct BaseString<V, const N: usize> {
46 s: String,
47 _marker: PhantomData<V>,
48}
49
50impl<V, const N: usize> fmt::Debug for BaseString<V, N> {
51 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
52 write!(f, "BaseString<{}>({:?})", N, self.s)
53 }
54}
55
56pub type CsString<const N: usize> = BaseString<SensitiveUtf8, N>;
64
65pub type CiString<const N: usize> = BaseString<InsensitiveAscii, N>;
72
73impl<V, const N: usize> FromStr for BaseString<V, N>
74where
75 V: StringVariant,
76{
77 type Err = StringError;
78
79 fn from_str(s: &str) -> Result<Self, Self::Err> {
80 if N < s.len() {
81 return Err(StringError::InvalidLength {
82 max_len: N,
83 found_len: s.len(),
84 });
85 }
86
87 V::validate(s)?;
88
89 Ok(BaseString {
90 s: s.into(),
91 _marker: PhantomData::default(),
92 })
93 }
94}
95
96impl<V, const N: usize> fmt::Display for BaseString<V, N> {
97 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98 f.write_str(&self.s)
99 }
100}
101
102impl<V, const N: usize> BaseString<V, N> {
103 pub fn new(s: impl Into<String>) -> Self {
105 let s = s.into();
106
107 if s.len() <= N {
108 Self {
109 s,
110 _marker: PhantomData::default(),
111 }
112 } else {
113 panic!("String to long");
114 }
115 }
116}
117
118impl<V, const N: usize> BaseString<V, N> {
119 pub fn as_str(&self) -> &str {
120 self.s.as_str()
121 }
122}
123
124impl<V, const N: usize> TryFrom<String> for BaseString<V, N>
125where
126 V: StringVariant,
127{
128 type Error = StringError;
129
130 fn try_from(s: String) -> Result<Self, Self::Error> {
131 V::validate(&s)?;
132 Ok(BaseString {
133 s,
134 _marker: PhantomData::default(),
135 })
136 }
137}
138
139impl<V, const N: usize> From<BaseString<V, N>> for String {
140 fn from(s: BaseString<V, N>) -> String {
141 s.s
142 }
143}
144
145impl<V, const N: usize> AsRef<str> for BaseString<V, N> {
146 fn as_ref(&self) -> &str {
147 &self.s
148 }
149}
150
151impl<V, const N: usize> PartialEq<&str> for BaseString<V, N> {
152 fn eq(&self, other: &&str) -> bool {
153 self.s.as_str().eq(*other)
154 }
155}
156
157impl<V, const N: usize> serde::Serialize for BaseString<V, N> {
158 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
159 where
160 S: serde::Serializer,
161 {
162 serializer.serialize_str(self.as_str())
163 }
164}
165
166impl<'de, V, const N: usize> serde::de::Deserialize<'de> for BaseString<V, N>
167where
168 V: StringVariant,
169{
170 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
171 where
172 D: serde::de::Deserializer<'de>,
173 {
174 deserializer.deserialize_str(BaseStringVisitor(PhantomData::default()))
175 }
176}
177
178struct BaseStringVisitor<V, const N: usize>(std::marker::PhantomData<V>);
179
180impl<'de, V, const N: usize> serde::de::Visitor<'de> for BaseStringVisitor<V, N>
181where
182 V: StringVariant,
183{
184 type Value = BaseString<V, N>;
185
186 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
187 formatter.write_str("an ascii printable string (32-127)")
188 }
189
190 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
191 where
192 E: serde::de::Error,
193 {
194 value.parse::<BaseString<V, N>>().map_err(E::custom)
195 }
196}