1use core::fmt;
29use core::str::FromStr;
30
31use super::IpAddr;
32
33#[cfg(feature = "serde")]
34use serde::{Deserialize, Serialize};
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39#[non_exhaustive]
40pub enum CidrError {
41 InvalidFormat,
46 InvalidPrefixLength,
52 InvalidIpAddr,
56}
57
58impl fmt::Display for CidrError {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 match self {
61 Self::InvalidFormat => write!(f, "invalid CIDR format"),
62 Self::InvalidPrefixLength => write!(f, "invalid prefix length"),
63 Self::InvalidIpAddr => write!(f, "invalid IP address"),
64 }
65 }
66}
67
68#[cfg(feature = "std")]
69impl std::error::Error for CidrError {}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
100#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
101pub struct Cidr {
102 address: IpAddr,
104 prefix_length: u8,
106}
107
108impl Cidr {
109 pub const fn new(address: IpAddr, prefix_length: u8) -> Result<Self, CidrError> {
126 let max_prefix = if address.as_inner().is_ipv4() {
127 32
128 } else {
129 128
130 };
131
132 if prefix_length > max_prefix {
133 return Err(CidrError::InvalidPrefixLength);
134 }
135
136 Ok(Self {
137 address,
138 prefix_length,
139 })
140 }
141
142 #[must_use]
153 #[inline]
154 pub const fn address(&self) -> IpAddr {
155 self.address
156 }
157
158 #[must_use]
169 #[inline]
170 pub const fn prefix_length(&self) -> u8 {
171 self.prefix_length
172 }
173
174 #[must_use]
188 pub fn network_address(&self) -> IpAddr {
189 let inner = self.address.as_inner();
190
191 if let core::net::IpAddr::V4(ip) = inner {
192 let octets = ip.octets();
193 let ip = u32::from_be_bytes([octets[0], octets[1], octets[2], octets[3]]);
194 let mask = u32::MAX << (32 - u32::from(self.prefix_length));
195 let network = ip & mask;
196 let network_octets = network.to_be_bytes();
197 IpAddr::new(core::net::IpAddr::V4(core::net::Ipv4Addr::new(
198 network_octets[0],
199 network_octets[1],
200 network_octets[2],
201 network_octets[3],
202 )))
203 } else if let core::net::IpAddr::V6(ip) = inner {
204 let segments = ip.segments();
205 let mut network_segments = [0u16; 8];
206
207 let full_segments = (self.prefix_length / 16) as usize;
208 let partial_bits = self.prefix_length % 16;
209
210 network_segments[..full_segments].copy_from_slice(&segments[..full_segments]);
211
212 if partial_bits > 0 && full_segments < 8 {
213 let mask = u16::MAX << (16 - u32::from(partial_bits));
214 network_segments[full_segments] = segments[full_segments] & mask;
215 }
216
217 IpAddr::new(core::net::IpAddr::V6(core::net::Ipv6Addr::new(
218 network_segments[0],
219 network_segments[1],
220 network_segments[2],
221 network_segments[3],
222 network_segments[4],
223 network_segments[5],
224 network_segments[6],
225 network_segments[7],
226 )))
227 } else {
228 self.address
229 }
230 }
231
232 #[must_use]
247 #[allow(clippy::missing_const_for_fn)]
248 pub fn broadcast_address(&self) -> Option<IpAddr> {
249 let inner = self.address.as_inner();
250
251 if let core::net::IpAddr::V4(ip) = inner {
252 let octets = ip.octets();
253 let ip = u32::from_be_bytes([octets[0], octets[1], octets[2], octets[3]]);
254 let mask = u32::MAX << (32 - u32::from(self.prefix_length));
255 let broadcast = ip | !mask;
256 let broadcast_octets = broadcast.to_be_bytes();
257 Some(IpAddr::new(core::net::IpAddr::V4(
258 core::net::Ipv4Addr::new(
259 broadcast_octets[0],
260 broadcast_octets[1],
261 broadcast_octets[2],
262 broadcast_octets[3],
263 ),
264 )))
265 } else {
266 None }
268 }
269
270 #[must_use]
285 pub fn contains(&self, ip: &IpAddr) -> bool {
286 let network = self.network_address();
287 let network_inner = network.as_inner();
288 let ip_inner = ip.as_inner();
289
290 match (network_inner, ip_inner) {
291 (core::net::IpAddr::V4(network), core::net::IpAddr::V4(ip)) => {
292 let network_octets = network.octets();
293 let network = u32::from_be_bytes([
294 network_octets[0],
295 network_octets[1],
296 network_octets[2],
297 network_octets[3],
298 ]);
299 let ip_octets = ip.octets();
300 let ip =
301 u32::from_be_bytes([ip_octets[0], ip_octets[1], ip_octets[2], ip_octets[3]]);
302 let mask = u32::MAX << (32 - u32::from(self.prefix_length));
303 (network & mask) == (ip & mask)
304 }
305 (core::net::IpAddr::V6(network), core::net::IpAddr::V6(ip)) => {
306 let network_segments = network.segments();
307 let ip_segments = ip.segments();
308
309 let full_segments = (self.prefix_length / 16) as usize;
310 let partial_bits = self.prefix_length % 16;
311
312 for i in 0..full_segments {
313 if network_segments[i] != ip_segments[i] {
314 return false;
315 }
316 }
317
318 if partial_bits > 0 && full_segments < 8 {
319 let mask = u16::MAX << (16 - u32::from(partial_bits));
320 if (network_segments[full_segments] & mask)
321 != (ip_segments[full_segments] & mask)
322 {
323 return false;
324 }
325 }
326
327 true
328 }
329 _ => false,
330 }
331 }
332
333 #[must_use]
344 #[allow(clippy::missing_const_for_fn)]
345 pub fn size(&self) -> u128 {
346 if self.address.as_inner().is_ipv4() {
347 1u128 << (32 - u32::from(self.prefix_length))
348 } else {
349 let shift = 128 - u32::from(self.prefix_length);
350 if shift == 128 {
351 u128::MAX
352 } else {
353 1u128 << shift
354 }
355 }
356 }
357}
358
359#[cfg(feature = "arbitrary")]
360impl<'a> arbitrary::Arbitrary<'a> for Cidr {
361 fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
362 let address = IpAddr::arbitrary(u)?;
363
364 let max_prefix = if address.as_inner().is_ipv4() {
365 32
366 } else {
367 128
368 };
369
370 let prefix_length = u8::arbitrary(u)? % (max_prefix + 1);
371
372 Ok(Self {
373 address,
374 prefix_length,
375 })
376 }
377}
378
379impl FromStr for Cidr {
380 type Err = CidrError;
381
382 fn from_str(s: &str) -> Result<Self, Self::Err> {
383 let parts: Vec<&str> = s.split('/').collect();
384
385 if parts.len() != 2 {
386 return Err(CidrError::InvalidFormat);
387 }
388
389 let address: IpAddr = parts[0].parse().map_err(|_| CidrError::InvalidIpAddr)?;
390
391 let prefix_length: u8 = parts[1]
392 .parse()
393 .map_err(|_| CidrError::InvalidPrefixLength)?;
394
395 Self::new(address, prefix_length)
396 }
397}
398
399impl fmt::Display for Cidr {
400 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401 write!(f, "{}/{}", self.address, self.prefix_length)
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_cidr_creation() {
411 let ip: IpAddr = "192.168.1.0".parse().unwrap();
412 let cidr = Cidr::new(ip, 24).unwrap();
413 assert_eq!(cidr.address(), ip);
414 assert_eq!(cidr.prefix_length(), 24);
415 }
416
417 #[test]
418 fn test_cidr_parsing() {
419 let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
420 assert_eq!(cidr.address().to_string(), "192.168.1.0");
421 assert_eq!(cidr.prefix_length(), 24);
422 }
423
424 #[test]
425 fn test_invalid_prefix_length() {
426 let ip: IpAddr = "192.168.1.0".parse().unwrap();
427 assert!(Cidr::new(ip, 33).is_err());
428 }
429
430 #[test]
431 fn test_network_address() {
432 let cidr: Cidr = "192.168.1.100/24".parse().unwrap();
433 let network = cidr.network_address();
434 assert_eq!(network.to_string(), "192.168.1.0");
435 }
436
437 #[test]
438 fn test_broadcast_address() {
439 let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
440 let broadcast = cidr.broadcast_address().unwrap();
441 assert_eq!(broadcast.to_string(), "192.168.1.255");
442 }
443
444 #[test]
445 fn test_contains() {
446 let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
447
448 let ip1: IpAddr = "192.168.1.100".parse().unwrap();
449 let ip2: IpAddr = "192.168.2.100".parse().unwrap();
450
451 assert!(cidr.contains(&ip1));
452 assert!(!cidr.contains(&ip2));
453 }
454
455 #[test]
456 fn test_size() {
457 let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
458 assert_eq!(cidr.size(), 256);
459 }
460
461 #[test]
462 fn test_ipv6_cidr() {
463 let cidr: Cidr = "2001:db8::/32".parse().unwrap();
464 assert_eq!(cidr.prefix_length(), 32);
465
466 let ip: IpAddr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap();
467 assert!(cidr.contains(&ip));
468 }
469
470 #[test]
471 fn test_display() {
472 let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
473 assert_eq!(format!("{}", cidr), "192.168.1.0/24");
474 }
475
476 #[test]
477 fn test_ipv6_network_address() {
478 let cidr: Cidr = "2001:db8:85a3::8a2e:370:7334/64".parse().unwrap();
479 let network = cidr.network_address();
480 assert_eq!(network.to_string(), "2001:db8:85a3::");
481 }
482
483 #[test]
484 fn test_ipv6_contains() {
485 let cidr: Cidr = "2001:db8::/32".parse().unwrap();
486 let ip1: IpAddr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap();
487 let ip2: IpAddr = "2001:db9::1".parse().unwrap();
488
489 assert!(cidr.contains(&ip1));
490 assert!(!cidr.contains(&ip2));
491 }
492
493 #[test]
494 fn test_ipv6_size() {
495 let cidr: Cidr = "2001:db8::/32".parse().unwrap();
496 assert_eq!(cidr.size(), 1u128 << 96);
497 }
498
499 #[test]
500 fn test_ipv6_broadcast_none() {
501 let cidr: Cidr = "2001:db8::/32".parse().unwrap();
502 assert!(cidr.broadcast_address().is_none());
503 }
504
505 #[test]
506 fn test_ipv6_max_prefix() {
507 let cidr: Cidr = "2001:db8::/128".parse().unwrap();
508 assert_eq!(cidr.prefix_length(), 128);
509 assert_eq!(cidr.size(), 1);
510 }
511
512 #[test]
513 fn test_ipv6_zero_prefix() {
514 let cidr: Cidr = "2001:db8::/0".parse().unwrap();
515 assert_eq!(cidr.prefix_length(), 0);
516 assert_eq!(cidr.size(), u128::MAX);
517 }
518
519 #[test]
520 fn test_ipv4_max_prefix() {
521 let cidr: Cidr = "192.168.1.0/32".parse().unwrap();
522 assert_eq!(cidr.prefix_length(), 32);
523 assert_eq!(cidr.size(), 1);
524 }
525
526 #[test]
527 fn test_ipv4_zero_prefix() {
528 let cidr: Cidr = "192.168.1.0/0".parse().unwrap();
529 assert_eq!(cidr.prefix_length(), 0);
530 assert_eq!(cidr.size(), 1u128 << 32);
531 }
532}