mwtitle/
ip.rs

1/*
2Copyright (C) 2021 Erutuon
3
4This program is free software: you can redistribute it and/or modify
5it under the terms of the GNU General Public License as published by
6the Free Software Foundation, either version 3 of the License, or
7(at your option) any later version.
8
9This program is distributed in the hope that it will be useful,
10but WITHOUT ANY WARRANTY; without even the implied warranty of
11MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12GNU General Public License for more details.
13
14You should have received a copy of the GNU General Public License
15along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17use crate::ipv6::{parse_ipv6_rev, Segment};
18use std::io::Cursor;
19
20/// Validates IPv4 with optional range in a generous way,
21/// allowing leading zeros except in the range portion.
22fn is_ipv4(ip: &str) -> bool {
23    let mut nums = ip
24        .split('.')
25        // Every number must be 3 or fewer digits.
26        .take_while(|num| num.len() <= 3)
27        .map(|num| num.parse::<u8>());
28    // Must contain exactly 4 valid numbers.
29    nums.by_ref()
30        .take(4)
31        .filter(std::result::Result::is_ok)
32        .count()
33        == 4
34        && nums.next().is_none()
35}
36
37#[test]
38fn is_ipv4_recognizes_ipv4_addresses() {
39    assert!(is_ipv4("000.001.00.255"));
40}
41
42#[test]
43fn is_ipv4_rejects_invalid_ipv4_addresses() {
44    assert!(is_ipv4("000.001.00.255"));
45    // a number is greater than u8::MAX
46    assert!(!is_ipv4("000.001.00.256"));
47    // range not allowed
48    assert!(!is_ipv4("000.001.00.256/32"));
49    // too many digits in number
50    assert!(!is_ipv4("0000.1.1.1"));
51    // too few numbers
52    assert!(!is_ipv4("1.1.1"));
53    // too many numbers
54    assert!(!is_ipv4("1.1.1.1.1"));
55    // extraneous characters
56    assert!(!is_ipv4("1.1._.1"));
57    assert!(!is_ipv4("_.1.1.1.1"));
58    assert!(!is_ipv4("1_.1.1.1.1"));
59    assert!(!is_ipv4("_1.1.1.1.1"));
60    assert!(!is_ipv4("1.1.1.1.1_"));
61}
62
63fn write_to_buf(
64    buf: &mut [u8],
65    mut writer: impl FnMut(&mut Cursor<&mut [u8]>) -> std::io::Result<()>,
66) -> std::io::Result<&str> {
67    let end = {
68        let mut cursor = Cursor::new(&mut *buf);
69        writer(&mut cursor)?;
70        cursor.position() as usize
71    };
72    // SAFETY: `end` is the end of the valid UTF-8 that was just written to `buf`.
73    std::str::from_utf8(&buf[..end])
74        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
75}
76
77/// Convert IP addresses to a consistent verbose form.
78/// Strip leading 0 from segments of IPv4 and IPv6 addresses.
79/// Replace :: with zeroes and replace lowercase a-f with uppercase A-F in IPv6 addresses.
80pub(crate) fn sanitize_ip(input: &mut String) {
81    // SAFETY: `str::split` yields one `Some(_)`, even when it is given an empty string.
82    let ip = input.split('/').next().unwrap();
83    if is_ipv4(ip) {
84        let mut zeros_to_remove = [None, None, None, None];
85        let mut iter_zeros_to_remove = zeros_to_remove.iter_mut();
86        // Iterate over positions of zeros at beginning of input or before '.'.
87        // Iterate in reverse order because zeros_to_remove must be applied
88        // from the end of the string to the beginning
89        // for the ranges in zeros_to_remove to remain valid.
90        for pos in ip
91            .rmatch_indices('0')
92            .map(|(pos, _)| pos)
93            .filter(|&pos| pos == 0 || input.as_bytes()[pos - 1] == b'.')
94        {
95            let zero_count = input.as_bytes()[pos..]
96                .iter()
97                .position(|b| *b != b'0')
98                .unwrap_or(ip.len() - pos);
99            let zeros_to_remove =
100                    // If sequence of zeros is at the end of input or before '.', keep one zero.
101                    if pos + zero_count == ip.len() || input.as_bytes()[pos..][zero_count] == b'.' {
102                        zero_count - 1
103                    } else {
104                        zero_count
105                    };
106            if zeros_to_remove > 0 {
107                // This unwrap won't panic because is_ipv4 ensures
108                // that input contains exactly 4 numbers separated by '.'
109                // plus an optional range, which doesn't begin with a zero,
110                // so there are at most 4 sequences of zeros that are at the beginning of input
111                // or preceded by '.' that could need to be trimmed.
112                *iter_zeros_to_remove.next().unwrap() =
113                    Some(pos..pos + zeros_to_remove);
114            }
115        }
116        // Flattening will ensure every zero to remove is visited, because they are inserted sequentially.
117        for modification in zeros_to_remove.into_iter().flatten() {
118            input.replace_range(modification, "");
119        }
120    } else if let Ok(reverse_parsed_ipv6) = parse_ipv6_rev(ip).map_err(|_| ()) {
121        let ip_len = ip.len(); // to satisfy borrow checker
122
123        // Normalization done in reverse order, so that indices are not invalidated.
124        // parse_ipv6_rev guarantees that indices are valid.
125        use std::io::Write as _;
126        for segment in &reverse_parsed_ipv6 {
127            match segment {
128                // Convert a-f to uppercase and remove extra leading zeros.
129                Segment::Num(range) => {
130                    let num = &input[range.clone()];
131                    if (num.starts_with('0') && num.len() > 1)
132                        || num.bytes().any(|b| b.is_ascii_lowercase())
133                    {
134                        // Reserve enough space for 4 hex digits, enough for any u16.
135                        let mut buf = [0u8; 4];
136                        let hex = write_to_buf(&mut buf, |cursor| {
137                            write!(
138                                cursor,
139                                "{:X}",
140                                // SAFETY: `parse_ipv6_rev` checks
141                                // that `Segment::Num` contains up to 4 hex digits.
142                                u16::from_str_radix(num, 16).unwrap()
143                            )
144                        })
145                        // SAFETY: The only error that `std::io::Write` for `Cursor<&mut u8>`
146                        // emits is "no more space" (ErrorKind::WriteZero);
147                        // `buf` has enough space to write any `u16`.
148                        .unwrap();
149                        input.replace_range(range.clone(), hex);
150                    }
151                }
152                // Normalize :: to a sequence of zeros separated by :.
153                Segment::Colons(range) => {
154                    if range.len() == 2 {
155                        let number_count = reverse_parsed_ipv6
156                            .iter()
157                            .filter(|segment| {
158                                !matches!(segment, Segment::Colons(_))
159                            })
160                            .count();
161                        // SAFETY: This can't underflow
162                        // because `parse_ipv6_rev` checks that its return value has 8 or fewer elements.
163                        let missing_zero_count = 8 - number_count;
164                        // Reserve enough space for the maximum number of zeros (8) and colons (7).
165                        let mut buf = [0u8; 15];
166                        let zeros = write_to_buf(&mut buf, |cursor| {
167                            if range.start != 0 {
168                                cursor.write_all(b":")?;
169                            }
170                            for i in 0..missing_zero_count {
171                                cursor.write_all(if i == 0 {
172                                    b"0"
173                                } else {
174                                    b":0"
175                                })?;
176                            }
177                            if range.end != ip_len {
178                                cursor.write_all(b":")?;
179                            }
180                            Ok(())
181                        })
182                        // SAFETY: `write_all` will always return `Ok(_)`
183                        // because `buf` has enough space for the maximum number of characters.
184                        .unwrap();
185                        input.replace_range(range.clone(), zeros);
186                    }
187                }
188            }
189        }
190    }
191}
192
193#[cfg(test)]
194fn test_sanitize_ip<const N: usize>(tests: [(&str, &str); N]) {
195    for (input, expected) in tests {
196        let mut output = input.to_string();
197        sanitize_ip(&mut output);
198        assert_eq!(output, expected, "{:?}", parse_ipv6_rev(input));
199    }
200}
201
202#[test]
203fn sanitize_ip_replaces_double_colons_with_zeros() {
204    test_sanitize_ip([
205        ("::1", "0:0:0:0:0:0:0:1"),
206        ("0:0:0:0:0:0:0:1", "0:0:0:0:0:0:0:1"),
207        ("::", "0:0:0:0:0:0:0:0"),
208        ("0:0:0:1::", "0:0:0:1:0:0:0:0"),
209        ("::1", "0:0:0:0:0:0:0:1"),
210        ("::1", "0:0:0:0:0:0:0:1"),
211        ("::1", "0:0:0:0:0:0:0:1"),
212    ]);
213}
214
215#[test]
216fn sanitize_ip_uppercases() {
217    test_sanitize_ip([
218        ("cebc:2004:f::", "CEBC:2004:F:0:0:0:0:0"),
219        ("3f:535::e:fbb", "3F:535:0:0:0:0:E:FBB"),
220        ("::1/24", "0:0:0:0:0:0:0:1/24"),
221    ]);
222}
223
224#[test]
225fn sanitize_ip_recognizes_subpages_of_ipv6_address() {
226    test_sanitize_ip([
227        ("1::1/IP_subpage", "1:0:0:0:0:0:0:1/IP_subpage"),
228        ("1::1_/not_IP_subpage", "1::1_/not_IP_subpage"),
229        ("1::g/not_IP_subpage", "1::g/not_IP_subpage"),
230        // Despite its look, this isn't an IP range.
231        ("1::1/24", "1:0:0:0:0:0:0:1/24"),
232    ]);
233}