1use std::ffi::CStr;
2use std::io::{IoSlice, IoSliceMut};
3use std::iter::FusedIterator;
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
5use std::ops::ControlFlow;
6use std::os::fd::{BorrowedFd, OwnedFd};
7
8use bytemuck::checked::try_cast_slice;
9use bytemuck::{CheckedBitPattern, Pod, Zeroable, bytes_of, bytes_of_mut};
10use rustix::io::{Errno, ReadWriteFlags, preadv2, pwritev2};
11use rustix::net::{
12 AddressFamily, SocketAddrUnix, SocketFlags, SocketType, bind, connect, socket_with,
13};
14
15pub(crate) fn open_socket() -> Result<OwnedFd, SocketError> {
16 let addr = SocketAddrUnix::new(PATH_NSCDSOCKET).map_err(|_| SocketError::Addr)?;
17 let sock = socket_with(
18 AddressFamily::UNIX,
19 SocketType::STREAM,
20 SocketFlags::CLOEXEC | SocketFlags::NONBLOCK,
21 None,
22 )
23 .map_err(SocketError::Open)?;
24 bind(&sock, &SocketAddrUnix::new_unnamed()).map_err(SocketError::Bind)?;
25 connect(&sock, &addr).map_err(SocketError::Connect)?;
26 Ok(sock)
27}
28
29#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
30pub enum SocketError {
31 Open(#[source] Errno),
33 Addr,
35 Bind(#[source] Errno),
37 Connect(#[source] Errno),
39}
40
41pub(crate) fn write_request(
42 sock: BorrowedFd<'_>,
43 io: &mut IoState,
44 host: &[u8],
45) -> Result<ControlFlow<()>, RequestError> {
46 let req = RequestHeader {
47 version: NSCD_VERSION,
48 r#type: GETAI,
49 key_len: (host.len() + 1).try_into().unwrap(),
50 };
51
52 let mut slices = [
53 IoSlice::new(bytes_of(&req)),
54 IoSlice::new(host),
55 IoSlice::new(&[0]),
56 ];
57 write_all(sock, io, &mut slices).map_err(RequestError::Write)
58}
59
60#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
61pub enum RequestError {
62 Write(#[source] WriteError),
64}
65
66pub(crate) fn read_header(
67 sock: BorrowedFd<'_>,
68 io: &mut IoState,
69 resp: &mut AiResponseHeader,
70) -> Result<ControlFlow<IsEmpty, ()>, HeaderError> {
71 let mut slices = [IoSliceMut::new(bytes_of_mut(resp))];
72 if read_all(sock, io, &mut slices)?.is_continue() {
73 return Ok(ControlFlow::Continue(()));
74 }
75
76 if resp.version != NSCD_VERSION {
77 return Err(HeaderError::Version(resp.version));
78 } else if resp.found != 1 || resp.error != 0 || resp.naddrs == 0 || resp.addrslen == 0 {
79 return Ok(ControlFlow::Break(IsEmpty::Empty));
80 } else if resp.naddrs < 0 || resp.addrslen < 0 || resp.canonlen < 0 || resp.canonlen > 254 {
81 return Err(HeaderError::Data);
82 }
83
84 let Some(data_len) = Some(resp.naddrs)
85 .and_then(|l| l.checked_add(resp.addrslen))
86 .and_then(|l| l.checked_add(resp.canonlen))
87 else {
88 return Err(HeaderError::TooBig);
89 };
90 let data_len = data_len as u32 as usize;
91 if data_len > MAX_DATA_LEN {
92 return Err(HeaderError::TooBig);
93 }
94
95 Ok(ControlFlow::Break(IsEmpty::HasData(data_len)))
96}
97
98#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
99pub enum HeaderError {
100 Read(#[from] ReadError),
102 Version(i32),
104 Data,
106 TooBig,
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub(crate) enum IsEmpty {
112 Empty,
113 HasData(usize),
114}
115
116pub(crate) fn read_data(
117 sock: BorrowedFd<'_>,
118 io: &mut IoState,
119 buf: &mut [u8],
120) -> Result<ControlFlow<()>, DataError> {
121 let mut slices = [IoSliceMut::new(buf)];
122 Ok(read_all(sock, io, &mut slices)?)
123}
124
125pub(crate) fn interpret_data<'a>(
126 resp: &AiResponseHeader,
127 buf: &'a [u8],
128) -> Result<IpAddrIterator<'a>, DataError> {
129 let slice = buf;
131 let (slice, canon) = if resp.canonlen != 0 {
132 match slice
133 .len()
134 .checked_sub(resp.canonlen.try_into().unwrap_or(usize::MAX))
135 .and_then(|at| slice.split_at_checked(at))
136 .and_then(|(slice, canon)| Some((slice, CStr::from_bytes_with_nul(canon).ok()?)))
137 {
138 Some((slice, canon)) => (slice, Some(canon)),
139 None => return Err(DataError::Canon),
140 }
141 } else {
142 (slice, None)
143 };
144
145 let Some((slice, families)) = slice
147 .len()
148 .checked_sub(resp.naddrs.try_into().unwrap_or(usize::MAX))
149 .and_then(|at| slice.split_at_checked(at))
150 .and_then(|(slice, families)| Some((slice, try_cast_slice(families).ok()?)))
151 else {
152 return Err(DataError::Family);
153 };
154
155 let expected_len: usize = families
157 .iter()
158 .map(|&family| match family {
159 Family::V4 => size_of::<Ipv4Addr>(),
160 Family::V6 => size_of::<Ipv6Addr>(),
161 })
162 .sum();
163 if expected_len != slice.len() {
164 return Err(DataError::DataLength {
165 actual: slice.len(),
166 expected: expected_len,
167 });
168 }
169
170 Ok(IpAddrIterator {
171 canon,
172 families,
173 slice,
174 })
175}
176
177#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
178pub enum DataError {
179 Read(#[from] ReadError),
181 Canon,
183 Family,
185 DataLength { actual: usize, expected: usize },
187}
188
189#[derive(Debug, Default, Clone, Copy)]
191pub struct IpAddrIterator<'a> {
192 families: &'a [Family],
193 slice: &'a [u8],
194 canon: Option<&'a CStr>,
195}
196
197impl Iterator for IpAddrIterator<'_> {
198 type Item = IpAddr;
199
200 fn next(&mut self) -> Option<Self::Item> {
201 let [family, families @ ..] = self.families else {
202 return None;
203 };
204 self.families = families;
205
206 match family {
207 Family::V4 => {
208 let bits;
209 (bits, self.slice) = self.slice.split_at_checked(size_of::<Ipv4Addr>())?;
210 let bits = u32::from_be_bytes(bits.try_into().unwrap());
211 Some(IpAddr::V4(Ipv4Addr::from_bits(bits)))
212 }
213 Family::V6 => {
214 let bits;
215 (bits, self.slice) = self.slice.split_at_checked(size_of::<Ipv6Addr>())?;
216 let bits = u128::from_be_bytes(bits.try_into().unwrap());
217 Some(IpAddr::V6(Ipv6Addr::from_bits(bits)))
218 }
219 }
220 }
221
222 #[inline]
223 fn size_hint(&self) -> (usize, Option<usize>) {
224 (0, Some(self.len()))
225 }
226}
227
228impl FusedIterator for IpAddrIterator<'_> {}
229
230impl ExactSizeIterator for IpAddrIterator<'_> {
231 #[inline]
232 fn len(&self) -> usize {
233 self.families.len()
234 }
235}
236
237impl<'a> IpAddrIterator<'a> {
238 #[inline]
240 pub fn canon(&self) -> Option<&'a CStr> {
241 self.canon
242 }
243}
244
245#[derive(Debug, Clone, Copy, PartialEq, Eq, CheckedBitPattern)]
246#[repr(u8)]
247#[allow(dead_code)] pub(crate) enum Family {
249 V4 = 2,
250 V6 = 10,
251}
252
253fn write_all(
254 sock: BorrowedFd<'_>,
255 state: &mut IoState,
256 mut slices: &mut [IoSlice<'_>],
257) -> Result<ControlFlow<()>, WriteError> {
258 if state.pos > 0 {
259 IoSlice::advance_slices(&mut slices, state.pos);
260 }
261
262 match pwritev2(sock, slices, u64::MAX, ReadWriteFlags::NOWAIT) {
263 Ok(n) if n > 0 => {
264 IoSlice::advance_slices(&mut slices, n);
265 if slices.is_empty() {
266 Ok(ControlFlow::Break(()))
267 } else {
268 state.pos += n;
269 state.had_zero = false;
270 state.had_intr = false;
271 Ok(ControlFlow::Continue(()))
272 }
273 }
274 Ok(_) if !state.had_zero => {
275 state.had_zero = true;
276 Ok(ControlFlow::Continue(()))
277 }
278 Ok(_) => Err(WriteError(None)),
279 Err(Errno::INTR) if !state.had_intr => {
280 state.had_intr = true;
281 Ok(ControlFlow::Continue(()))
282 }
283 Err(err) => Err(WriteError(Some(err))),
284 }
285}
286
287#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
288#[error("Could not write data to socket")]
289pub struct WriteError(#[source] pub Option<Errno>);
290
291fn read_all(
292 sock: BorrowedFd<'_>,
293 state: &mut IoState,
294 mut slices: &mut [IoSliceMut<'_>],
295) -> Result<ControlFlow<()>, ReadError> {
296 if state.pos > 0 {
297 IoSliceMut::advance_slices(&mut slices, state.pos);
298 }
299
300 match preadv2(sock, slices, u64::MAX, ReadWriteFlags::NOWAIT) {
301 Ok(n) if n > 0 => {
302 IoSliceMut::advance_slices(&mut slices, n);
303 if slices.is_empty() {
304 Ok(ControlFlow::Break(()))
305 } else {
306 state.pos += n;
307 state.had_zero = false;
308 state.had_intr = false;
309 Ok(ControlFlow::Continue(()))
310 }
311 }
312 Ok(_) if !state.had_zero => {
313 state.had_zero = true;
314 Ok(ControlFlow::Continue(()))
315 }
316 Ok(_) => Err(ReadError(None)),
317 Err(Errno::INTR) if !state.had_intr => {
318 state.had_intr = true;
319 Ok(ControlFlow::Continue(()))
320 }
321 Err(err) => Err(ReadError(Some(err))),
322 }
323}
324
325#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
326pub struct ReadError(#[source] pub Option<Errno>);
328
329pub(crate) trait IsWouldblock {
330 fn is_wouldblock(&self) -> bool;
331}
332
333impl<T: IsWouldblock> IsWouldblock for &T {
334 #[inline]
335 fn is_wouldblock(&self) -> bool {
336 <T as IsWouldblock>::is_wouldblock(self)
337 }
338}
339
340impl<T: IsWouldblock> IsWouldblock for Option<T> {
341 #[inline]
342 fn is_wouldblock(&self) -> bool {
343 match self {
344 Some(err) => err.is_wouldblock(),
345 None => false,
346 }
347 }
348}
349
350impl IsWouldblock for RequestError {
351 #[inline]
352 fn is_wouldblock(&self) -> bool {
353 match self {
354 RequestError::Write(err) => err.is_wouldblock(),
355 }
356 }
357}
358
359impl IsWouldblock for HeaderError {
360 #[inline]
361 fn is_wouldblock(&self) -> bool {
362 match self {
363 HeaderError::Read(err) => err.is_wouldblock(),
364 HeaderError::Version(_) | HeaderError::Data | HeaderError::TooBig => false,
365 }
366 }
367}
368
369impl IsWouldblock for DataError {
370 #[inline]
371 fn is_wouldblock(&self) -> bool {
372 match self {
373 DataError::Read(err) => err.is_wouldblock(),
374 DataError::Canon | DataError::Family | DataError::DataLength { .. } => false,
375 }
376 }
377}
378
379impl IsWouldblock for WriteError {
380 #[inline]
381 fn is_wouldblock(&self) -> bool {
382 self.0.is_wouldblock()
383 }
384}
385
386impl IsWouldblock for ReadError {
387 #[inline]
388 fn is_wouldblock(&self) -> bool {
389 self.0.is_wouldblock()
390 }
391}
392
393impl IsWouldblock for Errno {
394 #[inline]
395 fn is_wouldblock(&self) -> bool {
396 *self == Errno::WOULDBLOCK
397 }
398}
399
400#[derive(Debug, Clone, Copy, Default)]
401pub(crate) struct IoState {
402 pos: usize,
403 had_zero: bool,
404 had_intr: bool,
405}
406
407#[derive(Debug, Clone, Copy, Pod, Zeroable)]
408#[repr(C)]
409struct RequestHeader {
410 version: i32,
411 r#type: i32,
412 key_len: i32,
413}
414
415#[derive(Debug, Clone, Copy, Pod, Zeroable, Default)]
416#[repr(C)]
417pub(crate) struct AiResponseHeader {
418 version: i32,
419 found: i32,
420 naddrs: NscdSsize,
421 addrslen: NscdSsize,
422 canonlen: NscdSsize,
423 error: i32,
424}
425
426type NscdSsize = i32;
428
429const PATH_NSCDSOCKET: &CStr = c"/var/run/nscd/socket";
431const GETAI: i32 = 14;
432const NSCD_VERSION: i32 = 2;
433
434const MAX_DATA_LEN: usize = 8192;