raw_socket/
control.rs

1// Copyright (C) 2020 - Will Glozer. All rights reserved.
2
3use std::fmt;
4use std::iter;
5use std::mem::{size_of, zeroed};
6use std::net::Ipv6Addr;
7use std::ptr;
8use std::slice;
9use crate::ffi::*;
10
11#[derive(Debug)]
12pub enum CMsg<'a> {
13    Ipv6HopLimit(c_int),
14    Ipv6PathMtu(c_int),
15    Ipv6PktInfo(Ipv6PktInfo),
16    Raw(Raw<'a>),
17}
18
19pub struct Ipv6PktInfo(in6_pktinfo);
20
21#[derive(Debug)]
22pub struct Raw<'a> {
23    pub level: c_int,
24    pub kind:  c_int,
25    pub data:  &'a [u8],
26}
27
28#[derive(Copy, Clone, Debug)]
29pub enum Error {
30    BufferSize,
31}
32
33impl<'a> CMsg<'a> {
34    pub fn encode<'b>(buf: &'b mut [u8], msgs: &[CMsg]) -> Result<&'b [u8], Error> {
35        let mut n = 0;
36
37        unsafe {
38            let mut root   = message_header(buf);
39            let mut header = first_header(&mut root)?;
40
41            for msg in msgs {
42                let len = msg.size() as _;
43
44                (*header).cmsg_len   = CMSG_LEN(len) as _;
45                (*header).cmsg_level = msg.level();
46                (*header).cmsg_type  = msg.kind();
47
48                msg.write(CMSG_DATA(header));
49
50                n += CMSG_SPACE(len) as usize;
51
52                header = next_header(&root, header)?;
53            }
54        }
55
56        Ok(&buf[..n])
57    }
58
59    pub fn decode<'b>(buf: &'b [u8]) -> impl Iterator<Item = CMsg<'b>> {
60        unsafe {
61            let mut root = message_header(buf);
62            let mut next = first_header(&mut root);
63
64            iter::from_fn(move || {
65                let header = next.ok()?;
66
67                let len   = (*header).cmsg_len;
68                let level = (*header).cmsg_level;
69                let kind  = (*header).cmsg_type;
70
71                let ptr = CMSG_DATA(header);
72                let len = len as usize;
73
74                next = next_header(&root, header);
75
76                Self::read(level, kind, ptr, len)
77            })
78        }
79    }
80
81    fn level(&self) -> c_int {
82        match self {
83            Self::Ipv6HopLimit(..) => IPPROTO_IPV6,
84            Self::Ipv6PathMtu(..)  => IPPROTO_IPV6,
85            Self::Ipv6PktInfo(..)  => IPPROTO_IPV6,
86            Self::Raw(raw)         => raw.level,
87        }
88    }
89
90    fn kind(&self) -> c_int {
91        match self {
92            Self::Ipv6HopLimit(..) => IPV6_HOPLIMIT,
93            Self::Ipv6PathMtu(..)  => IPV6_PATHMTU,
94            Self::Ipv6PktInfo(..)  => IPV6_PKTINFO,
95            Self::Raw(raw)         => raw.kind,
96        }
97    }
98
99    fn size(&self) -> usize {
100        match self {
101            Self::Ipv6HopLimit(..) => size_of::<c_int>(),
102            Self::Ipv6PathMtu(..)  => size_of::<c_int>(),
103            Self::Ipv6PktInfo(..)  => size_of::<in6_pktinfo>(),
104            Self::Raw(raw)         => raw.data.len(),
105        }
106    }
107
108    unsafe fn read<'b>(level: c_int, kind: c_int, ptr: *const u8, len: usize) -> Option<CMsg<'b>> {
109        const INVALID: c_int = 0;
110
111        Some(match (level, kind) {
112            (IPPROTO_IPV6, IPV6_HOPLIMIT) => CMsg::Ipv6HopLimit(read(ptr)),
113            (IPPROTO_IPV6, IPV6_PATHMTU ) => CMsg::Ipv6PathMtu(read(ptr)),
114            (IPPROTO_IPV6, IPV6_PKTINFO ) => Ipv6PktInfo(read(ptr)).into(),
115            (INVALID     , INVALID      ) => return None,
116            (_           , _            ) => Raw::read(level, kind, ptr, len).into(),
117        })
118    }
119
120    unsafe fn write(&self, ptr: *mut u8) {
121        match self {
122            Self::Ipv6HopLimit(limit) => write(ptr, limit.to_le()),
123            Self::Ipv6PathMtu(mtu)    => write(ptr, mtu),
124            Self::Ipv6PktInfo(info)   => write(ptr, info.0),
125            Self::Raw(raw)            => raw.write(ptr),
126        }
127    }
128}
129
130unsafe fn message_header(buf: &[u8]) -> msghdr {
131    let mut msg: msghdr = zeroed();
132    msg.msg_control     = buf.as_ptr() as *mut _;
133    msg.msg_controllen  = buf.len()    as      _;
134    msg
135}
136
137unsafe fn first_header(msg: &mut msghdr) -> Result<*mut cmsghdr, Error> {
138    match CMSG_FIRSTHDR(msg) {
139        ptr if ptr.is_null() => Err(Error::BufferSize),
140        ptr                  => Ok(ptr),
141    }
142}
143
144unsafe fn next_header(msg: &msghdr, cmsg: *const cmsghdr) -> Result<*mut cmsghdr, Error> {
145    match CMSG_NXTHDR(msg, cmsg) {
146        ptr if ptr.is_null() => Err(Error::BufferSize),
147        ptr                  => Ok(ptr),
148    }
149}
150
151unsafe fn read<T>(src: *const u8) -> T {
152    ptr::read_unaligned(src as *const T)
153}
154
155unsafe fn write<T>(dst: *mut u8, src: T) {
156    ptr::write_unaligned(dst as *mut T, src);
157}
158
159impl Ipv6PktInfo {
160    pub fn addr(&self) -> Ipv6Addr {
161        Ipv6Addr::from(self.0.ipi6_addr.s6_addr)
162    }
163
164    pub fn ifindex(&self) -> u32 {
165        self.0.ipi6_ifindex as u32
166    }
167}
168
169impl<'a> Raw<'a> {
170    pub const fn from(level: c_int, kind: c_int, data: &'a [u8]) -> Self {
171        Self { level, kind, data }
172    }
173
174    unsafe fn read(level: c_int, kind: c_int, ptr: *const u8, len: usize) -> Self {
175        let len  = len - size_of::<cmsghdr>();
176        let data = slice::from_raw_parts(ptr, len);
177        Self { level, kind, data }
178    }
179
180    unsafe fn write(&self, ptr: *mut u8) {
181        let src = self.data.as_ptr();
182        let len = self.data.len();
183        ptr::copy_nonoverlapping(src, ptr, len);
184    }
185}
186
187impl<'a> From<Ipv6PktInfo> for CMsg<'a> {
188    fn from(info: Ipv6PktInfo) -> Self {
189        Self::Ipv6PktInfo(info)
190    }
191}
192
193impl<'a> From<Raw<'a>> for CMsg<'a> {
194    fn from(raw: Raw<'a>) -> Self {
195        Self::Raw(raw)
196    }
197}
198
199impl fmt::Debug for Ipv6PktInfo {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        let addr  = self.addr();
202        let ifidx = self.ifindex();
203        write!(f, "{{ addr: {}, ifindex: {} }}", addr, ifidx)
204    }
205}
206
207impl std::error::Error for Error {}
208
209impl fmt::Display for Error {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        write!(f, "{:?}", self)
212    }
213}