use std::{
ffi::{CStr, CString, NulError},
fmt::{self, Display, Formatter},
path::Path,
};
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct EnumConversionError<T> {
pub name: &'static str,
pub value: T,
}
impl<T: Display> Display for EnumConversionError<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "unknown {} value {}", self.name, self.value)
}
}
impl<T: Display + fmt::Debug> std::error::Error for EnumConversionError<T> {}
pub fn try_enum_from_raw<E, T>(name: &'static str, value: T) -> Result<E, EnumConversionError<T>>
where
E: TryFrom<T>,
T: Copy,
{
E::try_from(value).map_err(|_| EnumConversionError { name, value })
}
pub fn checked_int<T, U, E>(value: T, name: &str, error: impl FnOnce(String) -> E) -> Result<U, E>
where
T: TryInto<U>,
{
value.try_into().map_err(|_| error(name.into()))
}
#[macro_export]
macro_rules! impl_enum_conversion {
($ty:ty, $from:ty, $into:ty $(,)?) => {
const _: () = {
impl From<$from> for $into {
fn from(value: $from) -> Self {
let value = value as $ty;
<$into>::try_from(value).unwrap()
}
}
impl From<$into> for $from {
fn from(value: $into) -> Self {
let value = value.into();
unsafe { core::mem::transmute::<$ty, $from>(value) }
}
}
};
};
($from:ty, $into:ty $(,)?) => {
impl_enum_conversion!(u32, $from, $into);
};
}
#[macro_export]
macro_rules! impl_enum_display {
($enum:ty, { $($variant:path => $name:expr),+ $(,)? } $(,)?) => {
impl std::fmt::Display for $enum {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
$($variant => f.write_str($name),)+
}
}
}
};
}
#[macro_export]
macro_rules! assert_close {
($actual:expr, $expected:expr $(,)?) => {
$crate::assert_close!($actual, $expected, 1.0e-5);
};
($actual:expr, $expected:expr, $tolerance:expr $(,)?) => {{
let actual = $actual;
let expected = $expected;
let tolerance = $tolerance as f64;
assert_eq!(
actual.len(),
expected.len(),
"assert_close length mismatch: actual len {}, expected len {}",
actual.len(),
expected.len(),
);
for (index, (actual, expected)) in actual.iter().zip(expected.iter()).enumerate() {
let actual = *actual as f64;
let expected = *expected as f64;
let difference = (actual - expected).abs();
assert!(
difference <= tolerance,
"assert_close failed at index {index}: actual {actual}, expected {expected}, difference {difference}, tolerance {tolerance}",
);
}
}};
($actual:expr, $expected:expr, $tolerance:expr, $name:expr $(,)?) => {{
let actual = $actual as f64;
let expected = $expected as f64;
let tolerance = $tolerance as f64;
let difference = (actual - expected).abs();
assert!(
difference <= tolerance,
"assert_close failed for {}: actual {actual}, expected {expected}, difference {difference}, tolerance {tolerance}",
$name,
);
}};
}
pub fn path_to_cstring(path: &Path) -> Result<CString, NulError> {
#[cfg(unix)]
{
CString::new(path.as_os_str().as_bytes())
}
#[cfg(not(unix))]
{
CString::new(path.as_os_str().to_string_lossy().as_bytes())
}
}
pub fn string_from_c_chars(buffer: &[i8]) -> String {
let bytes = unsafe { std::slice::from_raw_parts(buffer.as_ptr().cast::<u8>(), buffer.len()) };
let end = bytes
.iter()
.position(|byte| *byte == 0)
.unwrap_or(bytes.len());
String::from_utf8_lossy(&bytes[..end]).into_owned()
}
pub unsafe fn string_from_c_ptr(ptr: *const i8) -> String {
if ptr.is_null() {
return String::new();
}
unsafe { CStr::from_ptr(ptr).to_string_lossy().into_owned() }
}
pub fn copy_string_to_c_chars<const N: usize>(buffer: &mut [i8; N], value: &str) {
let bytes = value.as_bytes();
let len = bytes.len().min(N.saturating_sub(1));
for (slot, byte) in buffer.iter_mut().zip(bytes.iter().copied()).take(len) {
*slot = byte as i8;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct LibraryVersion {
major: u64,
minor: u64,
patch: u64,
}
impl LibraryVersion {
pub const fn new(major: u64, minor: u64, patch: u64) -> Self {
Self {
major,
minor,
patch,
}
}
pub const fn major(self) -> u64 {
self.major
}
pub const fn minor(self) -> u64 {
self.minor
}
pub const fn patch(self) -> u64 {
self.patch
}
pub const fn raw(self) -> u64 {
self.major * 10000 + self.minor * 100 + self.patch
}
}
impl PartialEq<u64> for LibraryVersion {
fn eq(&self, other: &u64) -> bool {
self.raw() == *other
}
}
impl PartialOrd<u64> for LibraryVersion {
fn partial_cmp(&self, other: &u64) -> Option<std::cmp::Ordering> {
self.raw().partial_cmp(other)
}
}
impl PartialEq<u32> for LibraryVersion {
fn eq(&self, other: &u32) -> bool {
self.raw() == u64::from(*other)
}
}
impl PartialOrd<u32> for LibraryVersion {
fn partial_cmp(&self, other: &u32) -> Option<std::cmp::Ordering> {
self.raw().partial_cmp(&u64::from(*other))
}
}
impl PartialEq<i32> for LibraryVersion {
fn eq(&self, other: &i32) -> bool {
*other >= 0 && self.raw() == *other as u64
}
}
impl PartialOrd<i32> for LibraryVersion {
fn partial_cmp(&self, other: &i32) -> Option<std::cmp::Ordering> {
if *other < 0 {
Some(std::cmp::Ordering::Greater)
} else {
self.raw().partial_cmp(&(*other as u64))
}
}
}
impl Display for LibraryVersion {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.raw(), f)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct CudaRuntimeVersion {
pub raw: u64,
}
impl CudaRuntimeVersion {
pub const fn from_raw(raw: u64) -> Self {
Self { raw }
}
pub const fn raw(self) -> u64 {
self.raw
}
}
impl From<u64> for CudaRuntimeVersion {
fn from(raw: u64) -> Self {
Self::from_raw(raw)
}
}
impl From<u32> for CudaRuntimeVersion {
fn from(raw: u32) -> Self {
Self::from_raw(raw as u64)
}
}
impl From<usize> for CudaRuntimeVersion {
fn from(raw: usize) -> Self {
Self::from_raw(raw as u64)
}
}
impl From<i32> for CudaRuntimeVersion {
fn from(raw: i32) -> Self {
Self::from_raw(raw as u64)
}
}
impl PartialEq<u64> for CudaRuntimeVersion {
fn eq(&self, other: &u64) -> bool {
self.raw == *other
}
}
impl PartialOrd<u64> for CudaRuntimeVersion {
fn partial_cmp(&self, other: &u64) -> Option<std::cmp::Ordering> {
self.raw.partial_cmp(other)
}
}
impl PartialEq<u32> for CudaRuntimeVersion {
fn eq(&self, other: &u32) -> bool {
self.raw == u64::from(*other)
}
}
impl PartialOrd<u32> for CudaRuntimeVersion {
fn partial_cmp(&self, other: &u32) -> Option<std::cmp::Ordering> {
self.raw.partial_cmp(&u64::from(*other))
}
}
impl PartialEq<i32> for CudaRuntimeVersion {
fn eq(&self, other: &i32) -> bool {
*other >= 0 && self.raw == *other as u64
}
}
impl PartialOrd<i32> for CudaRuntimeVersion {
fn partial_cmp(&self, other: &i32) -> Option<std::cmp::Ordering> {
if *other < 0 {
Some(std::cmp::Ordering::Greater)
} else {
self.raw.partial_cmp(&(*other as u64))
}
}
}
impl Display for CudaRuntimeVersion {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self.raw, f)
}
}