use crate::error::{DaosError, Result};
use std::ffi::{CStr, CString};
use std::ptr::NonNull;
pub trait NonNullExt<T> {
fn check_null(ptr: *const T) -> Result<NonNull<T>>;
fn check_null_mut(ptr: *mut T) -> Result<NonNull<T>>;
}
impl<T> NonNullExt<T> for NonNull<T> {
#[inline]
fn check_null(ptr: *const T) -> Result<NonNull<T>> {
if ptr.is_null() {
Err(DaosError::InvalidArg)
} else {
Ok(unsafe { NonNull::new_unchecked(ptr as *mut T) })
}
}
#[inline]
fn check_null_mut(ptr: *mut T) -> Result<NonNull<T>> {
if ptr.is_null() {
Err(DaosError::InvalidArg)
} else {
Ok(unsafe { NonNull::new_unchecked(ptr) })
}
}
}
#[inline]
pub fn validate_non_null<T>(ptr: *const T) -> Result<NonNull<T>> {
NonNull::check_null(ptr)
}
#[inline]
pub fn validate_non_null_mut<T>(ptr: *mut T) -> Result<NonNull<T>> {
NonNull::check_null_mut(ptr)
}
pub fn validate_c_str<'a>(ptr: *const std::os::raw::c_char) -> Result<&'a str> {
if ptr.is_null() {
return Err(DaosError::InvalidArg);
}
let c_str = unsafe { CStr::from_ptr(ptr) };
c_str
.to_str()
.map_err(|_| DaosError::Internal("Invalid UTF-8 in C string".to_string()))
}
pub fn validate_c_str_mut(ptr: *mut std::os::raw::c_char) -> Result<()> {
if ptr.is_null() {
Err(DaosError::InvalidArg)
} else {
Ok(())
}
}
pub fn as_char_ptr(s: &str) -> Result<CString> {
CString::new(s).map_err(|_| DaosError::InvalidArg)
}
pub fn as_const_char_ptr(s: &str) -> Result<CString> {
as_char_ptr(s)
}
#[inline]
pub fn as_mut_ptr<T>(ptr: NonNull<T>) -> *mut T {
ptr.as_ptr()
}
pub trait AsFFIPtr {
type Target;
fn as_ffi_ptr(&self) -> Result<Self::Target>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_check_null_rejects_null() {
let null_ptr: *const i32 = std::ptr::null();
let result = NonNull::<i32>::check_null(null_ptr);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), DaosError::InvalidArg));
}
#[test]
fn test_check_null_accepts_valid() {
let value = 42i32;
let valid_ptr: *const i32 = &value;
let result = NonNull::<i32>::check_null(valid_ptr);
assert!(result.is_ok());
}
#[test]
fn test_validate_non_null_mut() {
let mut value = 42i32;
let result = validate_non_null_mut(&mut value);
assert!(result.is_ok());
}
#[test]
fn test_validate_c_str_rejects_null() {
let result = validate_c_str(std::ptr::null());
assert!(result.is_err());
}
#[test]
fn test_validate_c_str_valid() {
let c_string = CString::new("hello").unwrap();
let ptr = c_string.as_ptr();
let result = validate_c_str(ptr);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "hello");
}
#[test]
fn test_as_char_ptr() {
let s = "test string";
let result = as_char_ptr(s);
assert!(result.is_ok());
assert_eq!(result.unwrap().as_bytes(), s.as_bytes());
}
#[test]
fn test_as_char_ptr_rejects_nul() {
let s = "line1\0line2";
let result = as_char_ptr(s);
assert!(result.is_err()); }
}