1use core::fmt;
29use core::str::FromStr;
30
31use super::IpAddr;
32
33#[cfg(feature = "serde")]
34use serde::{Deserialize, Serialize};
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39#[non_exhaustive]
40pub enum CidrError {
41 InvalidFormat,
46 InvalidPrefixLength,
52 InvalidIpAddr,
56}
57
58impl fmt::Display for CidrError {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 match self {
61 Self::InvalidFormat => write!(f, "invalid CIDR format"),
62 Self::InvalidPrefixLength => write!(f, "invalid prefix length"),
63 Self::InvalidIpAddr => write!(f, "invalid IP address"),
64 }
65 }
66}
67
68#[cfg(feature = "std")]
69impl std::error::Error for CidrError {}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
100#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
101pub struct Cidr {
102 address: IpAddr,
104 prefix_length: u8,
106}
107
108impl Cidr {
109 pub const fn new(address: IpAddr, prefix_length: u8) -> Result<Self, CidrError> {
126 let max_prefix = if address.as_inner().is_ipv4() {
127 32
128 } else {
129 128
130 };
131
132 if prefix_length > max_prefix {
133 return Err(CidrError::InvalidPrefixLength);
134 }
135
136 Ok(Self {
137 address,
138 prefix_length,
139 })
140 }
141
142 #[must_use]
153 #[inline]
154 pub const fn address(&self) -> IpAddr {
155 self.address
156 }
157
158 #[must_use]
169 #[inline]
170 pub const fn prefix_length(&self) -> u8 {
171 self.prefix_length
172 }
173
174 #[must_use]
188 pub fn network_address(&self) -> IpAddr {
189 let inner = self.address.as_inner();
190
191 if let core::net::IpAddr::V4(ip) = inner {
192 let octets = ip.octets();
193 let ip = u32::from_be_bytes([octets[0], octets[1], octets[2], octets[3]]);
194 let mask = u32::MAX << (32 - u32::from(self.prefix_length));
195 let network = ip & mask;
196 let network_octets = network.to_be_bytes();
197 IpAddr::new(core::net::IpAddr::V4(core::net::Ipv4Addr::new(
198 network_octets[0],
199 network_octets[1],
200 network_octets[2],
201 network_octets[3],
202 )))
203 } else if let core::net::IpAddr::V6(ip) = inner {
204 let segments = ip.segments();
205 let mut network_segments = [0u16; 8];
206
207 let full_segments = (self.prefix_length / 16) as usize;
208 let partial_bits = self.prefix_length % 16;
209
210 network_segments[..full_segments].copy_from_slice(&segments[..full_segments]);
211
212 if partial_bits > 0 && full_segments < 8 {
213 let mask = u16::MAX << (16 - u32::from(partial_bits));
214 network_segments[full_segments] = segments[full_segments] & mask;
215 }
216
217 IpAddr::new(core::net::IpAddr::V6(core::net::Ipv6Addr::new(
218 network_segments[0],
219 network_segments[1],
220 network_segments[2],
221 network_segments[3],
222 network_segments[4],
223 network_segments[5],
224 network_segments[6],
225 network_segments[7],
226 )))
227 } else {
228 self.address
229 }
230 }
231
232 #[must_use]
247 #[allow(clippy::missing_const_for_fn)]
248 pub fn broadcast_address(&self) -> Option<IpAddr> {
249 let inner = self.address.as_inner();
250
251 if let core::net::IpAddr::V4(ip) = inner {
252 let octets = ip.octets();
253 let ip = u32::from_be_bytes([octets[0], octets[1], octets[2], octets[3]]);
254 let mask = u32::MAX << (32 - u32::from(self.prefix_length));
255 let broadcast = ip | !mask;
256 let broadcast_octets = broadcast.to_be_bytes();
257 Some(IpAddr::new(core::net::IpAddr::V4(
258 core::net::Ipv4Addr::new(
259 broadcast_octets[0],
260 broadcast_octets[1],
261 broadcast_octets[2],
262 broadcast_octets[3],
263 ),
264 )))
265 } else {
266 None }
268 }
269
270 #[must_use]
285 pub fn contains(&self, ip: &IpAddr) -> bool {
286 let network = self.network_address();
287 let network_inner = network.as_inner();
288 let ip_inner = ip.as_inner();
289
290 match (network_inner, ip_inner) {
291 (core::net::IpAddr::V4(network), core::net::IpAddr::V4(ip)) => {
292 let network_octets = network.octets();
293 let network = u32::from_be_bytes([
294 network_octets[0],
295 network_octets[1],
296 network_octets[2],
297 network_octets[3],
298 ]);
299 let ip_octets = ip.octets();
300 let ip =
301 u32::from_be_bytes([ip_octets[0], ip_octets[1], ip_octets[2], ip_octets[3]]);
302 let mask = u32::MAX << (32 - u32::from(self.prefix_length));
303 (network & mask) == (ip & mask)
304 }
305 (core::net::IpAddr::V6(network), core::net::IpAddr::V6(ip)) => {
306 let network_segments = network.segments();
307 let ip_segments = ip.segments();
308
309 let full_segments = (self.prefix_length / 16) as usize;
310 let partial_bits = self.prefix_length % 16;
311
312 for i in 0..full_segments {
313 if network_segments[i] != ip_segments[i] {
314 return false;
315 }
316 }
317
318 if partial_bits > 0 && full_segments < 8 {
319 let mask = u16::MAX << (16 - u32::from(partial_bits));
320 if (network_segments[full_segments] & mask)
321 != (ip_segments[full_segments] & mask)
322 {
323 return false;
324 }
325 }
326
327 true
328 }
329 _ => false,
330 }
331 }
332
333 #[must_use]
344 #[allow(clippy::missing_const_for_fn)]
345 pub fn size(&self) -> u128 {
346 if self.address.as_inner().is_ipv4() {
347 1u128 << (32 - u32::from(self.prefix_length))
348 } else {
349 let shift = 128 - u32::from(self.prefix_length);
350 if shift == 128 {
351 u128::MAX
352 } else {
353 1u128 << shift
354 }
355 }
356 }
357}
358
359impl FromStr for Cidr {
360 type Err = CidrError;
361
362 fn from_str(s: &str) -> Result<Self, Self::Err> {
363 let parts: Vec<&str> = s.split('/').collect();
364
365 if parts.len() != 2 {
366 return Err(CidrError::InvalidFormat);
367 }
368
369 let address: IpAddr = parts[0].parse().map_err(|_| CidrError::InvalidIpAddr)?;
370
371 let prefix_length: u8 = parts[1]
372 .parse()
373 .map_err(|_| CidrError::InvalidPrefixLength)?;
374
375 Self::new(address, prefix_length)
376 }
377}
378
379impl fmt::Display for Cidr {
380 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381 write!(f, "{}/{}", self.address, self.prefix_length)
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[test]
390 fn test_cidr_creation() {
391 let ip: IpAddr = "192.168.1.0".parse().unwrap();
392 let cidr = Cidr::new(ip, 24).unwrap();
393 assert_eq!(cidr.address(), ip);
394 assert_eq!(cidr.prefix_length(), 24);
395 }
396
397 #[test]
398 fn test_cidr_parsing() {
399 let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
400 assert_eq!(cidr.address().to_string(), "192.168.1.0");
401 assert_eq!(cidr.prefix_length(), 24);
402 }
403
404 #[test]
405 fn test_invalid_prefix_length() {
406 let ip: IpAddr = "192.168.1.0".parse().unwrap();
407 assert!(Cidr::new(ip, 33).is_err());
408 }
409
410 #[test]
411 fn test_network_address() {
412 let cidr: Cidr = "192.168.1.100/24".parse().unwrap();
413 let network = cidr.network_address();
414 assert_eq!(network.to_string(), "192.168.1.0");
415 }
416
417 #[test]
418 fn test_broadcast_address() {
419 let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
420 let broadcast = cidr.broadcast_address().unwrap();
421 assert_eq!(broadcast.to_string(), "192.168.1.255");
422 }
423
424 #[test]
425 fn test_contains() {
426 let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
427
428 let ip1: IpAddr = "192.168.1.100".parse().unwrap();
429 let ip2: IpAddr = "192.168.2.100".parse().unwrap();
430
431 assert!(cidr.contains(&ip1));
432 assert!(!cidr.contains(&ip2));
433 }
434
435 #[test]
436 fn test_size() {
437 let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
438 assert_eq!(cidr.size(), 256);
439 }
440
441 #[test]
442 fn test_ipv6_cidr() {
443 let cidr: Cidr = "2001:db8::/32".parse().unwrap();
444 assert_eq!(cidr.prefix_length(), 32);
445
446 let ip: IpAddr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap();
447 assert!(cidr.contains(&ip));
448 }
449
450 #[test]
451 fn test_display() {
452 let cidr: Cidr = "192.168.1.0/24".parse().unwrap();
453 assert_eq!(format!("{}", cidr), "192.168.1.0/24");
454 }
455
456 #[test]
457 fn test_ipv6_network_address() {
458 let cidr: Cidr = "2001:db8:85a3::8a2e:370:7334/64".parse().unwrap();
459 let network = cidr.network_address();
460 assert_eq!(network.to_string(), "2001:db8:85a3::");
461 }
462
463 #[test]
464 fn test_ipv6_contains() {
465 let cidr: Cidr = "2001:db8::/32".parse().unwrap();
466 let ip1: IpAddr = "2001:db8:85a3::8a2e:370:7334".parse().unwrap();
467 let ip2: IpAddr = "2001:db9::1".parse().unwrap();
468
469 assert!(cidr.contains(&ip1));
470 assert!(!cidr.contains(&ip2));
471 }
472
473 #[test]
474 fn test_ipv6_size() {
475 let cidr: Cidr = "2001:db8::/32".parse().unwrap();
476 assert_eq!(cidr.size(), 1u128 << 96);
477 }
478
479 #[test]
480 fn test_ipv6_broadcast_none() {
481 let cidr: Cidr = "2001:db8::/32".parse().unwrap();
482 assert!(cidr.broadcast_address().is_none());
483 }
484
485 #[test]
486 fn test_ipv6_max_prefix() {
487 let cidr: Cidr = "2001:db8::/128".parse().unwrap();
488 assert_eq!(cidr.prefix_length(), 128);
489 assert_eq!(cidr.size(), 1);
490 }
491
492 #[test]
493 fn test_ipv6_zero_prefix() {
494 let cidr: Cidr = "2001:db8::/0".parse().unwrap();
495 assert_eq!(cidr.prefix_length(), 0);
496 assert_eq!(cidr.size(), u128::MAX);
497 }
498
499 #[test]
500 fn test_ipv4_max_prefix() {
501 let cidr: Cidr = "192.168.1.0/32".parse().unwrap();
502 assert_eq!(cidr.prefix_length(), 32);
503 assert_eq!(cidr.size(), 1);
504 }
505
506 #[test]
507 fn test_ipv4_zero_prefix() {
508 let cidr: Cidr = "192.168.1.0/0".parse().unwrap();
509 assert_eq!(cidr.prefix_length(), 0);
510 assert_eq!(cidr.size(), 1u128 << 32);
511 }
512}