cml_chain/certs/
utils.rs

1use std::{borrow::Cow, str::FromStr};
2
3use super::{Ipv4, Ipv6, StakeCredential};
4use cml_core::DeserializeError;
5use cml_crypto::RawBytesEncoding;
6
7impl StakeCredential {
8    // we don't implement RawBytesEncoding as from_raw_bytes() would be unable to distinguish
9    pub fn to_raw_bytes(&self) -> &[u8] {
10        match self {
11            Self::PubKey { hash, .. } => hash.to_raw_bytes(),
12            Self::Script { hash, .. } => hash.to_raw_bytes(),
13        }
14    }
15}
16
17#[derive(Debug, thiserror::Error)]
18pub enum IPStringParsingError {
19    #[error("Invalid IPv4 Address String, expected period-separated bytes e.g. 0.0.0.0")]
20    IPv4StringFormat,
21    #[error("Invalid IPv6 Address String, expected colon-separated hextets e.g. 2001:0db8:0000:0000:0000:8a2e:0370:7334")]
22    IPv6StringFormat,
23    #[error("Deserializing from bytes: {0:?}")]
24    DeserializeError(DeserializeError),
25}
26
27impl std::fmt::Display for Ipv4 {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(
30            f,
31            "{}",
32            self.inner
33                .iter()
34                .map(ToString::to_string)
35                .collect::<Vec<String>>()
36                .join(".")
37        )
38    }
39}
40
41impl FromStr for Ipv4 {
42    type Err = IPStringParsingError;
43
44    fn from_str(s: &str) -> Result<Self, Self::Err> {
45        s.split('.')
46            .map(FromStr::from_str)
47            .collect::<Result<Vec<u8>, _>>()
48            .map_err(|_e| IPStringParsingError::IPv4StringFormat)
49            .and_then(|bytes| Self::new(bytes).map_err(IPStringParsingError::DeserializeError))
50    }
51}
52
53impl serde::Serialize for Ipv4 {
54    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
55    where
56        S: serde::Serializer,
57    {
58        serializer.serialize_str(&self.to_string())
59    }
60}
61
62impl<'de> serde::de::Deserialize<'de> for Ipv4 {
63    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
64    where
65        D: serde::de::Deserializer<'de>,
66    {
67        let s = <String as serde::de::Deserialize>::deserialize(deserializer)?;
68        Self::from_str(&s).map_err(|_e| {
69            serde::de::Error::invalid_value(serde::de::Unexpected::Str(&s), &"invalid ipv4 address")
70        })
71    }
72}
73
74impl schemars::JsonSchema for Ipv4 {
75    fn schema_name() -> String {
76        String::from("Ipv4")
77    }
78
79    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
80        String::json_schema(gen)
81    }
82
83    fn is_referenceable() -> bool {
84        String::is_referenceable()
85    }
86}
87
88impl Ipv6 {
89    const LEN: usize = 16;
90
91    pub fn hextets(&self) -> Vec<u16> {
92        let mut ret = Vec::with_capacity(Self::LEN / 2);
93        for i in (0..self.inner.len()).step_by(2) {
94            ret.push(((self.inner[i + 1] as u16) << 8) | (self.inner[i] as u16));
95        }
96        ret
97    }
98}
99
100impl std::fmt::Display for Ipv6 {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        // Using the canonical format for IPV6 in RFC5952
103        // 4.1) Leading zeros MUST be suppressed.
104        // 4.2.1) :: MUST shorten as much as possible
105        // 4.2.2) :: MUST NOT be used for a single 0 field
106        // 4.2.3) :: Ties are broken by choosing the location first in the string
107        // 4.3) Hex chars MUST be lowercase
108        // NOTE: we do NOT support IPv4-Mapped IPv6 special text representations (Section 5)
109        //       this is only RECOMMENDED, not required, and only when the format is known
110        //       e.g. specific prefixes are used
111        let mut best_gap_len = 0;
112        let mut best_gap_start = 0;
113        // usize::MAX is fine since we're max 16 here
114        const UNDEF: usize = usize::MAX;
115        let mut cur_gap_start = UNDEF;
116        let hextets = self.hextets();
117        for (i, hextet) in hextets.iter().enumerate() {
118            if *hextet == 0 {
119                if cur_gap_start == UNDEF {
120                    cur_gap_start = i;
121                }
122            } else {
123                if cur_gap_start != UNDEF && (i - cur_gap_start) > best_gap_len {
124                    best_gap_len = i - cur_gap_start;
125                    best_gap_start = cur_gap_start;
126                }
127                cur_gap_start = UNDEF;
128            }
129        }
130        if cur_gap_start != UNDEF && (hextets.len() - cur_gap_start) > best_gap_len {
131            best_gap_len = hextets.len() - cur_gap_start;
132            best_gap_start = cur_gap_start;
133        }
134        fn ipv6_substr(hextet_substr: &[u16]) -> String {
135            hextet_substr
136                .iter()
137                .map(|hextet| {
138                    let trimmed = hex::encode(hextet.to_le_bytes())
139                        .trim_start_matches('0')
140                        .to_owned();
141                    if trimmed.is_empty() {
142                        "0".to_owned()
143                    } else {
144                        trimmed
145                    }
146                })
147                .collect::<Vec<String>>()
148                .join(":")
149        }
150        let canonical_str_rep = if best_gap_len > 1 {
151            format!(
152                "{}::{}",
153                ipv6_substr(&hextets[..best_gap_start]),
154                ipv6_substr(&hextets[(best_gap_start + best_gap_len)..])
155            )
156        } else {
157            ipv6_substr(&hextets)
158        };
159        write!(f, "{}", canonical_str_rep)
160    }
161}
162
163impl FromStr for Ipv6 {
164    type Err = IPStringParsingError;
165
166    fn from_str(s: &str) -> Result<Self, Self::Err> {
167        fn ipv6_subbytes(substr: &str) -> Result<Vec<u8>, IPStringParsingError> {
168            let mut bytes = Vec::new();
169            for hextet_str in substr.split(':') {
170                // hex::decode does not allow odd-length strings so pad it
171                let padded_str = if hextet_str.len() % 2 == 0 {
172                    Cow::Borrowed(hextet_str)
173                } else {
174                    Cow::Owned(format!("0{hextet_str}"))
175                };
176                let hextet_bytes = hex::decode(padded_str.as_bytes())
177                    .map_err(|_e| IPStringParsingError::IPv6StringFormat)?;
178                match hextet_bytes.len() {
179                    0 => {
180                        bytes.extend(&[0, 0]);
181                    }
182                    1 => {
183                        bytes.push(0);
184                        bytes.push(hextet_bytes[0]);
185                    }
186                    2 => {
187                        bytes.extend(&hextet_bytes);
188                    }
189                    _ => return Err(IPStringParsingError::IPv6StringFormat),
190                }
191            }
192            Ok(bytes)
193        }
194        let bytes = if let Some((left_str, right_str)) = s.split_once("::") {
195            let mut bytes = ipv6_subbytes(left_str)?;
196            let right_bytes = ipv6_subbytes(right_str)?;
197            // pad middle with 0s
198            bytes.resize(Self::LEN - right_bytes.len(), 0);
199            bytes.extend(&right_bytes);
200            bytes
201        } else {
202            ipv6_subbytes(s)?
203        };
204        Self::new(bytes).map_err(IPStringParsingError::DeserializeError)
205    }
206}
207
208impl serde::Serialize for Ipv6 {
209    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
210    where
211        S: serde::Serializer,
212    {
213        serializer.serialize_str(&self.to_string())
214    }
215}
216
217impl<'de> serde::de::Deserialize<'de> for Ipv6 {
218    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
219    where
220        D: serde::de::Deserializer<'de>,
221    {
222        let s = <String as serde::de::Deserialize>::deserialize(deserializer)?;
223        Self::from_str(&s).map_err(|_e| {
224            serde::de::Error::invalid_value(serde::de::Unexpected::Str(&s), &"invalid ipv6 address")
225        })
226    }
227}
228
229impl schemars::JsonSchema for Ipv6 {
230    fn schema_name() -> String {
231        String::from("Ipv6")
232    }
233
234    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
235        String::json_schema(gen)
236    }
237
238    fn is_referenceable() -> bool {
239        String::is_referenceable()
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn ipv4_json() {
249        let json_str_1 = "\"0.0.0.0\"";
250        let from_json_1: Ipv4 = serde_json::from_str(json_str_1).unwrap();
251        let to_json_1 = serde_json::to_string_pretty(&from_json_1).unwrap();
252        assert_eq!(json_str_1, to_json_1);
253        let json_str_2 = "\"255.255.255.255\"";
254        let from_json_2: Ipv4 = serde_json::from_str(json_str_2).unwrap();
255        let to_json_2 = serde_json::to_string_pretty(&from_json_2).unwrap();
256        assert_eq!(json_str_2, to_json_2);
257    }
258
259    fn ipv6_json_testcase(long_form_json: &str, canonical_form_json: &str) {
260        let from_long: Ipv6 = serde_json::from_str(long_form_json).unwrap();
261        let to_json_1 = serde_json::to_string_pretty(&from_long).unwrap();
262        assert_eq!(canonical_form_json, to_json_1);
263        let from_canonical: Ipv6 = serde_json::from_str(canonical_form_json).unwrap();
264        let to_json_2 = serde_json::to_string_pretty(&from_canonical).unwrap();
265        assert_eq!(canonical_form_json, to_json_2);
266        assert_eq!(from_long.inner, from_canonical.inner);
267    }
268
269    #[test]
270    fn ipv6_json() {
271        // This tests that we abide by RFC5952 for IPV6 Canonical text form
272        // part of the implementation relies on the hex crate's behavior but
273        // that is checked as part of this test (e.g. that lowercase is used + omit leading 0s)
274        ipv6_json_testcase(
275            "\"2001:0db8:0000:0000:0000:ff00:0042:8329\"",
276            "\"2001:db8::ff00:42:8329\"",
277        );
278        // ties broken by first one
279        ipv6_json_testcase(
280            "\"2001:0db8:0000:0000:1111:0000:0000:8329\"",
281            "\"2001:db8::1111:0:0:8329\"",
282        );
283        // min run not first
284        ipv6_json_testcase(
285            "\"0001:0000:0002:0000:0000:0000:0003:0000\"",
286            "\"1:0:2::3:0\"",
287        );
288        // ends in min run
289        ipv6_json_testcase("\"000a:000b:0000:0000:0000:0000:0000:0000\"", "\"a:b::\"");
290        // starts with min run
291        ipv6_json_testcase(
292            "\"0000:0000:0000:0000:0000:0000:abcd:0000\"",
293            "\"::abcd:0\"",
294        );
295        // don't use runs for single 0 hextets
296        ipv6_json_testcase(
297            "\"0000:000a:0000:000b:0000:000c:0000:000d\"",
298            "\"0:a:0:b:0:c:0:d\"",
299        );
300    }
301}