nscd_lookup/
protocol.rs

1use std::ffi::CStr;
2use std::io::{IoSlice, IoSliceMut};
3use std::iter::FusedIterator;
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
5use std::ops::ControlFlow;
6use std::os::fd::{BorrowedFd, OwnedFd};
7
8use bytemuck::checked::try_cast_slice;
9use bytemuck::{CheckedBitPattern, Pod, Zeroable, bytes_of, bytes_of_mut};
10use rustix::io::{Errno, ReadWriteFlags, preadv2, pwritev2};
11use rustix::net::{
12    AddressFamily, SocketAddrUnix, SocketFlags, SocketType, bind, connect, socket_with,
13};
14
15pub(crate) fn open_socket() -> Result<OwnedFd, SocketError> {
16    let addr = SocketAddrUnix::new(PATH_NSCDSOCKET).map_err(|_| SocketError::Addr)?;
17    let sock = socket_with(
18        AddressFamily::UNIX,
19        SocketType::STREAM,
20        SocketFlags::CLOEXEC | SocketFlags::NONBLOCK,
21        None,
22    )
23    .map_err(SocketError::Open)?;
24    bind(&sock, &SocketAddrUnix::new_unnamed()).map_err(SocketError::Bind)?;
25    connect(&sock, &addr).map_err(SocketError::Connect)?;
26    Ok(sock)
27}
28
29#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
30pub enum SocketError {
31    /// Could not open socket
32    Open(#[source] Errno),
33    /// Nscd socket address was invalid
34    Addr,
35    /// Could not bind anonymous socket
36    Bind(#[source] Errno),
37    /// Could not connect to nscd socket
38    Connect(#[source] Errno),
39}
40
41pub(crate) fn write_request(
42    sock: BorrowedFd<'_>,
43    io: &mut IoState,
44    host: &[u8],
45) -> Result<ControlFlow<()>, RequestError> {
46    let req = RequestHeader {
47        version: NSCD_VERSION,
48        r#type: GETAI,
49        key_len: (host.len() + 1).try_into().unwrap(),
50    };
51
52    let mut slices = [
53        IoSlice::new(bytes_of(&req)),
54        IoSlice::new(host),
55        IoSlice::new(&[0]),
56    ];
57    write_all(sock, io, &mut slices).map_err(RequestError::Write)
58}
59
60#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
61pub enum RequestError {
62    /// Could not send request
63    Write(#[source] WriteError),
64}
65
66pub(crate) fn read_header(
67    sock: BorrowedFd<'_>,
68    io: &mut IoState,
69    resp: &mut AiResponseHeader,
70) -> Result<ControlFlow<IsEmpty, ()>, HeaderError> {
71    let mut slices = [IoSliceMut::new(bytes_of_mut(resp))];
72    if read_all(sock, io, &mut slices)?.is_continue() {
73        return Ok(ControlFlow::Continue(()));
74    }
75
76    if resp.version != NSCD_VERSION {
77        return Err(HeaderError::Version(resp.version));
78    } else if resp.found != 1 || resp.error != 0 || resp.naddrs == 0 || resp.addrslen == 0 {
79        return Ok(ControlFlow::Break(IsEmpty::Empty));
80    } else if resp.naddrs < 0 || resp.addrslen < 0 || resp.canonlen < 0 || resp.canonlen > 254 {
81        return Err(HeaderError::Data);
82    }
83
84    let Some(data_len) = Some(resp.naddrs)
85        .and_then(|l| l.checked_add(resp.addrslen))
86        .and_then(|l| l.checked_add(resp.canonlen))
87    else {
88        return Err(HeaderError::TooBig);
89    };
90    let data_len = data_len as u32 as usize;
91    if data_len > MAX_DATA_LEN {
92        return Err(HeaderError::TooBig);
93    }
94
95    Ok(ControlFlow::Break(IsEmpty::HasData(data_len)))
96}
97
98#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
99pub enum HeaderError {
100    /// Could not read header
101    Read(#[from] ReadError),
102    /// Wrong version {0:x}, expected {NSCD_VERSION:x}
103    Version(i32),
104    /// nscd response not understood
105    Data,
106    /// nscd response unreasonable large
107    TooBig,
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub(crate) enum IsEmpty {
112    Empty,
113    HasData(usize),
114}
115
116pub(crate) fn read_data(
117    sock: BorrowedFd<'_>,
118    io: &mut IoState,
119    buf: &mut [u8],
120) -> Result<ControlFlow<()>, DataError> {
121    let mut slices = [IoSliceMut::new(buf)];
122    Ok(read_all(sock, io, &mut slices)?)
123}
124
125pub(crate) fn interpret_data<'a>(
126    resp: &AiResponseHeader,
127    buf: &'a [u8],
128) -> Result<IpAddrIterator<'a>, DataError> {
129    // read canonical name
130    let slice = buf;
131    let (slice, canon) = if resp.canonlen != 0 {
132        match slice
133            .len()
134            .checked_sub(resp.canonlen.try_into().unwrap_or(usize::MAX))
135            .and_then(|at| slice.split_at_checked(at))
136            .and_then(|(slice, canon)| Some((slice, CStr::from_bytes_with_nul(canon).ok()?)))
137        {
138            Some((slice, canon)) => (slice, Some(canon)),
139            None => return Err(DataError::Canon),
140        }
141    } else {
142        (slice, None)
143    };
144
145    // make sure that all address families are `AF_INET` or `AF_INET6`
146    let Some((slice, families)) = slice
147        .len()
148        .checked_sub(resp.naddrs.try_into().unwrap_or(usize::MAX))
149        .and_then(|at| slice.split_at_checked(at))
150        .and_then(|(slice, families)| Some((slice, try_cast_slice(families).ok()?)))
151    else {
152        return Err(DataError::Family);
153    };
154
155    // make sure that the lengths of all addresses combined equals the length of the buffer
156    let expected_len: usize = families
157        .iter()
158        .map(|&family| match family {
159            Family::V4 => size_of::<Ipv4Addr>(),
160            Family::V6 => size_of::<Ipv6Addr>(),
161        })
162        .sum();
163    if expected_len != slice.len() {
164        return Err(DataError::DataLength {
165            actual: slice.len(),
166            expected: expected_len,
167        });
168    }
169
170    Ok(IpAddrIterator {
171        canon,
172        families,
173        slice,
174    })
175}
176
177#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
178pub enum DataError {
179    /// Could not read header
180    Read(#[from] ReadError),
181    /// Could not extract canonical name
182    Canon,
183    /// Response contained address families other than AF_INET / AF_INET6
184    Family,
185    /// Actual length of IP addresses {actual} != expected length {expected}
186    DataLength { actual: usize, expected: usize },
187}
188
189/// An iterator of [`IpAddr`]esses, returned by [`lookup()`][crate::sync::lookup].
190#[derive(Debug, Default, Clone, Copy)]
191pub struct IpAddrIterator<'a> {
192    families: &'a [Family],
193    slice: &'a [u8],
194    canon: Option<&'a CStr>,
195}
196
197impl Iterator for IpAddrIterator<'_> {
198    type Item = IpAddr;
199
200    fn next(&mut self) -> Option<Self::Item> {
201        let [family, families @ ..] = self.families else {
202            return None;
203        };
204        self.families = families;
205
206        match family {
207            Family::V4 => {
208                let bits;
209                (bits, self.slice) = self.slice.split_at_checked(size_of::<Ipv4Addr>())?;
210                let bits = u32::from_be_bytes(bits.try_into().unwrap());
211                Some(IpAddr::V4(Ipv4Addr::from_bits(bits)))
212            }
213            Family::V6 => {
214                let bits;
215                (bits, self.slice) = self.slice.split_at_checked(size_of::<Ipv6Addr>())?;
216                let bits = u128::from_be_bytes(bits.try_into().unwrap());
217                Some(IpAddr::V6(Ipv6Addr::from_bits(bits)))
218            }
219        }
220    }
221
222    #[inline]
223    fn size_hint(&self) -> (usize, Option<usize>) {
224        (0, Some(self.len()))
225    }
226}
227
228impl FusedIterator for IpAddrIterator<'_> {}
229
230impl ExactSizeIterator for IpAddrIterator<'_> {
231    #[inline]
232    fn len(&self) -> usize {
233        self.families.len()
234    }
235}
236
237impl<'a> IpAddrIterator<'a> {
238    /// The canonical name of the host.
239    #[inline]
240    pub fn canon(&self) -> Option<&'a CStr> {
241        self.canon
242    }
243}
244
245#[derive(Debug, Clone, Copy, PartialEq, Eq, CheckedBitPattern)]
246#[repr(u8)]
247#[allow(dead_code)] // constructed by [`try_cast_slice()`]
248pub(crate) enum Family {
249    V4 = 2,
250    V6 = 10,
251}
252
253fn write_all(
254    sock: BorrowedFd<'_>,
255    state: &mut IoState,
256    mut slices: &mut [IoSlice<'_>],
257) -> Result<ControlFlow<()>, WriteError> {
258    if state.pos > 0 {
259        IoSlice::advance_slices(&mut slices, state.pos);
260    }
261
262    match pwritev2(sock, slices, u64::MAX, ReadWriteFlags::NOWAIT) {
263        Ok(n) if n > 0 => {
264            IoSlice::advance_slices(&mut slices, n);
265            if slices.is_empty() {
266                Ok(ControlFlow::Break(()))
267            } else {
268                state.pos += n;
269                state.had_zero = false;
270                state.had_intr = false;
271                Ok(ControlFlow::Continue(()))
272            }
273        }
274        Ok(_) if !state.had_zero => {
275            state.had_zero = true;
276            Ok(ControlFlow::Continue(()))
277        }
278        Ok(_) => Err(WriteError(None)),
279        Err(Errno::INTR) if !state.had_intr => {
280            state.had_intr = true;
281            Ok(ControlFlow::Continue(()))
282        }
283        Err(err) => Err(WriteError(Some(err))),
284    }
285}
286
287#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
288#[error("Could not write data to socket")]
289pub struct WriteError(#[source] pub Option<Errno>);
290
291fn read_all(
292    sock: BorrowedFd<'_>,
293    state: &mut IoState,
294    mut slices: &mut [IoSliceMut<'_>],
295) -> Result<ControlFlow<()>, ReadError> {
296    if state.pos > 0 {
297        IoSliceMut::advance_slices(&mut slices, state.pos);
298    }
299
300    match preadv2(sock, slices, u64::MAX, ReadWriteFlags::NOWAIT) {
301        Ok(n) if n > 0 => {
302            IoSliceMut::advance_slices(&mut slices, n);
303            if slices.is_empty() {
304                Ok(ControlFlow::Break(()))
305            } else {
306                state.pos += n;
307                state.had_zero = false;
308                state.had_intr = false;
309                Ok(ControlFlow::Continue(()))
310            }
311        }
312        Ok(_) if !state.had_zero => {
313            state.had_zero = true;
314            Ok(ControlFlow::Continue(()))
315        }
316        Ok(_) => Err(ReadError(None)),
317        Err(Errno::INTR) if !state.had_intr => {
318            state.had_intr = true;
319            Ok(ControlFlow::Continue(()))
320        }
321        Err(err) => Err(ReadError(Some(err))),
322    }
323}
324
325#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
326/// Could not read data from socket
327pub struct ReadError(#[source] pub Option<Errno>);
328
329pub(crate) trait IsWouldblock {
330    fn is_wouldblock(&self) -> bool;
331}
332
333impl<T: IsWouldblock> IsWouldblock for &T {
334    #[inline]
335    fn is_wouldblock(&self) -> bool {
336        <T as IsWouldblock>::is_wouldblock(self)
337    }
338}
339
340impl<T: IsWouldblock> IsWouldblock for Option<T> {
341    #[inline]
342    fn is_wouldblock(&self) -> bool {
343        match self {
344            Some(err) => err.is_wouldblock(),
345            None => false,
346        }
347    }
348}
349
350impl IsWouldblock for RequestError {
351    #[inline]
352    fn is_wouldblock(&self) -> bool {
353        match self {
354            RequestError::Write(err) => err.is_wouldblock(),
355        }
356    }
357}
358
359impl IsWouldblock for HeaderError {
360    #[inline]
361    fn is_wouldblock(&self) -> bool {
362        match self {
363            HeaderError::Read(err) => err.is_wouldblock(),
364            HeaderError::Version(_) | HeaderError::Data | HeaderError::TooBig => false,
365        }
366    }
367}
368
369impl IsWouldblock for DataError {
370    #[inline]
371    fn is_wouldblock(&self) -> bool {
372        match self {
373            DataError::Read(err) => err.is_wouldblock(),
374            DataError::Canon | DataError::Family | DataError::DataLength { .. } => false,
375        }
376    }
377}
378
379impl IsWouldblock for WriteError {
380    #[inline]
381    fn is_wouldblock(&self) -> bool {
382        self.0.is_wouldblock()
383    }
384}
385
386impl IsWouldblock for ReadError {
387    #[inline]
388    fn is_wouldblock(&self) -> bool {
389        self.0.is_wouldblock()
390    }
391}
392
393impl IsWouldblock for Errno {
394    #[inline]
395    fn is_wouldblock(&self) -> bool {
396        *self == Errno::WOULDBLOCK
397    }
398}
399
400#[derive(Debug, Clone, Copy, Default)]
401pub(crate) struct IoState {
402    pos: usize,
403    had_zero: bool,
404    had_intr: bool,
405}
406
407#[derive(Debug, Clone, Copy, Pod, Zeroable)]
408#[repr(C)]
409struct RequestHeader {
410    version: i32,
411    r#type: i32,
412    key_len: i32,
413}
414
415#[derive(Debug, Clone, Copy, Pod, Zeroable, Default)]
416#[repr(C)]
417pub(crate) struct AiResponseHeader {
418    version: i32,
419    found: i32,
420    naddrs: NscdSsize,
421    addrslen: NscdSsize,
422    canonlen: NscdSsize,
423    error: i32,
424}
425
426// typedef in `nscd-types.h`
427type NscdSsize = i32;
428
429// constants in `nscd-client.h`
430const PATH_NSCDSOCKET: &CStr = c"/var/run/nscd/socket";
431const GETAI: i32 = 14;
432const NSCD_VERSION: i32 = 2;
433
434const MAX_DATA_LEN: usize = 8192; // This is not a constant in nscd, just a safety check.