1use std::ops::ControlFlow;
4use std::os::fd::{AsFd, OwnedFd};
5use std::time::{Duration, Instant};
6
7use rustix::event::{PollFd, PollFlags, Timespec, poll};
8use rustix::io::Errno;
9
10use crate::protocol::{
11 AiResponseHeader, DataError, HeaderError, IoState, IpAddrIterator, IsEmpty, IsWouldblock,
12 RequestError, SocketError, interpret_data, open_socket, read_data, read_header, write_request,
13};
14
15#[inline]
23pub fn lookup(
24 host: impl AsRef<[u8]>,
25 buf: &mut Vec<u8>,
26 timeout: Option<Duration>,
27) -> Result<Option<IpAddrIterator<'_>>, Error> {
28 do_lookup(host.as_ref(), buf, timeout.unwrap_or(DEFAULT_TIMEOUT))
29}
30
31fn do_lookup<'a>(
32 host: &[u8],
33 buf: &'a mut Vec<u8>,
34 timeout: Duration,
35) -> Result<Option<IpAddrIterator<'a>>, Error> {
36 let deadline = Instant::now() + timeout;
37
38 let sock = open_socket().map_err(Error::Socket)?;
39
40 let mut io = IoState::default();
41 loop {
42 match write_request(sock.as_fd(), &mut io, host) {
43 Ok(ControlFlow::Continue(())) => continue,
44 Ok(ControlFlow::Break(())) => break,
45 Err(err) if err.is_wouldblock() => await_writable(&sock, deadline)?,
46 Err(err) => return Err(Error::Request(err)),
47 }
48 }
49
50 io = IoState::default();
51 let mut resp = AiResponseHeader::default();
52 let data_len = loop {
53 await_readable(&sock, deadline)?;
54 match read_header(sock.as_fd(), &mut io, &mut resp) {
55 Ok(ControlFlow::Continue(())) => continue,
56 Ok(ControlFlow::Break(IsEmpty::Empty)) => return Ok(None),
57 Ok(ControlFlow::Break(IsEmpty::HasData(data_len))) => break data_len,
58 Err(err) if err.is_wouldblock() => await_readable(&sock, deadline)?,
59 Err(err) => return Err(Error::Header(err)),
60 }
61 };
62 buf.resize(data_len, 0);
63
64 io = IoState::default();
65 loop {
66 match read_data(sock.as_fd(), &mut io, buf) {
67 Ok(ControlFlow::Continue(())) => continue,
68 Ok(ControlFlow::Break(())) => break,
69 Err(err) if err.is_wouldblock() => await_readable(&sock, deadline)?,
70 Err(err) => return Err(Error::Data(err)),
71 }
72 }
73
74 Ok(Some(interpret_data(&resp, buf)?))
75}
76
77fn await_writable(sock: &OwnedFd, deadline: Instant) -> Result<(), Error> {
78 let events = PollFlags::IN
79 | PollFlags::OUT
80 | PollFlags::WRNORM
81 | PollFlags::WRBAND
82 | PollFlags::ERR
83 | PollFlags::HUP;
84 await_io(sock, deadline, events)
85}
86
87fn await_readable(sock: &OwnedFd, deadline: Instant) -> Result<(), Error> {
88 let events = PollFlags::IN
89 | PollFlags::PRI
90 | PollFlags::RDNORM
91 | PollFlags::RDBAND
92 | PollFlags::ERR
93 | PollFlags::HUP;
94 await_io(sock, deadline, events)
95}
96
97fn await_io(sock: &OwnedFd, deadline: Instant, events: PollFlags) -> Result<(), Error> {
98 let Some(remaining) = deadline.checked_duration_since(Instant::now()) else {
99 return Err(Error::Timeout(None));
100 };
101
102 let timeout = Timespec::try_from(remaining).map_err(|_| Error::Timeout(None))?;
103 let mut fds = [PollFd::new(sock, events)];
104 if poll(&mut fds, Some(&timeout)).map_err(|err| Error::Timeout(Some(err)))? == 0 {
105 return Err(Error::Timeout(None));
106 }
107
108 Ok(())
109}
110
111#[derive(Debug, Clone, Copy, thiserror::Error, displaydoc::Display)]
112pub enum Error {
114 Socket(#[from] SocketError),
116 Request(#[from] RequestError),
118 Header(#[from] HeaderError),
120 Data(#[from] DataError),
122 Timeout(#[source] Option<Errno>),
124}
125
126pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);