use std::{
ffi::{CStr, CString, NulError},
fmt::{self, Display, Formatter},
path::Path,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct EnumConversionError<T> {
pub name: &'static str,
pub value: T,
}
impl<T: Display> Display for EnumConversionError<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "unknown {} value {}", self.name, self.value)
}
}
impl<T: Display + fmt::Debug> std::error::Error for EnumConversionError<T> {}
pub fn try_enum_from_raw<E, T>(name: &'static str, value: T) -> Result<E, EnumConversionError<T>>
where
E: TryFrom<T>,
T: Copy,
{
E::try_from(value).map_err(|_| EnumConversionError { name, value })
}
#[macro_export]
macro_rules! impl_enum_conversion {
($ty:ty, $from:ty, $into:ty $(,)?) => {
const _: () = {
impl From<$from> for $into {
fn from(value: $from) -> Self {
let value = value as $ty;
<$into>::try_from(value).unwrap()
}
}
impl From<$into> for $from {
fn from(value: $into) -> Self {
let value = value.into();
unsafe { core::mem::transmute::<$ty, $from>(value) }
}
}
};
};
($from:ty, $into:ty $(,)?) => {
impl_enum_conversion!(u32, $from, $into);
};
}
#[macro_export]
macro_rules! assert_close {
($actual:expr, $expected:expr $(,)?) => {
$crate::assert_close!($actual, $expected, 1.0e-5);
};
($actual:expr, $expected:expr, $tolerance:expr $(,)?) => {{
let actual = $actual;
let expected = $expected;
let tolerance = $tolerance as f64;
assert_eq!(
actual.len(),
expected.len(),
"assert_close length mismatch: actual len {}, expected len {}",
actual.len(),
expected.len(),
);
for (index, (actual, expected)) in actual.iter().zip(expected.iter()).enumerate() {
let actual = *actual as f64;
let expected = *expected as f64;
let difference = (actual - expected).abs();
assert!(
difference <= tolerance,
"assert_close failed at index {index}: actual {actual}, expected {expected}, difference {difference}, tolerance {tolerance}",
);
}
}};
($actual:expr, $expected:expr, $tolerance:expr, $name:expr $(,)?) => {{
let actual = $actual as f64;
let expected = $expected as f64;
let tolerance = $tolerance as f64;
let difference = (actual - expected).abs();
assert!(
difference <= tolerance,
"assert_close failed for {}: actual {actual}, expected {expected}, difference {difference}, tolerance {tolerance}",
$name,
);
}};
}
pub fn path_to_cstring(path: &Path) -> Result<CString, NulError> {
CString::new(path.as_os_str().to_string_lossy().as_bytes())
}
pub fn string_from_c_chars(buffer: &[i8]) -> String {
if buffer.is_empty() {
return String::new();
}
unsafe {
CStr::from_ptr(buffer.as_ptr())
.to_string_lossy()
.into_owned()
}
}
pub unsafe fn string_from_c_ptr(ptr: *const i8) -> String {
if ptr.is_null() {
return String::new();
}
unsafe { CStr::from_ptr(ptr).to_string_lossy().into_owned() }
}
pub fn copy_string_to_c_chars<const N: usize>(buffer: &mut [i8; N], value: &str) {
let bytes = value.as_bytes();
let len = bytes.len().min(N.saturating_sub(1));
for (slot, byte) in buffer.iter_mut().zip(bytes.iter().copied()).take(len) {
*slot = byte as i8;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct LibraryVersion {
pub raw: u64,
}
impl LibraryVersion {
pub const fn from_raw(raw: u64) -> Self {
Self { raw }
}
pub const fn raw(self) -> u64 {
self.raw
}
}
impl From<u64> for LibraryVersion {
fn from(raw: u64) -> Self {
Self::from_raw(raw)
}
}
impl From<u32> for LibraryVersion {
fn from(raw: u32) -> Self {
Self::from_raw(raw as u64)
}
}
impl From<usize> for LibraryVersion {
fn from(raw: usize) -> Self {
Self::from_raw(raw as u64)
}
}
impl From<i32> for LibraryVersion {
fn from(raw: i32) -> Self {
Self::from_raw(raw as u64)
}
}
impl PartialEq<u64> for LibraryVersion {
fn eq(&self, other: &u64) -> bool {
self.raw == *other
}
}
impl PartialOrd<u64> for LibraryVersion {
fn partial_cmp(&self, other: &u64) -> Option<std::cmp::Ordering> {
self.raw.partial_cmp(other)
}
}
impl PartialEq<u32> for LibraryVersion {
fn eq(&self, other: &u32) -> bool {
self.raw == u64::from(*other)
}
}
impl PartialOrd<u32> for LibraryVersion {
fn partial_cmp(&self, other: &u32) -> Option<std::cmp::Ordering> {
self.raw.partial_cmp(&u64::from(*other))
}
}
impl PartialEq<i32> for LibraryVersion {
fn eq(&self, other: &i32) -> bool {
*other >= 0 && self.raw == *other as u64
}
}
impl PartialOrd<i32> for LibraryVersion {
fn partial_cmp(&self, other: &i32) -> Option<std::cmp::Ordering> {
if *other < 0 {
Some(std::cmp::Ordering::Greater)
} else {
self.raw.partial_cmp(&(*other as u64))
}
}
}
impl Display for LibraryVersion {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.raw, f)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct CudaRuntimeVersion {
pub raw: u64,
}
impl CudaRuntimeVersion {
pub const fn from_raw(raw: u64) -> Self {
Self { raw }
}
pub const fn raw(self) -> u64 {
self.raw
}
}
impl From<u64> for CudaRuntimeVersion {
fn from(raw: u64) -> Self {
Self::from_raw(raw)
}
}
impl From<u32> for CudaRuntimeVersion {
fn from(raw: u32) -> Self {
Self::from_raw(raw as u64)
}
}
impl From<usize> for CudaRuntimeVersion {
fn from(raw: usize) -> Self {
Self::from_raw(raw as u64)
}
}
impl From<i32> for CudaRuntimeVersion {
fn from(raw: i32) -> Self {
Self::from_raw(raw as u64)
}
}
impl PartialEq<u64> for CudaRuntimeVersion {
fn eq(&self, other: &u64) -> bool {
self.raw == *other
}
}
impl PartialOrd<u64> for CudaRuntimeVersion {
fn partial_cmp(&self, other: &u64) -> Option<std::cmp::Ordering> {
self.raw.partial_cmp(other)
}
}
impl PartialEq<u32> for CudaRuntimeVersion {
fn eq(&self, other: &u32) -> bool {
self.raw == u64::from(*other)
}
}
impl PartialOrd<u32> for CudaRuntimeVersion {
fn partial_cmp(&self, other: &u32) -> Option<std::cmp::Ordering> {
self.raw.partial_cmp(&u64::from(*other))
}
}
impl PartialEq<i32> for CudaRuntimeVersion {
fn eq(&self, other: &i32) -> bool {
*other >= 0 && self.raw == *other as u64
}
}
impl PartialOrd<i32> for CudaRuntimeVersion {
fn partial_cmp(&self, other: &i32) -> Option<std::cmp::Ordering> {
if *other < 0 {
Some(std::cmp::Ordering::Greater)
} else {
self.raw.partial_cmp(&(*other as u64))
}
}
}
impl Display for CudaRuntimeVersion {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.raw, f)
}
}