use crate::inet::{
ip,
ipv6::{IpV6Address, SocketAddressV6},
unspecified::Unspecified,
ExplicitCongestionNotification,
};
use core::{fmt, mem::size_of, net};
use s2n_codec::zerocopy::U16;
const IPV4_LEN: usize = 32 / 8;
define_inet_type!(
pub struct IpV4Address {
octets: [u8; IPV4_LEN],
}
);
impl IpV4Address {
pub const UNSPECIFIED: Self = Self {
octets: [0; IPV4_LEN],
};
#[inline]
pub const fn unicast_scope(self) -> Option<ip::UnicastScope> {
use ip::UnicastScope::*;
match self.octets {
[0, 0, 0, 0] => None,
[0, _, _, _] => None,
[127, _, _, _] => Some(Loopback),
[10, _, _, _] => Some(Private),
[172, 16..=31, _, _] => Some(Private),
[192, 168, _, _] => Some(Private),
[100, 64..=127, _, _] => {
Some(Private)
}
[169, 254, _, _] => Some(LinkLocal),
[192, 0, 0, 9] => Some(Global),
[192, 0, 0, 10] => Some(Global),
[192, 0, 0, _] => None,
[198, 18..=19, _, _] => None,
[192, 0, 2, _] => None,
[198, 51, 100, _] => None,
[203, 0, 113, _] => None,
[233, 252, 0, _] => None,
[240..=255, _, _, _] => None,
_ => Some(Global),
}
}
#[inline]
pub const fn to_ipv6_mapped(self) -> IpV6Address {
let mut addr = [0; size_of::<IpV6Address>()];
let [a, b, c, d] = self.octets;
addr[10] = 0xFF;
addr[11] = 0xFF;
addr[12] = a;
addr[13] = b;
addr[14] = c;
addr[15] = d;
IpV6Address { octets: addr }
}
#[inline]
pub fn with_port(self, port: u16) -> SocketAddressV4 {
SocketAddressV4 {
ip: self,
port: port.into(),
}
}
}
impl fmt::Debug for IpV4Address {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "IPv4Address({self})")
}
}
impl fmt::Display for IpV4Address {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
let octets = &self.octets;
write!(
fmt,
"{}.{}.{}.{}",
octets[0], octets[1], octets[2], octets[3]
)
}
}
impl Unspecified for IpV4Address {
#[inline]
fn is_unspecified(&self) -> bool {
Self::UNSPECIFIED.eq(self)
}
}
test_inet_snapshot!(ipv4, ipv4_snapshot_test, IpV4Address);
define_inet_type!(
pub struct SocketAddressV4 {
ip: IpV4Address,
port: U16,
}
);
impl SocketAddressV4 {
pub const UNSPECIFIED: Self = Self {
ip: IpV4Address::UNSPECIFIED,
port: U16::ZERO,
};
#[inline]
pub const fn ip(&self) -> &IpV4Address {
&self.ip
}
#[inline]
pub fn port(self) -> u16 {
self.port.into()
}
#[inline]
pub fn set_port(&mut self, port: u16) {
self.port.set(port)
}
#[inline]
pub const fn unicast_scope(&self) -> Option<ip::UnicastScope> {
self.ip.unicast_scope()
}
#[inline]
pub const fn to_ipv6_mapped(self) -> SocketAddressV6 {
let ip = self.ip().to_ipv6_mapped();
let port = self.port;
SocketAddressV6 { ip, port }
}
}
impl fmt::Debug for SocketAddressV4 {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "SocketAddressV4({self})")
}
}
impl fmt::Display for SocketAddressV4 {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}:{}", self.ip, self.port)
}
}
impl Unspecified for SocketAddressV4 {
#[inline]
fn is_unspecified(&self) -> bool {
Self::UNSPECIFIED.eq(self)
}
}
impl From<net::Ipv4Addr> for IpV4Address {
fn from(address: net::Ipv4Addr) -> Self {
(&address).into()
}
}
impl From<&net::Ipv4Addr> for IpV4Address {
fn from(address: &net::Ipv4Addr) -> Self {
address.octets().into()
}
}
impl From<IpV4Address> for net::Ipv4Addr {
fn from(address: IpV4Address) -> Self {
address.octets.into()
}
}
impl From<net::SocketAddrV4> for SocketAddressV4 {
fn from(address: net::SocketAddrV4) -> Self {
let ip = address.ip().into();
let port = address.port().into();
Self { ip, port }
}
}
impl From<(net::Ipv4Addr, u16)> for SocketAddressV4 {
fn from((ip, port): (net::Ipv4Addr, u16)) -> Self {
Self::new(ip, port)
}
}
impl From<SocketAddressV4> for net::SocketAddrV4 {
fn from(address: SocketAddressV4) -> Self {
let ip = address.ip.into();
let port = address.port.into();
Self::new(ip, port)
}
}
impl From<&SocketAddressV4> for net::SocketAddrV4 {
fn from(address: &SocketAddressV4) -> Self {
let ip = address.ip.into();
let port = address.port.into();
Self::new(ip, port)
}
}
impl From<SocketAddressV4> for net::SocketAddr {
fn from(address: SocketAddressV4) -> Self {
let addr: net::SocketAddrV4 = address.into();
addr.into()
}
}
impl From<&SocketAddressV4> for net::SocketAddr {
fn from(address: &SocketAddressV4) -> Self {
let addr: net::SocketAddrV4 = address.into();
addr.into()
}
}
test_inet_snapshot!(socket_v4, socket_v4_snapshot_test, SocketAddressV4);
impl From<[u8; IPV4_LEN]> for IpV4Address {
#[inline]
fn from(octets: [u8; IPV4_LEN]) -> Self {
Self { octets }
}
}
impl From<IpV4Address> for [u8; IPV4_LEN] {
#[inline]
fn from(address: IpV4Address) -> Self {
address.octets
}
}
define_inet_type!(
pub struct Header {
vihl: Vihl,
tos: Tos,
total_len: U16,
id: U16,
flag_fragment: FlagFragment,
ttl: u8,
protocol: ip::Protocol,
checksum: U16,
source: IpV4Address,
destination: IpV4Address,
}
);
impl fmt::Debug for Header {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ipv4::Header")
.field("version", &self.vihl.version())
.field("header_len", &self.vihl.header_len())
.field("dscp", &self.tos.dscp())
.field("ecn", &self.tos.ecn())
.field("total_len", &self.total_len)
.field("id", &format_args!("0x{:04x}", self.id.get()))
.field("flags (reserved)", &self.flag_fragment.reserved())
.field(
"flags (don't fragment)",
&self.flag_fragment.dont_fragment(),
)
.field(
"flags (more fragments)",
&self.flag_fragment.more_fragments(),
)
.field("fragment_offset", &self.flag_fragment.fragment_offset())
.field("ttl", &self.ttl)
.field("protocol", &self.protocol)
.field("checksum", &format_args!("0x{:04x}", self.checksum.get()))
.field("source", &self.source)
.field("destination", &self.destination)
.finish()
}
}
impl Header {
#[inline]
pub fn swap(&mut self) {
core::mem::swap(&mut self.source, &mut self.destination)
}
#[inline]
pub const fn vihl(&self) -> &Vihl {
&self.vihl
}
#[inline]
pub fn vihl_mut(&mut self) -> &mut Vihl {
&mut self.vihl
}
#[inline]
pub const fn tos(&self) -> &Tos {
&self.tos
}
#[inline]
pub fn tos_mut(&mut self) -> &mut Tos {
&mut self.tos
}
#[inline]
pub const fn total_len(&self) -> &U16 {
&self.total_len
}
#[inline]
pub fn total_len_mut(&mut self) -> &mut U16 {
&mut self.total_len
}
#[inline]
pub const fn id(&self) -> &U16 {
&self.id
}
#[inline]
pub fn id_mut(&mut self) -> &mut U16 {
&mut self.id
}
#[inline]
pub const fn flag_fragment(&self) -> &FlagFragment {
&self.flag_fragment
}
#[inline]
pub fn flag_fragment_mut(&mut self) -> &mut FlagFragment {
&mut self.flag_fragment
}
#[inline]
pub const fn ttl(&self) -> &u8 {
&self.ttl
}
#[inline]
pub fn ttl_mut(&mut self) -> &mut u8 {
&mut self.ttl
}
#[inline]
pub const fn protocol(&self) -> &ip::Protocol {
&self.protocol
}
#[inline]
pub fn protocol_mut(&mut self) -> &mut ip::Protocol {
&mut self.protocol
}
#[inline]
pub const fn checksum(&self) -> &U16 {
&self.checksum
}
#[inline]
pub fn checksum_mut(&mut self) -> &mut U16 {
&mut self.checksum
}
#[inline]
pub const fn source(&self) -> &IpV4Address {
&self.source
}
#[inline]
pub fn source_mut(&mut self) -> &mut IpV4Address {
&mut self.source
}
#[inline]
pub const fn destination(&self) -> &IpV4Address {
&self.destination
}
#[inline]
pub fn destination_mut(&mut self) -> &mut IpV4Address {
&mut self.destination
}
#[inline]
pub fn update_checksum(&mut self) {
use core::hash::Hasher;
self.checksum.set(0);
let bytes = self.as_bytes();
let mut checksum = crate::inet::checksum::Checksum::generic();
checksum.write(bytes);
self.checksum.set_be(checksum.finish_be());
}
}
define_inet_type!(
pub struct Vihl {
value: u8,
}
);
impl fmt::Debug for Vihl {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Vihl")
.field("version", &self.version())
.field("header_len", &self.header_len())
.finish()
}
}
impl Vihl {
#[inline]
pub fn version(&self) -> u8 {
self.value >> 4
}
#[inline]
pub fn set_version(&mut self, value: u8) -> &mut Self {
self.value = (value << 4) | (self.value & 0x0F);
self
}
#[inline]
pub fn header_len(&self) -> u8 {
self.value & 0x0F
}
#[inline]
pub fn set_header_len(&mut self, value: u8) -> &mut Self {
self.value = (self.value & 0xF0) | (value & 0x0F);
self
}
}
define_inet_type!(
pub struct Tos {
value: u8,
}
);
impl Tos {
#[inline]
pub fn dscp(&self) -> u8 {
self.value >> 2
}
#[inline]
pub fn set_dscp(&mut self, value: u8) -> &mut Self {
self.value = (value << 2) | (self.value & 0b11);
self
}
#[inline]
pub fn ecn(&self) -> ExplicitCongestionNotification {
ExplicitCongestionNotification::new(self.value & 0b11)
}
#[inline]
pub fn set_ecn(&mut self, ecn: ExplicitCongestionNotification) -> &mut Self {
self.value = (self.value & !0b11) | ecn as u8;
self
}
}
impl fmt::Debug for Tos {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ipv4::Tos")
.field("dscp", &self.dscp())
.field("ecn", &self.ecn())
.finish()
}
}
define_inet_type!(
pub struct FlagFragment {
value: U16,
}
);
impl FlagFragment {
const FRAGMENT_MASK: u16 = 0b0001_1111_1111_1111;
#[inline]
pub fn reserved(&self) -> bool {
self.get(1 << 15)
}
pub fn set_reserved(&mut self, enabled: bool) -> &mut Self {
self.set(1 << 15, enabled)
}
#[inline]
pub fn dont_fragment(&self) -> bool {
self.get(1 << 14)
}
#[inline]
pub fn set_dont_fragment(&mut self, enabled: bool) -> &mut Self {
self.set(1 << 14, enabled)
}
#[inline]
pub fn more_fragments(&self) -> bool {
self.get(1 << 13)
}
#[inline]
pub fn set_more_fragments(&mut self, enabled: bool) -> &mut Self {
self.set(1 << 13, enabled)
}
#[inline]
pub fn fragment_offset(&self) -> u16 {
self.value.get() & Self::FRAGMENT_MASK
}
#[inline]
pub fn set_fragment_offset(&mut self, offset: u16) -> &mut Self {
self.value
.set(self.value.get() & !Self::FRAGMENT_MASK | offset & Self::FRAGMENT_MASK);
self
}
#[inline]
fn get(&self, mask: u16) -> bool {
self.value.get() & mask == mask
}
#[inline]
fn set(&mut self, mask: u16, enabled: bool) -> &mut Self {
let value = self.value.get();
let value = if enabled { value | mask } else { value & !mask };
self.value.set(value);
self
}
}
impl fmt::Debug for FlagFragment {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ipv4::FlagFragment")
.field("reserved", &self.reserved())
.field("dont_fragment", &self.dont_fragment())
.field("more_fragments", &self.more_fragments())
.field("fragment_offset", &self.fragment_offset())
.finish()
}
}
#[cfg(any(test, feature = "std"))]
mod std_conversion {
use super::*;
use std::net;
impl net::ToSocketAddrs for SocketAddressV4 {
type Iter = std::iter::Once<net::SocketAddr>;
fn to_socket_addrs(&self) -> std::io::Result<Self::Iter> {
let ip = self.ip.into();
let port = self.port.into();
let addr = net::SocketAddrV4::new(ip, port);
Ok(std::iter::once(addr.into()))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use bolero::{check, generator::*};
use s2n_codec::{DecoderBuffer, DecoderBufferMut};
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(5), kani::solver(kissat))]
fn scope_test() {
let g = produce::<[u8; 4]>().map_gen(IpV4Address::from);
check!().with_generator(g).cloned().for_each(|subject| {
use ip::UnicastScope::*;
let expected = std::net::Ipv4Addr::from(subject);
let network = ip_network::Ipv4Network::from(expected);
match subject.unicast_scope() {
Some(Global) => {
if subject.octets == [192, 0, 0, 9] || subject.octets == [192, 0, 0, 10] {
return;
}
assert!(network.is_global());
}
Some(Private) => {
assert!(expected.is_private() || network.is_shared_address_space());
}
Some(Loopback) => {
assert!(expected.is_loopback());
}
Some(LinkLocal) => {
assert!(expected.is_link_local());
}
None => {
assert!(
expected.is_broadcast()
|| expected.is_multicast()
|| expected.is_documentation()
|| network.is_benchmarking()
|| network.is_ietf_protocol_assignments()
|| network.is_reserved()
|| network.is_local_identification()
|| network.is_unspecified()
);
}
}
})
}
#[test]
#[cfg_attr(miri, ignore)]
fn snapshot_test() {
let mut buffer = vec![0u8; core::mem::size_of::<Header>()];
for (idx, byte) in buffer.iter_mut().enumerate() {
*byte = idx as u8;
}
let decoder = DecoderBuffer::new(&buffer);
let (header, _) = decoder.decode::<&Header>().unwrap();
insta::assert_debug_snapshot!("snapshot_test", header);
buffer.fill(255);
let decoder = DecoderBuffer::new(&buffer);
let (header, _) = decoder.decode::<&Header>().unwrap();
insta::assert_debug_snapshot!("snapshot_filled_test", header);
}
#[test]
#[cfg_attr(kani, kani::proof, kani::unwind(5), kani::solver(kissat))]
fn header_getter_setter_test() {
check!().with_type::<Header>().for_each(|expected| {
let mut buffer = [255u8; core::mem::size_of::<Header>()];
let decoder = DecoderBufferMut::new(&mut buffer);
let (header, _) = decoder.decode::<&mut Header>().unwrap();
{
header
.vihl_mut()
.set_version(expected.vihl().version())
.set_header_len(expected.vihl().header_len());
header
.tos_mut()
.set_dscp(expected.tos().dscp())
.set_ecn(expected.tos().ecn());
header.id_mut().set(expected.id().get());
header.total_len_mut().set(expected.total_len().get());
header
.flag_fragment_mut()
.set_reserved(expected.flag_fragment().reserved())
.set_dont_fragment(expected.flag_fragment().dont_fragment())
.set_more_fragments(expected.flag_fragment().more_fragments())
.set_fragment_offset(expected.flag_fragment().fragment_offset());
*header.ttl_mut() = *expected.ttl();
*header.protocol_mut() = *expected.protocol();
header.checksum_mut().set(expected.checksum().get());
*header.source_mut() = *expected.source();
*header.destination_mut() = *expected.destination();
}
let decoder = DecoderBuffer::new(&buffer);
let (actual, _) = decoder.decode::<&Header>().unwrap();
{
assert_eq!(expected, actual);
assert_eq!(expected.vihl(), actual.vihl());
assert_eq!(expected.vihl().version(), actual.vihl().version());
assert_eq!(expected.vihl().header_len(), actual.vihl().header_len());
assert_eq!(expected.tos(), actual.tos());
assert_eq!(expected.tos().dscp(), actual.tos().dscp());
assert_eq!(expected.tos().ecn(), actual.tos().ecn());
assert_eq!(expected.id(), actual.id());
assert_eq!(expected.total_len(), actual.total_len());
assert_eq!(expected.flag_fragment(), actual.flag_fragment());
assert_eq!(
expected.flag_fragment().reserved(),
actual.flag_fragment().reserved()
);
assert_eq!(
expected.flag_fragment().dont_fragment(),
actual.flag_fragment().dont_fragment()
);
assert_eq!(
expected.flag_fragment().more_fragments(),
actual.flag_fragment().more_fragments()
);
assert_eq!(expected.ttl(), actual.ttl());
assert_eq!(expected.protocol(), actual.protocol());
assert_eq!(expected.checksum(), actual.checksum());
assert_eq!(expected.source(), actual.source());
assert_eq!(expected.destination(), actual.destination());
}
})
}
#[test]
fn header_round_trip_test() {
check!().for_each(|buffer| {
s2n_codec::assert_codec_round_trip_bytes!(Header, buffer);
});
}
}