1use crate::{error::parser::StrParseError, primitives::LOG_TARGET, Error};
24
25use bytes::{BufMut, BytesMut};
26use nom::{bytes::complete::take, number::complete::be_u8, Err, IResult};
27
28use alloc::{borrow::ToOwned, string::String, sync::Arc, vec::Vec};
29use core::{
30 fmt,
31 hash::{Hash, Hasher},
32 ops,
33 str::FromStr,
34};
35
36#[derive(Debug, Clone)]
38pub enum Str {
39 Static(&'static str),
40 Allocated(Arc<str>),
41}
42
43impl From<&'static str> for Str {
44 fn from(protocol: &'static str) -> Self {
45 Str::Static(protocol)
46 }
47}
48
49impl fmt::Display for Str {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 match self {
52 Self::Static(protocol) => protocol.fmt(f),
53 Self::Allocated(protocol) => protocol.fmt(f),
54 }
55 }
56}
57
58impl From<String> for Str {
59 fn from(protocol: String) -> Self {
60 Str::Allocated(Arc::from(protocol))
61 }
62}
63
64impl From<Arc<str>> for Str {
65 fn from(protocol: Arc<str>) -> Self {
66 Self::Allocated(protocol)
67 }
68}
69
70impl TryFrom<&[u8]> for Str {
71 type Error = StrParseError;
72
73 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
74 let string = core::str::from_utf8(value).map_err(StrParseError::Utf8)?.to_owned();
75
76 Ok(Self::from(string))
77 }
78}
79
80impl FromStr for Str {
81 type Err = Error;
82
83 fn from_str(s: &str) -> Result<Self, Self::Err> {
84 if s.len() > 255 {
85 tracing::warn!(
86 target: LOG_TARGET,
87 len = ?s.len(),
88 "string is too large",
89 );
90 return Err(Error::InvalidData);
91 }
92
93 Ok(Str::from(s.to_owned()))
94 }
95}
96
97impl ops::Deref for Str {
98 type Target = str;
99
100 fn deref(&self) -> &Self::Target {
101 match self {
102 Self::Static(protocol) => protocol,
103 Self::Allocated(protocol) => protocol,
104 }
105 }
106}
107
108impl Hash for Str {
109 fn hash<H: Hasher>(&self, state: &mut H) {
110 (self as &str).hash(state)
111 }
112}
113
114impl PartialEq for Str {
115 fn eq(&self, other: &Self) -> bool {
116 (self as &str) == (other as &str)
117 }
118}
119
120impl Eq for Str {}
121
122impl Str {
123 pub fn serialize(&self) -> Vec<u8> {
125 let mut out = BytesMut::with_capacity(self.len() + 1);
126
127 debug_assert!(self.len() <= u8::MAX as usize);
128 out.put_u8(self.len() as u8);
129 out.put_slice(self.as_bytes());
130
131 out.freeze().to_vec()
132 }
133
134 pub fn parse_frame(input: &[u8]) -> IResult<&[u8], Self, StrParseError> {
136 let (rest, size) = be_u8(input)?;
137 let (rest, string) = take(size)(rest)?;
138 let string = Str::try_from(string).map_err(Err::Error)?;
139
140 Ok((rest, string))
141 }
142
143 pub fn parse(bytes: impl AsRef<[u8]>) -> Result<Str, StrParseError> {
145 Ok(Self::parse_frame(bytes.as_ref())?.1)
146 }
147
148 pub fn serialized_len(&self) -> usize {
150 self.len() + 1
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use std::collections::VecDeque;
158
159 #[test]
160 fn empty_string() {
161 assert_eq!(
162 Str::parse(Vec::new()).unwrap_err(),
163 StrParseError::InvalidBitstream
164 );
165 }
166
167 #[test]
168 fn valid_string() {
169 let mut string: VecDeque<u8> =
170 String::from("hello, world!").as_bytes().to_vec().try_into().unwrap();
171 string.push_front(string.len() as u8);
172 let string: Vec<u8> = string.into();
173
174 assert_eq!(Str::parse(string), Ok(Str::from("hello, world!")),);
175 }
176
177 #[test]
178 fn valid_string_with_extra_bytes() {
179 let mut string: VecDeque<u8> =
180 String::from("hello, world!").as_bytes().to_vec().try_into().unwrap();
181 string.push_front(string.len() as u8);
182 string.push_back(1);
183 string.push_back(2);
184 string.push_back(3);
185 string.push_back(4);
186 let string: Vec<u8> = string.into();
187
188 assert_eq!(Str::parse(string), Ok(Str::from("hello, world!")));
189 }
190
191 #[test]
192 fn extra_bytes_returned() {
193 let mut string: VecDeque<u8> =
194 String::from("hello, world!").as_bytes().to_vec().try_into().unwrap();
195 string.push_front(string.len() as u8);
196 string.push_back(1);
197 string.push_back(2);
198 string.push_back(3);
199 string.push_back(4);
200 let string: Vec<u8> = string.into();
201
202 let (rest, string) = Str::parse_frame(&string).unwrap();
203
204 assert_eq!(string, Str::from("hello, world!"));
205 assert_eq!(rest, [1, 2, 3, 4]);
206 }
207
208 #[test]
209 fn serialize_works() {
210 let bytes = Str::from("hello, world!").serialize();
211
212 assert_eq!(Str::parse(bytes), Ok(Str::from("hello, world!")));
213 }
214
215 #[test]
216 fn contains_substring() {
217 let mut string: VecDeque<u8> =
218 String::from("hello, world!").as_bytes().to_vec().try_into().unwrap();
219 string.push_front(string.len() as u8);
220 let string: Vec<u8> = string.into();
221
222 assert!(Str::parse(string).unwrap().contains("world"));
223 }
224
225 #[test]
226 fn doesnt_contain_substring() {
227 let mut string: VecDeque<u8> =
228 String::from("hello, world!").as_bytes().to_vec().try_into().unwrap();
229 string.push_front(string.len() as u8);
230 let string: Vec<u8> = string.into();
231
232 assert!(!Str::parse(string).unwrap().contains("goodbye"));
233 }
234
235 #[test]
236 fn try_parse_non_utf8() {
237 let string = vec![
238 230, 214, 155, 197, 98, 170, 161, 183, 41, 58, 103, 216, 196, 180, 218, 194, 93, 131,
239 248, 109, 234, 196, 246, 15, 126, 91, 198, 187, 11, 54, 197, 115, 230, 214, 155, 197,
240 98, 170, 161, 183, 41, 58, 103, 216, 196, 180, 218, 194, 93, 131, 248, 109, 234, 196,
241 246, 15, 126, 91, 198, 187, 11, 54, 197, 115, 230, 214, 155, 197, 98, 170, 161, 183,
242 41, 58, 103, 216, 196, 180, 218, 194, 93, 131, 248, 109, 234, 196, 246, 15, 126, 91,
243 198, 187, 11, 54, 197, 115, 230, 214, 155, 197, 98, 170, 161, 183, 41, 58, 103, 216,
244 196, 180, 218, 194, 93, 131, 248, 109, 234, 196, 246, 15, 126, 91, 198, 187, 11, 54,
245 197, 115, 230, 214, 155, 197, 98, 170, 161, 183, 41, 58, 103, 216, 196, 180, 218, 194,
246 93, 131, 248, 109, 234, 196, 246, 15, 126, 91, 198, 187, 11, 54, 197, 115, 230, 214,
247 155, 197, 98, 170, 161, 183, 41, 58, 103, 216, 196, 180, 218, 194, 93, 131, 248, 109,
248 234, 196, 246, 15, 126, 91, 198, 187, 11, 54, 197, 115, 230, 214, 155, 197, 98, 170,
249 161, 183, 41, 58, 103, 216, 196, 180, 218, 194, 93, 131, 248, 109, 234, 196, 246, 15,
250 126, 91, 198, 187, 11, 54, 197, 115, 230, 64, 231, 155, 2, 143, 122, 48, 137, 247, 79,
251 229, 220, 40, 212, 53, 67, 193, 196, 204, 21, 45, 109, 227, 237, 29, 17, 31, 189, 17,
252 189, 195, 40, 5, 0, 4, 0, 7, 0, 0, 102, 216, 119, 64, 2, 88, 0, 0, 0, 0, 2, 0, 4, 0,
253 32, 103, 57, 105, 36, 53, 6, 188, 207, 237, 100, 79, 208, 65, 73, 180, 118, 143, 162,
254 202, 8, 103, 162, 220, 12, 95, 156, 67, 68, 62, 83, 112, 109, 0, 0, 1, 0, 119, 187, 61,
255 243, 159, 159, 198, 178, 65, 81, 148, 19, 78, 105, 92, 175, 190, 170, 136, 62, 19, 45,
256 23, 246, 228, 210, 215, 161, 129, 149, 160, 57, 137, 141, 144, 141, 163, 247, 34, 120,
257 ];
258 let tmp: &[u8] = string.as_ref();
259
260 assert!(Str::try_from(tmp).is_err());
261 }
262
263 #[test]
264 fn try_parse_too_long_str() {
265 assert!(Str::from_str(&"a".repeat(256)).is_err());
266 }
267}