rhymuri/
validate_ipv4_address.rs

1#![warn(clippy::pedantic)]
2
3use super::{
4    character_classes::DIGIT,
5    context::Context,
6    error::Error,
7};
8
9struct Shared {
10    num_groups: usize,
11    octet_buffer: String,
12}
13
14enum State {
15    NotInOctet(Shared),
16    ExpectDigitOrDot(Shared),
17}
18
19impl State {
20    fn finalize(self) -> Result<(), Error> {
21        match self {
22            Self::NotInOctet(_) => Err(Error::TruncatedHost),
23            Self::ExpectDigitOrDot(state) => {
24                Self::finalize_expect_digit_or_dot(state)
25            },
26        }
27    }
28
29    fn finalize_expect_digit_or_dot(state: Shared) -> Result<(), Error> {
30        let mut state = state;
31        if !state.octet_buffer.is_empty() {
32            state.num_groups += 1;
33            if state.octet_buffer.parse::<u8>().is_err() {
34                return Err(Error::InvalidDecimalOctet);
35            }
36        }
37        match state.num_groups {
38            4 => Ok(()),
39            n if n < 4 => Err(Error::TooFewAddressParts),
40            _ => Err(Error::TooManyAddressParts),
41        }
42    }
43
44    fn new() -> Self {
45        Self::NotInOctet(Shared {
46            num_groups: 0,
47            octet_buffer: String::new(),
48        })
49    }
50
51    fn next(
52        self,
53        c: char,
54    ) -> Result<Self, Error> {
55        match self {
56            Self::NotInOctet(state) => Self::next_not_in_octet(state, c),
57            Self::ExpectDigitOrDot(state) => {
58                Self::next_expect_digit_or_dot(state, c)
59            },
60        }
61    }
62
63    fn next_not_in_octet(
64        state: Shared,
65        c: char,
66    ) -> Result<Self, Error> {
67        let mut state = state;
68        if DIGIT.contains(&c) {
69            state.octet_buffer.push(c);
70            Ok(Self::ExpectDigitOrDot(state))
71        } else {
72            Err(Error::IllegalCharacter(Context::Ipv4Address))
73        }
74    }
75
76    fn next_expect_digit_or_dot(
77        state: Shared,
78        c: char,
79    ) -> Result<Self, Error> {
80        let mut state = state;
81        if c == '.' {
82            state.num_groups += 1;
83            if state.num_groups > 4 {
84                return Err(Error::TooManyAddressParts);
85            }
86            if state.octet_buffer.parse::<u8>().is_err() {
87                return Err(Error::InvalidDecimalOctet);
88            }
89            state.octet_buffer.clear();
90            Ok(Self::NotInOctet(state))
91        } else if DIGIT.contains(&c) {
92            state.octet_buffer.push(c);
93            Ok(Self::ExpectDigitOrDot(state))
94        } else {
95            Err(Error::IllegalCharacter(Context::Ipv4Address))
96        }
97    }
98}
99
100pub fn validate_ipv4_address<T>(address: T) -> Result<(), Error>
101where
102    T: AsRef<str>,
103{
104    address.as_ref().chars().try_fold(State::new(), State::next)?.finalize()
105}
106
107#[cfg(test)]
108mod tests {
109
110    use super::*;
111
112    #[test]
113    fn good() {
114        let test_vectors = [
115            "0.0.0.0",
116            "1.2.3.0",
117            "1.2.3.4",
118            "1.2.3.255",
119            "1.2.255.4",
120            "1.255.3.4",
121            "255.2.3.4",
122            "255.255.255.255",
123        ];
124        for test_vector in &test_vectors {
125            assert!(validate_ipv4_address(*test_vector).is_ok());
126        }
127    }
128
129    #[test]
130    // NOTE: This lint is disabled because it's triggered inside the
131    // `named_tuple!` macro expansion.
132    #[allow(clippy::from_over_into)]
133    fn bad() {
134        named_tuple!(
135            struct TestVector {
136                address_string: &'static str,
137                expected_error: Error,
138            }
139        );
140        let test_vectors: &[TestVector] = &[
141            ("1.2.x.4", Error::IllegalCharacter(Context::Ipv4Address)).into(),
142            ("1.2.3.4.8", Error::TooManyAddressParts).into(),
143            ("1.2.3", Error::TooFewAddressParts).into(),
144            ("1.2.3.", Error::TruncatedHost).into(),
145            ("1.2.3.256", Error::InvalidDecimalOctet).into(),
146            ("1.2.3.-4", Error::IllegalCharacter(Context::Ipv4Address)).into(),
147            ("1.2.3. 4", Error::IllegalCharacter(Context::Ipv4Address)).into(),
148            ("1.2.3.4 ", Error::IllegalCharacter(Context::Ipv4Address)).into(),
149        ];
150        for test_vector in test_vectors {
151            let result = validate_ipv4_address(test_vector.address_string());
152            assert!(result.is_err(), "{}", test_vector.address_string());
153            assert_eq!(
154                *test_vector.expected_error(),
155                result.unwrap_err(),
156                "{}",
157                test_vector.address_string()
158            );
159        }
160    }
161}