1use std::{
4 fmt::{Display, LowerHex, UpperHex, Write},
5 num::ParseIntError,
6 str::FromStr,
7};
8
9use serde::{Deserialize, Serialize};
10
11use super::serde::BytesVisitor;
12
13#[derive(Debug, thiserror::Error)]
14pub enum HexError {
15 #[error(transparent)]
16 ParseIntError(#[from] ParseIntError),
17
18 #[error("Invalid hex length: {0}")]
19 InvalidHexLength(usize),
20}
21
22#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
24pub struct Hex<T>(pub T);
25
26impl<T> AsRef<[u8]> for Hex<T>
27where
28 T: AsRef<[u8]>,
29{
30 fn as_ref(&self) -> &[u8] {
31 self.0.as_ref()
32 }
33}
34
35impl<T> Hex<T>
36where
37 T: AsRef<[u8]>,
38{
39 pub fn len(&self) -> usize {
41 self.0.as_ref().len()
42 }
43}
44
45impl<T> LowerHex for Hex<T>
46where
47 T: AsRef<[u8]>,
48{
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 let buf = self.0.as_ref();
51
52 let mut prefix = String::new();
53 let mut leading_spaces = String::new();
54
55 if let Some(width) = f.width() {
56 let content_with = buf.len() * 2 + if f.alternate() { 2 } else { 0 };
57
58 if width > content_with {
59 if f.sign_aware_zero_pad() {
60 prefix = "0".repeat(width - content_with);
61 } else {
62 leading_spaces = " ".repeat(width - content_with)
63 }
64 }
65 }
66
67 let mut content = String::with_capacity(buf.len() * 2);
68 for &b in buf {
69 write!(&mut content, "{:02x}", b)?;
70 }
71
72 if f.alternate() {
73 write!(f, "{}0x{}{}", leading_spaces, prefix, content)
74 } else {
75 write!(f, "{}{}{}", leading_spaces, prefix, content)
76 }
77 }
78}
79
80impl<T> UpperHex for Hex<T>
81where
82 T: AsRef<[u8]>,
83{
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 let buf = self.0.as_ref();
86
87 let mut prefix = String::new();
88 let mut leading_spaces = String::new();
89
90 if let Some(width) = f.width() {
91 let content_with = buf.len() * 2 + if f.alternate() { 2 } else { 0 };
92
93 if width > content_with {
94 if f.sign_aware_zero_pad() {
95 prefix = "0".repeat(width - content_with);
96 } else {
97 leading_spaces = " ".repeat(width - content_with)
98 }
99 }
100 }
101
102 let mut content = String::with_capacity(buf.len() * 2);
103 for &b in buf {
104 write!(&mut content, "{:02X}", b)?;
105 }
106
107 if f.alternate() {
108 write!(f, "{}0x{}{}", leading_spaces, prefix, content)
109 } else {
110 write!(f, "{}{}{}", leading_spaces, prefix, content)
111 }
112 }
113}
114
115impl FromStr for Hex<Vec<u8>> {
116 type Err = HexError;
117
118 fn from_str(s: &str) -> Result<Self, Self::Err> {
119 let s = if s.starts_with("0x") { &s[2..] } else { s };
120
121 if s.len() < 2 {
122 if s == "0" {
123 return Ok(Hex(vec![0x0]));
124 }
125 }
126
127 let s = if s.len() % 2 != 0 {
128 "0".to_string() + s
129 } else {
130 s.to_owned()
131 };
132
133 let buf: Result<Vec<u8>, ParseIntError> = (0..s.len())
134 .step_by(2)
135 .map(|i| u8::from_str_radix(&s[i..i + 2], 16))
136 .collect();
137
138 Ok(Hex(buf?))
139 }
140}
141
142impl<const N: usize> FromStr for Hex<[u8; N]> {
143 type Err = HexError;
144
145 fn from_str(s: &str) -> Result<Self, Self::Err> {
146 let s = if s.starts_with("0x") { &s[2..] } else { s };
147
148 if s.len() > N * 2 {
149 return Err(HexError::InvalidHexLength(s.len()));
150 }
151
152 if s.len() % 2 != 0 {
153 return Err(HexError::InvalidHexLength(s.len()));
154 }
155
156 let offset = N - s.len() / 2;
157
158 let mut buf = [0u8; N];
159
160 for i in offset..N {
161 buf[i] = u8::from_str_radix(&s[(i - offset) * 2..(i - offset) * 2 + 2], 16)?;
162 }
163
164 Ok(Hex(buf))
165 }
166}
167
168impl TryFrom<&str> for Hex<Vec<u8>> {
169 type Error = HexError;
170
171 fn try_from(value: &str) -> Result<Self, Self::Error> {
172 value.parse()
173 }
174}
175
176impl TryFrom<String> for Hex<Vec<u8>> {
177 type Error = HexError;
178
179 fn try_from(value: String) -> Result<Self, Self::Error> {
180 value.parse()
181 }
182}
183
184impl<const N: usize> TryFrom<&str> for Hex<[u8; N]> {
185 type Error = HexError;
186
187 fn try_from(value: &str) -> Result<Self, Self::Error> {
188 value.parse()
189 }
190}
191
192impl<const N: usize> TryFrom<String> for Hex<[u8; N]> {
193 type Error = HexError;
194
195 fn try_from(value: String) -> Result<Self, Self::Error> {
196 value.parse()
197 }
198}
199
200impl From<Vec<u8>> for Hex<Vec<u8>> {
201 fn from(value: Vec<u8>) -> Self {
202 Self(value)
203 }
204}
205impl From<Hex<Vec<u8>>> for Vec<u8> {
206 fn from(value: Hex<Vec<u8>>) -> Self {
207 value.0
208 }
209}
210
211impl<const N: usize> From<[u8; N]> for Hex<[u8; N]> {
212 fn from(value: [u8; N]) -> Self {
213 Self(value)
214 }
215}
216impl<const N: usize> From<Hex<[u8; N]>> for [u8; N] {
217 fn from(value: Hex<[u8; N]>) -> Self {
218 value.0
219 }
220}
221
222impl Serialize for Hex<Vec<u8>> {
223 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
224 where
225 S: serde::Serializer,
226 {
227 if serializer.is_human_readable() {
228 format!("{:#x}", self).serialize(serializer)
229 } else {
230 serializer.serialize_newtype_struct("bytes", &self.0)
231 }
232 }
233}
234
235impl<'de> Deserialize<'de> for Hex<Vec<u8>> {
236 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
237 where
238 D: serde::Deserializer<'de>,
239 {
240 if deserializer.is_human_readable() {
241 let data = String::deserialize(deserializer)?;
242
243 Hex::<Vec<u8>>::from_str(&data).map_err(serde::de::Error::custom)
244 } else {
245 let hex = Vec::<u8>::deserialize(deserializer)?;
246
247 Ok(Hex(hex))
248 }
249 }
250}
251
252impl<const N: usize> Serialize for Hex<[u8; N]> {
253 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
254 where
255 S: serde::Serializer,
256 {
257 if serializer.is_human_readable() {
258 format!("{:#x}", self).serialize(serializer)
259 } else {
260 serializer.serialize_newtype_struct("bytesN", &self.0.to_vec())
261 }
262 }
263}
264
265impl<'de, const N: usize> Deserialize<'de> for Hex<[u8; N]> {
266 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
267 where
268 D: serde::Deserializer<'de>,
269 {
270 if deserializer.is_human_readable() {
271 let data = String::deserialize(deserializer)?;
272
273 Hex::<[u8; N]>::from_str(&data).map_err(serde::de::Error::custom)
274 } else {
275 let buf = deserializer.deserialize_newtype_struct("bytesN", BytesVisitor)?;
276
277 if buf.len() < N {
278 return Err(HexError::InvalidHexLength(buf.len()))
279 .map_err(serde::de::Error::custom);
280 }
281
282 let mut hex = [0u8; N];
283
284 hex.copy_from_slice(&buf[(buf.len() - N)..]);
285
286 Ok(Hex(hex))
287 }
288 }
289}
290
291impl<T: AsRef<[u8]>> Display for Hex<T> {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 write!(f, "{:#x}", self)
294 }
295}
296
297#[cfg(test)]
298mod tests {
299
300 use super::*;
301
302 #[test]
303 fn test_lower_hex_padding() {
304 assert_eq!(
305 format!("{:#66x}", Hex(&[0u8, 0x1, 0xa])),
306 " 0x00010a"
307 );
308
309 assert_eq!(
310 format!("{:#066x}", Hex(&[0u8, 0x1, 0xa])),
311 "0x000000000000000000000000000000000000000000000000000000000000010a"
312 );
313 }
314
315 #[test]
316 fn test_upper_hex_padding() {
317 assert_eq!(
318 format!("{:#66X}", Hex(&[0u8, 0x1, 0xa])),
319 " 0x00010A"
320 );
321
322 assert_eq!(
323 format!("{:#066X}", Hex(&[0u8, 0x1, 0xa])),
324 "0x000000000000000000000000000000000000000000000000000000000000010A"
325 );
326 }
327}