mwtitle 0.2.7

MediaWiki title validation and formatting
Documentation
/*
Copyright (C) 2021 Erutuon

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
use crate::ipv6::{parse_ipv6_rev, Segment};
use std::io::Cursor;

/// Validates IPv4 with optional range in a generous way,
/// allowing leading zeros except in the range portion.
fn is_ipv4(ip: &str) -> bool {
    let mut nums = ip
        .split('.')
        // Every number must be 3 or fewer digits.
        .take_while(|num| num.len() <= 3)
        .map(|num| num.parse::<u8>());
    // Must contain exactly 4 valid numbers.
    nums.by_ref()
        .take(4)
        .filter(std::result::Result::is_ok)
        .count()
        == 4
        && nums.next().is_none()
}

#[test]
fn is_ipv4_recognizes_ipv4_addresses() {
    assert!(is_ipv4("000.001.00.255"));
}

#[test]
fn is_ipv4_rejects_invalid_ipv4_addresses() {
    assert!(is_ipv4("000.001.00.255"));
    // a number is greater than u8::MAX
    assert!(!is_ipv4("000.001.00.256"));
    // range not allowed
    assert!(!is_ipv4("000.001.00.256/32"));
    // too many digits in number
    assert!(!is_ipv4("0000.1.1.1"));
    // too few numbers
    assert!(!is_ipv4("1.1.1"));
    // too many numbers
    assert!(!is_ipv4("1.1.1.1.1"));
    // extraneous characters
    assert!(!is_ipv4("1.1._.1"));
    assert!(!is_ipv4("_.1.1.1.1"));
    assert!(!is_ipv4("1_.1.1.1.1"));
    assert!(!is_ipv4("_1.1.1.1.1"));
    assert!(!is_ipv4("1.1.1.1.1_"));
}

fn write_to_buf(
    buf: &mut [u8],
    mut writer: impl FnMut(&mut Cursor<&mut [u8]>) -> std::io::Result<()>,
) -> std::io::Result<&str> {
    let end = {
        let mut cursor = Cursor::new(&mut *buf);
        writer(&mut cursor)?;
        cursor.position() as usize
    };
    // SAFETY: `end` is the end of the valid UTF-8 that was just written to `buf`.
    std::str::from_utf8(&buf[..end])
        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}

/// Convert IP addresses to a consistent verbose form.
/// Strip leading 0 from segments of IPv4 and IPv6 addresses.
/// Replace :: with zeroes and replace lowercase a-f with uppercase A-F in IPv6 addresses.
pub(crate) fn sanitize_ip(input: &mut String) {
    // SAFETY: `str::split` yields one `Some(_)`, even when it is given an empty string.
    let ip = input.split('/').next().unwrap();
    if is_ipv4(ip) {
        let mut zeros_to_remove = [None, None, None, None];
        let mut iter_zeros_to_remove = zeros_to_remove.iter_mut();
        // Iterate over positions of zeros at beginning of input or before '.'.
        // Iterate in reverse order because zeros_to_remove must be applied
        // from the end of the string to the beginning
        // for the ranges in zeros_to_remove to remain valid.
        for pos in ip
            .rmatch_indices('0')
            .map(|(pos, _)| pos)
            .filter(|&pos| pos == 0 || input.as_bytes()[pos - 1] == b'.')
        {
            let zero_count = input.as_bytes()[pos..]
                .iter()
                .position(|b| *b != b'0')
                .unwrap_or(ip.len() - pos);
            let zeros_to_remove =
                    // If sequence of zeros is at the end of input or before '.', keep one zero.
                    if pos + zero_count == ip.len() || input.as_bytes()[pos..][zero_count] == b'.' {
                        zero_count - 1
                    } else {
                        zero_count
                    };
            if zeros_to_remove > 0 {
                // This unwrap won't panic because is_ipv4 ensures
                // that input contains exactly 4 numbers separated by '.'
                // plus an optional range, which doesn't begin with a zero,
                // so there are at most 4 sequences of zeros that are at the beginning of input
                // or preceded by '.' that could need to be trimmed.
                *iter_zeros_to_remove.next().unwrap() =
                    Some(pos..pos + zeros_to_remove);
            }
        }
        // Flattening will ensure every zero to remove is visited, because they are inserted sequentially.
        for modification in zeros_to_remove.into_iter().flatten() {
            input.replace_range(modification, "");
        }
    } else if let Ok(reverse_parsed_ipv6) = parse_ipv6_rev(ip).map_err(|_| ()) {
        let ip_len = ip.len(); // to satisfy borrow checker

        // Normalization done in reverse order, so that indices are not invalidated.
        // parse_ipv6_rev guarantees that indices are valid.
        use std::io::Write as _;
        for segment in &reverse_parsed_ipv6 {
            match segment {
                // Convert a-f to uppercase and remove extra leading zeros.
                Segment::Num(range) => {
                    let num = &input[range.clone()];
                    if (num.starts_with('0') && num.len() > 1)
                        || num.bytes().any(|b| b.is_ascii_lowercase())
                    {
                        // Reserve enough space for 4 hex digits, enough for any u16.
                        let mut buf = [0u8; 4];
                        let hex = write_to_buf(&mut buf, |cursor| {
                            write!(
                                cursor,
                                "{:X}",
                                // SAFETY: `parse_ipv6_rev` checks
                                // that `Segment::Num` contains up to 4 hex digits.
                                u16::from_str_radix(num, 16).unwrap()
                            )
                        })
                        // SAFETY: The only error that `std::io::Write` for `Cursor<&mut u8>`
                        // emits is "no more space" (ErrorKind::WriteZero);
                        // `buf` has enough space to write any `u16`.
                        .unwrap();
                        input.replace_range(range.clone(), hex);
                    }
                }
                // Normalize :: to a sequence of zeros separated by :.
                Segment::Colons(range) => {
                    if range.len() == 2 {
                        let number_count = reverse_parsed_ipv6
                            .iter()
                            .filter(|segment| {
                                !matches!(segment, Segment::Colons(_))
                            })
                            .count();
                        // SAFETY: This can't underflow
                        // because `parse_ipv6_rev` checks that its return value has 8 or fewer elements.
                        let missing_zero_count = 8 - number_count;
                        // Reserve enough space for the maximum number of zeros (8) and colons (7).
                        let mut buf = [0u8; 15];
                        let zeros = write_to_buf(&mut buf, |cursor| {
                            if range.start != 0 {
                                cursor.write_all(b":")?;
                            }
                            for i in 0..missing_zero_count {
                                cursor.write_all(if i == 0 {
                                    b"0"
                                } else {
                                    b":0"
                                })?;
                            }
                            if range.end != ip_len {
                                cursor.write_all(b":")?;
                            }
                            Ok(())
                        })
                        // SAFETY: `write_all` will always return `Ok(_)`
                        // because `buf` has enough space for the maximum number of characters.
                        .unwrap();
                        input.replace_range(range.clone(), zeros);
                    }
                }
            }
        }
    }
}

#[cfg(test)]
fn test_sanitize_ip<const N: usize>(tests: [(&str, &str); N]) {
    for (input, expected) in tests {
        let mut output = input.to_string();
        sanitize_ip(&mut output);
        assert_eq!(output, expected, "{:?}", parse_ipv6_rev(input));
    }
}

#[test]
fn sanitize_ip_replaces_double_colons_with_zeros() {
    test_sanitize_ip([
        ("::1", "0:0:0:0:0:0:0:1"),
        ("0:0:0:0:0:0:0:1", "0:0:0:0:0:0:0:1"),
        ("::", "0:0:0:0:0:0:0:0"),
        ("0:0:0:1::", "0:0:0:1:0:0:0:0"),
        ("::1", "0:0:0:0:0:0:0:1"),
        ("::1", "0:0:0:0:0:0:0:1"),
        ("::1", "0:0:0:0:0:0:0:1"),
    ]);
}

#[test]
fn sanitize_ip_uppercases() {
    test_sanitize_ip([
        ("cebc:2004:f::", "CEBC:2004:F:0:0:0:0:0"),
        ("3f:535::e:fbb", "3F:535:0:0:0:0:E:FBB"),
        ("::1/24", "0:0:0:0:0:0:0:1/24"),
    ]);
}

#[test]
fn sanitize_ip_recognizes_subpages_of_ipv6_address() {
    test_sanitize_ip([
        ("1::1/IP_subpage", "1:0:0:0:0:0:0:1/IP_subpage"),
        ("1::1_/not_IP_subpage", "1::1_/not_IP_subpage"),
        ("1::g/not_IP_subpage", "1::g/not_IP_subpage"),
        // Despite its look, this isn't an IP range.
        ("1::1/24", "1:0:0:0:0:0:0:1/24"),
    ]);
}