1use std::fmt;
4use std::iter;
5use std::mem::{size_of, zeroed};
6use std::net::Ipv6Addr;
7use std::ptr;
8use std::slice;
9use crate::ffi::*;
10
11#[derive(Debug)]
12pub enum CMsg<'a> {
13 Ipv6HopLimit(c_int),
14 Ipv6PathMtu(c_int),
15 Ipv6PktInfo(Ipv6PktInfo),
16 Raw(Raw<'a>),
17}
18
19pub struct Ipv6PktInfo(in6_pktinfo);
20
21#[derive(Debug)]
22pub struct Raw<'a> {
23 pub level: c_int,
24 pub kind: c_int,
25 pub data: &'a [u8],
26}
27
28#[derive(Copy, Clone, Debug)]
29pub enum Error {
30 BufferSize,
31}
32
33impl<'a> CMsg<'a> {
34 pub fn encode<'b>(buf: &'b mut [u8], msgs: &[CMsg]) -> Result<&'b [u8], Error> {
35 let mut n = 0;
36
37 unsafe {
38 let mut root = message_header(buf);
39 let mut header = first_header(&mut root)?;
40
41 for msg in msgs {
42 let len = msg.size() as _;
43
44 (*header).cmsg_len = CMSG_LEN(len) as _;
45 (*header).cmsg_level = msg.level();
46 (*header).cmsg_type = msg.kind();
47
48 msg.write(CMSG_DATA(header));
49
50 n += CMSG_SPACE(len) as usize;
51
52 header = next_header(&root, header)?;
53 }
54 }
55
56 Ok(&buf[..n])
57 }
58
59 pub fn decode<'b>(buf: &'b [u8]) -> impl Iterator<Item = CMsg<'b>> {
60 unsafe {
61 let mut root = message_header(buf);
62 let mut next = first_header(&mut root);
63
64 iter::from_fn(move || {
65 let header = next.ok()?;
66
67 let len = (*header).cmsg_len;
68 let level = (*header).cmsg_level;
69 let kind = (*header).cmsg_type;
70
71 let ptr = CMSG_DATA(header);
72 let len = len as usize;
73
74 next = next_header(&root, header);
75
76 Self::read(level, kind, ptr, len)
77 })
78 }
79 }
80
81 fn level(&self) -> c_int {
82 match self {
83 Self::Ipv6HopLimit(..) => IPPROTO_IPV6,
84 Self::Ipv6PathMtu(..) => IPPROTO_IPV6,
85 Self::Ipv6PktInfo(..) => IPPROTO_IPV6,
86 Self::Raw(raw) => raw.level,
87 }
88 }
89
90 fn kind(&self) -> c_int {
91 match self {
92 Self::Ipv6HopLimit(..) => IPV6_HOPLIMIT,
93 Self::Ipv6PathMtu(..) => IPV6_PATHMTU,
94 Self::Ipv6PktInfo(..) => IPV6_PKTINFO,
95 Self::Raw(raw) => raw.kind,
96 }
97 }
98
99 fn size(&self) -> usize {
100 match self {
101 Self::Ipv6HopLimit(..) => size_of::<c_int>(),
102 Self::Ipv6PathMtu(..) => size_of::<c_int>(),
103 Self::Ipv6PktInfo(..) => size_of::<in6_pktinfo>(),
104 Self::Raw(raw) => raw.data.len(),
105 }
106 }
107
108 unsafe fn read<'b>(level: c_int, kind: c_int, ptr: *const u8, len: usize) -> Option<CMsg<'b>> {
109 const INVALID: c_int = 0;
110
111 Some(match (level, kind) {
112 (IPPROTO_IPV6, IPV6_HOPLIMIT) => CMsg::Ipv6HopLimit(read(ptr)),
113 (IPPROTO_IPV6, IPV6_PATHMTU ) => CMsg::Ipv6PathMtu(read(ptr)),
114 (IPPROTO_IPV6, IPV6_PKTINFO ) => Ipv6PktInfo(read(ptr)).into(),
115 (INVALID , INVALID ) => return None,
116 (_ , _ ) => Raw::read(level, kind, ptr, len).into(),
117 })
118 }
119
120 unsafe fn write(&self, ptr: *mut u8) {
121 match self {
122 Self::Ipv6HopLimit(limit) => write(ptr, limit.to_le()),
123 Self::Ipv6PathMtu(mtu) => write(ptr, mtu),
124 Self::Ipv6PktInfo(info) => write(ptr, info.0),
125 Self::Raw(raw) => raw.write(ptr),
126 }
127 }
128}
129
130unsafe fn message_header(buf: &[u8]) -> msghdr {
131 let mut msg: msghdr = zeroed();
132 msg.msg_control = buf.as_ptr() as *mut _;
133 msg.msg_controllen = buf.len() as _;
134 msg
135}
136
137unsafe fn first_header(msg: &mut msghdr) -> Result<*mut cmsghdr, Error> {
138 match CMSG_FIRSTHDR(msg) {
139 ptr if ptr.is_null() => Err(Error::BufferSize),
140 ptr => Ok(ptr),
141 }
142}
143
144unsafe fn next_header(msg: &msghdr, cmsg: *const cmsghdr) -> Result<*mut cmsghdr, Error> {
145 match CMSG_NXTHDR(msg, cmsg) {
146 ptr if ptr.is_null() => Err(Error::BufferSize),
147 ptr => Ok(ptr),
148 }
149}
150
151unsafe fn read<T>(src: *const u8) -> T {
152 ptr::read_unaligned(src as *const T)
153}
154
155unsafe fn write<T>(dst: *mut u8, src: T) {
156 ptr::write_unaligned(dst as *mut T, src);
157}
158
159impl Ipv6PktInfo {
160 pub fn addr(&self) -> Ipv6Addr {
161 Ipv6Addr::from(self.0.ipi6_addr.s6_addr)
162 }
163
164 pub fn ifindex(&self) -> u32 {
165 self.0.ipi6_ifindex as u32
166 }
167}
168
169impl<'a> Raw<'a> {
170 pub const fn from(level: c_int, kind: c_int, data: &'a [u8]) -> Self {
171 Self { level, kind, data }
172 }
173
174 unsafe fn read(level: c_int, kind: c_int, ptr: *const u8, len: usize) -> Self {
175 let len = len - size_of::<cmsghdr>();
176 let data = slice::from_raw_parts(ptr, len);
177 Self { level, kind, data }
178 }
179
180 unsafe fn write(&self, ptr: *mut u8) {
181 let src = self.data.as_ptr();
182 let len = self.data.len();
183 ptr::copy_nonoverlapping(src, ptr, len);
184 }
185}
186
187impl<'a> From<Ipv6PktInfo> for CMsg<'a> {
188 fn from(info: Ipv6PktInfo) -> Self {
189 Self::Ipv6PktInfo(info)
190 }
191}
192
193impl<'a> From<Raw<'a>> for CMsg<'a> {
194 fn from(raw: Raw<'a>) -> Self {
195 Self::Raw(raw)
196 }
197}
198
199impl fmt::Debug for Ipv6PktInfo {
200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 let addr = self.addr();
202 let ifidx = self.ifindex();
203 write!(f, "{{ addr: {}, ifindex: {} }}", addr, ifidx)
204 }
205}
206
207impl std::error::Error for Error {}
208
209impl fmt::Display for Error {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 write!(f, "{:?}", self)
212 }
213}