use core::ffi::CStr;
use core::mem::MaybeUninit;
use crate::Addr;
use crate::GaihAddrTuple;
use crate::err::NssErr;
use crate::err::NssRes;
pub(crate) struct Gaih4Buf<'a> {
hostname: *const libc::c_char,
addrs: &'a mut [MaybeUninit<GaihAddrTuple>],
addrs_len: usize,
maybe_head: &'a mut *mut GaihAddrTuple,
set_head: bool,
}
impl<'a> Gaih4Buf<'a> {
pub(crate) unsafe fn try_new(
hostname: &CStr,
maybe_head: &'a mut *mut GaihAddrTuple,
buffer: *mut libc::c_char,
buf_len: libc::size_t,
) -> NssRes<Self> {
if buffer.is_null() {
return Err(NssErr::INVALID_INPUT);
}
let (hostname, name_len) = {
let hostname = hostname.to_bytes_with_nul();
let host_len = hostname.len();
if buf_len < host_len {
return Err(NssErr::BUF_TOO_SMALL);
}
unsafe {
core::ptr::copy_nonoverlapping(hostname.as_ptr(), buffer.cast(), host_len);
};
(buffer as *const libc::c_char, host_len)
};
let offset_bytes = (buffer as usize + name_len)
.next_multiple_of(core::mem::align_of::<GaihAddrTuple>())
- buffer as usize;
let arr_len = buf_len.saturating_sub(offset_bytes) / core::mem::size_of::<GaihAddrTuple>();
let addrs = if arr_len == 0 {
&mut []
} else {
let arr_start = buffer.wrapping_add(offset_bytes);
if (arr_start as usize) < buffer as usize {
return Err(NssErr::INVALID_INPUT);
}
let arr = arr_start.cast::<MaybeUninit<GaihAddrTuple>>();
debug_assert_eq!(
arr as usize % core::mem::align_of::<GaihAddrTuple>(),
0,
"arr_start is aligned"
);
debug_assert!(
offset_bytes + arr_len * core::mem::size_of::<GaihAddrTuple>() <= buf_len,
"name and array fit in the buffer allocation"
);
unsafe {
core::slice::from_raw_parts_mut(arr, arr_len)
}
};
Ok(Self {
hostname,
addrs,
addrs_len: 0,
maybe_head,
set_head: false,
})
}
pub(crate) fn push(&mut self, addr: Addr) -> bool {
if !(*self.maybe_head).is_null() && !self.set_head {
unsafe {
**self.maybe_head = GaihAddrTuple::new_addr(self.hostname, addr);
}
self.set_head = true;
return true;
}
let child = {
let Some(slot) = self.addrs.get_mut(self.addrs_len) else {
return false;
};
core::ptr::from_mut(slot.write(GaihAddrTuple::new_addr(self.hostname, addr)))
};
match self.addrs_len {
0 if !self.set_head => {
debug_assert!(
(*self.maybe_head).is_null(),
"if pat were non null, we would have written to it and returned early"
);
*self.maybe_head = child;
self.set_head = true;
}
0 => unsafe {
(**self.maybe_head).next = child;
},
nonzero => {
let parent = &mut self.addrs[nonzero - 1];
unsafe {
parent.assume_init_mut().next = child;
}
}
}
self.addrs_len += 1;
true
}
}
#[cfg(test)]
mod buf_iter {
use core::ffi::CStr;
use core::marker::PhantomData;
use core::net::Ipv4Addr;
use core::net::Ipv6Addr;
use crate::Addr;
use crate::GaihAddrTuple;
use crate::buf::Gaih4Buf;
impl<'a> Gaih4Buf<'a> {
pub fn iter(&self) -> Gaih4BufIter<'_> {
let next = if !self.set_head {
assert_eq!(self.addrs_len, 0);
core::ptr::null_mut()
} else {
*self.maybe_head
};
Gaih4BufIter {
next,
_t: PhantomData,
}
}
}
pub struct Gaih4BufIter<'a> {
next: *mut GaihAddrTuple,
_t: PhantomData<&'a Gaih4Buf<'a>>,
}
impl<'a> Iterator for Gaih4BufIter<'a> {
type Item = (&'a CStr, Addr);
fn next(&mut self) -> Option<Self::Item> {
if self.next.is_null() {
return None;
}
let name;
let family;
let addr;
let scope_id;
unsafe {
name = CStr::from_ptr((*self.next).name);
family = (*self.next).family;
addr = (*self.next).addr;
scope_id = (*self.next).scope_id;
self.next = (*self.next).next;
};
let addr = match family {
libc::AF_INET => Addr {
ip: Ipv4Addr::from(addr[0].to_ne_bytes()).into(),
scope_id,
},
libc::AF_INET6 => {
let mut bytes = addr.iter().flat_map(|bits| bits.to_ne_bytes());
let octets = core::array::from_fn(|_| {
bytes.next().expect("there should be exactly 4 * 4 bytes")
});
assert_eq!(bytes.next(), None);
Addr {
ip: Ipv6Addr::from(octets).into(),
scope_id,
}
}
other => panic!("valid nodes are only ever IPv4 or IPv6. Found libc::AF_{other}"),
};
Some((name, addr))
}
}
}
#[cfg(test)]
mod buf_tests {
use crate::Addr;
use crate::GaihAddrTuple;
use crate::buf::Gaih4Buf;
use crate::err::NssErr;
use crate::err::NssRes;
use core::ffi::CStr;
use core::net::Ipv4Addr;
use core::net::Ipv6Addr;
#[test]
fn large_buf_seed_pat() {
const ADDRS4: &[u32] = &[111, 222, 333];
const ADDRS6: &[u128] = &[777, 888, 999];
const HOSTNAME: &CStr = c"AMBIGUOUS_NEIGHBOR";
let mut pat = core::pin::pin!(GaihAddrTuple {
next: core::ptr::null_mut(),
name: core::ptr::null(),
family: libc::AF_UNSPEC,
addr: [0; 4],
scope_id: 0,
});
let mut pat_ptr = &raw mut *pat;
let mut bytes = core::pin::pin!([0i8; 512]);
let mut buf =
unsafe { Gaih4Buf::try_new(HOSTNAME, &mut pat_ptr, bytes.as_mut_ptr(), bytes.len()) }
.expect("well formed inputs should be successful");
self::push_and_check(HOSTNAME, &mut buf, true, ADDRS4, ADDRS6)
.expect("should pass with large buf and seeded PAT");
}
#[test]
fn large_buf_null_pat() {
const ADDRS4: &[u32] = &[!111, !222];
const ADDRS6: &[u128] = &[!777, !888, !999, !1010];
const HOSTNAME: &CStr = c"another_host";
let mut pat = core::ptr::null_mut();
let mut bytes = core::pin::pin!([0i8; 512]);
let mut buf =
unsafe { Gaih4Buf::try_new(HOSTNAME, &mut pat, bytes.as_mut_ptr(), bytes.len()) }
.expect("well formed inputs should be successful");
self::push_and_check(HOSTNAME, &mut buf, true, ADDRS4, ADDRS6)
.expect("should pass with large buf and null PAT");
}
#[test]
fn tiny_buf_seed_pat() {
const HOSTNAME: &CStr = c"RunningOutOfIdeas";
const ADDRS4: &[u32] = &[2130706433];
const ADDRS6: &[u128] = &[];
let mut pat = core::pin::pin!(GaihAddrTuple {
next: core::ptr::null_mut(),
name: core::ptr::null(),
family: libc::AF_UNSPEC,
addr: [0; 4],
scope_id: 0,
});
let mut pat_ptr = &raw mut *pat;
let mut bytes = core::pin::pin!([0i8; 19]);
let mut buf =
unsafe { Gaih4Buf::try_new(HOSTNAME, &mut pat_ptr, bytes.as_mut_ptr(), bytes.len()) }
.expect("well formed inputs should be successful");
self::push_and_check(HOSTNAME, &mut buf, true, ADDRS4, ADDRS6)
.expect("should pass with large buf and seeded PAT");
}
#[test]
fn fail_tinier_buf_seed_pat() {
const HOSTNAME: &CStr = c"RunningOutOfIdeas2";
let mut pat = core::pin::pin!(GaihAddrTuple {
next: core::ptr::null_mut(),
name: core::ptr::null(),
family: libc::AF_UNSPEC,
addr: [0; 4],
scope_id: 0,
});
let mut pat_ptr = &raw mut *pat;
let mut bytes = core::pin::pin!([0i8; 18]);
let buf =
unsafe { Gaih4Buf::try_new(HOSTNAME, &mut pat_ptr, bytes.as_mut_ptr(), bytes.len()) };
let Err(err) = buf else {
panic!("buf should be too small for the hostname");
};
assert_eq!(err, NssErr::BUF_TOO_SMALL);
}
#[test]
fn fail_small_buf_null_pat() {
const ADDRS4: &[u32] = &[12345, 6789];
const ADDRS6: &[u128] = &[10111213, 1416171828, 9018937654];
const HOSTNAME: &CStr = c"should-fail-no-space";
let mut pat = core::ptr::null_mut();
let mut bytes = core::pin::pin!([0i8; 97]);
let mut buf =
unsafe { Gaih4Buf::try_new(HOSTNAME, &mut pat, bytes.as_mut_ptr(), bytes.len()) }
.expect("well formed inputs should be successful");
let err = self::push_and_check(HOSTNAME, &mut buf, false, ADDRS4, ADDRS6)
.expect_err("buf is not large enough for all results");
assert_eq!(err, NssErr::BUF_TOO_SMALL);
}
fn push_and_check(
hostname: &CStr,
buf: &mut Gaih4Buf,
expect_success: bool,
v4: &[u32],
v6: &[u128],
) -> NssRes<()> {
for ip in v4.iter().copied().map(Ipv4Addr::from_bits) {
let success = buf.push(Addr {
ip: ip.into(),
scope_id: 0,
});
if expect_success {
assert!(success, "v4 push should succeed");
} else {
return Err(NssErr::BUF_TOO_SMALL);
}
}
for (scope_id, ip) in v6.iter().copied().map(Ipv6Addr::from_bits).enumerate() {
let success = buf.push(Addr {
ip: ip.into(),
scope_id: scope_id as u32,
});
if expect_success {
assert!(success, "v6 push should succeed");
} else {
return Err(NssErr::BUF_TOO_SMALL);
}
}
let mut buffered = buf.iter();
let mut count = 0;
for ((host, addr), expected) in (&mut buffered)
.zip(v4.iter().copied().map(Ipv4Addr::from_bits).map(|ip| Addr {
ip: ip.into(),
scope_id: 0,
}))
.take(v4.len())
{
assert_eq!(host, hostname);
assert_eq!(addr, expected);
count += 1;
}
for ((host, addr), expected) in
(&mut buffered).zip(v6.iter().copied().enumerate().map(|(scope_id, bits)| Addr {
ip: Ipv6Addr::from_bits(bits).into(),
scope_id: scope_id as u32,
}))
{
assert_eq!(host, hostname);
assert_eq!(addr, expected);
count += 1;
}
assert_eq!(
count,
v4.len() + v6.len(),
"should have buffered all addresses"
);
Ok(())
}
}