use crate::error_utilities::last_error;
use rdma_sys::__be64;
use rdma_sys::ibv_device;
use rdma_sys::{ibv_free_device_list, ibv_get_device_list};
use rdma_sys::{ibv_get_device_guid, ibv_get_device_name};
use std::ffi::CStr;
use std::io;
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::os::raw::c_int;
use std::ptr::NonNull;
use std::{fmt, mem, slice};
use numeric_cast::NumericCast;
use scopeguard::guard_on_unwind;
pub struct DeviceList {
arr: NonNull<Device>,
len: usize,
}
unsafe impl Send for DeviceList {}
unsafe impl Sync for DeviceList {}
#[allow(missing_copy_implementations)] #[repr(transparent)]
pub struct Device(NonNull<ibv_device>);
unsafe impl Send for Device {}
unsafe impl Sync for Device {}
#[derive(Clone, Copy, PartialEq, Eq)]
#[repr(transparent)]
pub struct Guid(__be64);
impl DeviceList {
fn ffi_ptr(&self) -> *mut *mut ibv_device {
self.arr.as_ptr().cast()
}
#[inline]
pub fn available() -> io::Result<Self> {
unsafe {
let mut num_devices: c_int = 0;
let arr = ibv_get_device_list(&mut num_devices);
if arr.is_null() {
return Err(last_error());
}
let arr: NonNull<Device> = NonNull::new_unchecked(arr.cast());
let _guard = guard_on_unwind((), |()| ibv_free_device_list(arr.as_ptr().cast()));
let len: usize = num_devices.numeric_cast();
if mem::size_of::<c_int>() >= mem::size_of::<usize>() {
let total_size = len.saturating_mul(mem::size_of::<*mut ibv_device>());
assert!(total_size < usize::MAX.wrapping_div(2));
}
Ok(Self { arr, len })
}
}
#[inline]
#[must_use]
pub fn as_slice(&self) -> &[Device] {
unsafe { slice::from_raw_parts(self.arr.as_ptr(), self.len) }
}
}
impl Drop for DeviceList {
#[inline]
fn drop(&mut self) {
unsafe { ibv_free_device_list(self.ffi_ptr()) }
}
}
impl Deref for DeviceList {
type Target = [Device];
#[inline]
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
impl fmt::Debug for DeviceList {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<[Device] as fmt::Debug>::fmt(self, f)
}
}
impl Device {
pub(crate) fn ffi_ptr(&self) -> *mut ibv_device {
self.0.as_ptr()
}
#[inline]
#[must_use]
pub fn c_name(&self) -> &CStr {
unsafe { CStr::from_ptr(ibv_get_device_name(self.ffi_ptr())) }
}
#[inline]
#[must_use]
pub fn name(&self) -> &str {
#[allow(clippy::expect_used)]
self.c_name().to_str().expect("non-utf8 device name")
}
#[inline]
#[must_use]
pub fn guid(&self) -> Guid {
unsafe { Guid(ibv_get_device_guid(self.ffi_ptr())) }
}
}
impl fmt::Debug for Device {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = self.name();
let guid = self.guid();
f.debug_struct("Device")
.field("name", &name)
.field("guid", &guid)
.finish()
}
}
impl Guid {
#[inline]
#[must_use]
pub fn from_bytes(bytes: [u8; 8]) -> Self {
Self(u64::from_ne_bytes(bytes))
}
#[inline]
#[must_use]
pub fn as_bytes(&self) -> &[u8; 8] {
unsafe { &*<*const _>::cast(self) }
}
}
impl fmt::Debug for Guid {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Guid({:x})", self)
}
}
fn guid_to_hex<R>(guid: Guid, case: hex_simd::AsciiCase, f: impl FnOnce(&str) -> R) -> R {
let src: &[u8; 8] = guid.as_bytes();
let mut buf: MaybeUninit<[u8; 16]> = MaybeUninit::uninit();
let ans = {
let bytes = unsafe { slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), 16) };
let dst = hex_simd::OutBuf::from_uninit_mut(bytes);
let result = hex_simd::encode_as_str(src, dst, case);
unsafe { result.unwrap_unchecked() }
};
f(ans)
}
impl fmt::LowerHex for Guid {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
guid_to_hex(*self, hex_simd::AsciiCase::Lower, |s| {
<str as fmt::Display>::fmt(s, f)
})
}
}
impl fmt::UpperHex for Guid {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
guid_to_hex(*self, hex_simd::AsciiCase::Upper, |s| {
<str as fmt::Display>::fmt(s, f)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use const_str::hex_bytes as hex;
#[test]
fn guid_fmt() {
const GUID_HEX: &str = "26418cfffe021df9";
let guid = Guid::from_bytes(hex!(GUID_HEX));
let debug = format!("{:?}", guid);
let lower_hex = format!("{:x}", guid);
let upper_hex = format!("{:X}", guid);
assert_eq!(debug, format!("Guid({GUID_HEX})"));
assert_eq!(lower_hex, GUID_HEX);
assert_eq!(upper_hex, GUID_HEX.to_ascii_uppercase());
}
#[test]
fn marker() {
fn require_send_sync<T: Send + Sync>() {}
require_send_sync::<Device>();
require_send_sync::<DeviceList>();
require_send_sync::<Guid>();
}
}