1use std::{borrow::Cow, str::FromStr};
2
3use super::{Ipv4, Ipv6, StakeCredential};
4use cml_core::DeserializeError;
5use cml_crypto::RawBytesEncoding;
6
7impl StakeCredential {
8 pub fn to_raw_bytes(&self) -> &[u8] {
10 match self {
11 Self::PubKey { hash, .. } => hash.to_raw_bytes(),
12 Self::Script { hash, .. } => hash.to_raw_bytes(),
13 }
14 }
15}
16
17#[derive(Debug, thiserror::Error)]
18pub enum IPStringParsingError {
19 #[error("Invalid IPv4 Address String, expected period-separated bytes e.g. 0.0.0.0")]
20 IPv4StringFormat,
21 #[error("Invalid IPv6 Address String, expected colon-separated hextets e.g. 2001:0db8:0000:0000:0000:8a2e:0370:7334")]
22 IPv6StringFormat,
23 #[error("Deserializing from bytes: {0:?}")]
24 DeserializeError(DeserializeError),
25}
26
27impl std::fmt::Display for Ipv4 {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(
30 f,
31 "{}",
32 self.inner
33 .iter()
34 .map(ToString::to_string)
35 .collect::<Vec<String>>()
36 .join(".")
37 )
38 }
39}
40
41impl FromStr for Ipv4 {
42 type Err = IPStringParsingError;
43
44 fn from_str(s: &str) -> Result<Self, Self::Err> {
45 s.split('.')
46 .map(FromStr::from_str)
47 .collect::<Result<Vec<u8>, _>>()
48 .map_err(|_e| IPStringParsingError::IPv4StringFormat)
49 .and_then(|bytes| Self::new(bytes).map_err(IPStringParsingError::DeserializeError))
50 }
51}
52
53impl serde::Serialize for Ipv4 {
54 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
55 where
56 S: serde::Serializer,
57 {
58 serializer.serialize_str(&self.to_string())
59 }
60}
61
62impl<'de> serde::de::Deserialize<'de> for Ipv4 {
63 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
64 where
65 D: serde::de::Deserializer<'de>,
66 {
67 let s = <String as serde::de::Deserialize>::deserialize(deserializer)?;
68 Self::from_str(&s).map_err(|_e| {
69 serde::de::Error::invalid_value(serde::de::Unexpected::Str(&s), &"invalid ipv4 address")
70 })
71 }
72}
73
74impl schemars::JsonSchema for Ipv4 {
75 fn schema_name() -> String {
76 String::from("Ipv4")
77 }
78
79 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
80 String::json_schema(gen)
81 }
82
83 fn is_referenceable() -> bool {
84 String::is_referenceable()
85 }
86}
87
88impl Ipv6 {
89 const LEN: usize = 16;
90
91 pub fn hextets(&self) -> Vec<u16> {
92 let mut ret = Vec::with_capacity(Self::LEN / 2);
93 for i in (0..self.inner.len()).step_by(2) {
94 ret.push(((self.inner[i + 1] as u16) << 8) | (self.inner[i] as u16));
95 }
96 ret
97 }
98}
99
100impl std::fmt::Display for Ipv6 {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 let mut best_gap_len = 0;
112 let mut best_gap_start = 0;
113 const UNDEF: usize = usize::MAX;
115 let mut cur_gap_start = UNDEF;
116 let hextets = self.hextets();
117 for (i, hextet) in hextets.iter().enumerate() {
118 if *hextet == 0 {
119 if cur_gap_start == UNDEF {
120 cur_gap_start = i;
121 }
122 } else {
123 if cur_gap_start != UNDEF && (i - cur_gap_start) > best_gap_len {
124 best_gap_len = i - cur_gap_start;
125 best_gap_start = cur_gap_start;
126 }
127 cur_gap_start = UNDEF;
128 }
129 }
130 if cur_gap_start != UNDEF && (hextets.len() - cur_gap_start) > best_gap_len {
131 best_gap_len = hextets.len() - cur_gap_start;
132 best_gap_start = cur_gap_start;
133 }
134 fn ipv6_substr(hextet_substr: &[u16]) -> String {
135 hextet_substr
136 .iter()
137 .map(|hextet| {
138 let trimmed = hex::encode(hextet.to_le_bytes())
139 .trim_start_matches('0')
140 .to_owned();
141 if trimmed.is_empty() {
142 "0".to_owned()
143 } else {
144 trimmed
145 }
146 })
147 .collect::<Vec<String>>()
148 .join(":")
149 }
150 let canonical_str_rep = if best_gap_len > 1 {
151 format!(
152 "{}::{}",
153 ipv6_substr(&hextets[..best_gap_start]),
154 ipv6_substr(&hextets[(best_gap_start + best_gap_len)..])
155 )
156 } else {
157 ipv6_substr(&hextets)
158 };
159 write!(f, "{}", canonical_str_rep)
160 }
161}
162
163impl FromStr for Ipv6 {
164 type Err = IPStringParsingError;
165
166 fn from_str(s: &str) -> Result<Self, Self::Err> {
167 fn ipv6_subbytes(substr: &str) -> Result<Vec<u8>, IPStringParsingError> {
168 let mut bytes = Vec::new();
169 for hextet_str in substr.split(':') {
170 let padded_str = if hextet_str.len() % 2 == 0 {
172 Cow::Borrowed(hextet_str)
173 } else {
174 Cow::Owned(format!("0{hextet_str}"))
175 };
176 let hextet_bytes = hex::decode(padded_str.as_bytes())
177 .map_err(|_e| IPStringParsingError::IPv6StringFormat)?;
178 match hextet_bytes.len() {
179 0 => {
180 bytes.extend(&[0, 0]);
181 }
182 1 => {
183 bytes.push(0);
184 bytes.push(hextet_bytes[0]);
185 }
186 2 => {
187 bytes.extend(&hextet_bytes);
188 }
189 _ => return Err(IPStringParsingError::IPv6StringFormat),
190 }
191 }
192 Ok(bytes)
193 }
194 let bytes = if let Some((left_str, right_str)) = s.split_once("::") {
195 let mut bytes = ipv6_subbytes(left_str)?;
196 let right_bytes = ipv6_subbytes(right_str)?;
197 bytes.resize(Self::LEN - right_bytes.len(), 0);
199 bytes.extend(&right_bytes);
200 bytes
201 } else {
202 ipv6_subbytes(s)?
203 };
204 Self::new(bytes).map_err(IPStringParsingError::DeserializeError)
205 }
206}
207
208impl serde::Serialize for Ipv6 {
209 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
210 where
211 S: serde::Serializer,
212 {
213 serializer.serialize_str(&self.to_string())
214 }
215}
216
217impl<'de> serde::de::Deserialize<'de> for Ipv6 {
218 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
219 where
220 D: serde::de::Deserializer<'de>,
221 {
222 let s = <String as serde::de::Deserialize>::deserialize(deserializer)?;
223 Self::from_str(&s).map_err(|_e| {
224 serde::de::Error::invalid_value(serde::de::Unexpected::Str(&s), &"invalid ipv6 address")
225 })
226 }
227}
228
229impl schemars::JsonSchema for Ipv6 {
230 fn schema_name() -> String {
231 String::from("Ipv6")
232 }
233
234 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
235 String::json_schema(gen)
236 }
237
238 fn is_referenceable() -> bool {
239 String::is_referenceable()
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn ipv4_json() {
249 let json_str_1 = "\"0.0.0.0\"";
250 let from_json_1: Ipv4 = serde_json::from_str(json_str_1).unwrap();
251 let to_json_1 = serde_json::to_string_pretty(&from_json_1).unwrap();
252 assert_eq!(json_str_1, to_json_1);
253 let json_str_2 = "\"255.255.255.255\"";
254 let from_json_2: Ipv4 = serde_json::from_str(json_str_2).unwrap();
255 let to_json_2 = serde_json::to_string_pretty(&from_json_2).unwrap();
256 assert_eq!(json_str_2, to_json_2);
257 }
258
259 fn ipv6_json_testcase(long_form_json: &str, canonical_form_json: &str) {
260 let from_long: Ipv6 = serde_json::from_str(long_form_json).unwrap();
261 let to_json_1 = serde_json::to_string_pretty(&from_long).unwrap();
262 assert_eq!(canonical_form_json, to_json_1);
263 let from_canonical: Ipv6 = serde_json::from_str(canonical_form_json).unwrap();
264 let to_json_2 = serde_json::to_string_pretty(&from_canonical).unwrap();
265 assert_eq!(canonical_form_json, to_json_2);
266 assert_eq!(from_long.inner, from_canonical.inner);
267 }
268
269 #[test]
270 fn ipv6_json() {
271 ipv6_json_testcase(
275 "\"2001:0db8:0000:0000:0000:ff00:0042:8329\"",
276 "\"2001:db8::ff00:42:8329\"",
277 );
278 ipv6_json_testcase(
280 "\"2001:0db8:0000:0000:1111:0000:0000:8329\"",
281 "\"2001:db8::1111:0:0:8329\"",
282 );
283 ipv6_json_testcase(
285 "\"0001:0000:0002:0000:0000:0000:0003:0000\"",
286 "\"1:0:2::3:0\"",
287 );
288 ipv6_json_testcase("\"000a:000b:0000:0000:0000:0000:0000:0000\"", "\"a:b::\"");
290 ipv6_json_testcase(
292 "\"0000:0000:0000:0000:0000:0000:abcd:0000\"",
293 "\"::abcd:0\"",
294 );
295 ipv6_json_testcase(
297 "\"0000:000a:0000:000b:0000:000c:0000:000d\"",
298 "\"0:a:0:b:0:c:0:d\"",
299 );
300 }
301}