use std::{
any::TypeId,
collections::HashMap,
ffi::c_void,
os::raw::c_uchar,
sync::{Arc, Mutex},
};
use crate::error::Error;
type CleanupFn = Box<dyn FnMut() + Send>;
pub struct PointerRegistry {
tracked: Mutex<HashMap<usize, (TypeId, CleanupFn)>>,
}
impl PointerRegistry {
fn new() -> Self {
Self {
tracked: Mutex::new(HashMap::new()),
}
}
fn track(&self, ptr: usize, type_id: TypeId, cleanup: CleanupFn) {
if ptr != 0 {
if let Ok(mut tracked) = self.tracked.lock() {
tracked.insert(ptr, (type_id, cleanup));
}
}
}
pub fn validate(&self, ptr: usize, expected_type: TypeId) -> Result<(), Error> {
if ptr == 0 {
return Err(Error::null_parameter("pointer"));
}
let tracked = self.tracked.lock().map_err(|_| Error::mutex_poisoned())?;
match tracked.get(&ptr) {
Some((actual_type, _)) if *actual_type == expected_type => Ok(()),
Some(_) => Err(Error::wrong_pointer_type(ptr as u64)),
None => Err(Error::untracked_pointer(ptr as u64)),
}
}
pub fn free(&self, ptr: usize) -> Result<(), Error> {
if ptr == 0 {
return Ok(()); }
let mut cleanup = {
let mut tracked = self.tracked.lock().map_err(|_| Error::mutex_poisoned())?;
match tracked.remove(&ptr) {
Some((_, cleanup)) => cleanup,
None => return Err(Error::untracked_pointer(ptr as u64)),
}
};
cleanup(); Ok(())
}
}
impl Drop for PointerRegistry {
fn drop(&mut self) {
let tracked = self.tracked.lock().unwrap_or_else(|e| e.into_inner());
if !tracked.is_empty() {
eprintln!(
"\n⚠️ WARNING: {} pointer(s) were not freed at shutdown!",
tracked.len()
);
eprintln!("This indicates C code did not properly free all allocated pointers.");
eprintln!("Each pointer should be freed exactly once with cimpl_free().\n");
}
}
}
pub(crate) fn get_registry() -> &'static PointerRegistry {
use std::sync::OnceLock;
static REGISTRY: OnceLock<PointerRegistry> = OnceLock::new();
REGISTRY.get_or_init(PointerRegistry::new)
}
pub fn track_box<T: 'static>(ptr: *mut T) -> *mut T {
let ptr_val = ptr as usize; let cleanup = move || unsafe {
drop(Box::from_raw(ptr_val as *mut T));
};
get_registry().track(ptr as usize, TypeId::of::<T>(), Box::new(cleanup));
ptr
}
pub fn track_arc<T: 'static>(ptr: *mut T) -> *mut T {
let ptr_val = ptr as usize; let cleanup = move || unsafe {
drop(Arc::from_raw(ptr_val as *const T));
};
get_registry().track(ptr as usize, TypeId::of::<T>(), Box::new(cleanup));
ptr
}
pub fn track_arc_mutex<T: 'static>(ptr: *mut Mutex<T>) -> *mut Mutex<T> {
let ptr_val = ptr as usize; let cleanup = move || unsafe {
drop(Arc::from_raw(ptr_val as *const Mutex<T>));
};
get_registry().track(ptr as usize, TypeId::of::<Mutex<T>>(), Box::new(cleanup));
ptr
}
pub fn validate_pointer<T: 'static>(ptr: *mut T) -> Result<(), Error> {
get_registry().validate(ptr as usize, TypeId::of::<T>())
}
pub fn cimpl_free(ptr: *mut c_void) -> i32 {
match get_registry().free(ptr as usize) {
Ok(()) => 0,
Err(error) => {
#[cfg(test)]
{
if ptr as usize != 0 {
eprintln!(
"\n⚠️ ERROR: cimpl_free failed for pointer 0x{:x}: {}\n\
This usually means:\n\
1. The pointer was not allocated with box_tracked!/track_box\n\
2. The pointer was already freed (double-free)\n\
3. The pointer is invalid/corrupted\n",
ptr as usize, error
);
}
}
error.set_last();
-1
}
}
}
pub unsafe fn is_safe_buffer_size(size: usize, ptr: *const c_uchar) -> bool {
if size == 0 || size > isize::MAX as usize {
return false;
}
if !ptr.is_null() {
let end_ptr = ptr.add(size);
if end_ptr < ptr {
return false; }
}
true
}
pub unsafe fn safe_slice_from_raw_parts(
ptr: *const c_uchar,
len: usize,
param_name: &str,
) -> Result<&[u8], Error> {
if ptr.is_null() {
return Err(Error::null_parameter(param_name));
}
if !is_safe_buffer_size(len, ptr) {
return Err(Error::invalid_buffer_size(len, param_name));
}
Ok(std::slice::from_raw_parts(ptr, len))
}
pub fn to_c_string(s: String) -> *mut std::os::raw::c_char {
use std::ffi::CString;
match CString::new(s) {
Ok(c_str) => {
let ptr = c_str.into_raw();
let ptr_val = ptr as usize;
get_registry().track(
ptr_val,
TypeId::of::<CString>(),
Box::new(move || unsafe {
drop(CString::from_raw(ptr_val as *mut std::os::raw::c_char))
}),
);
ptr
}
Err(_) => std::ptr::null_mut(),
}
}
pub fn to_c_bytes(bytes: Vec<u8>) -> *const c_uchar {
let len = bytes.len();
let ptr = Box::into_raw(bytes.into_boxed_slice()) as *const c_uchar;
let ptr_val = ptr as usize;
get_registry().track(
ptr_val,
TypeId::of::<Box<[u8]>>(),
Box::new(move || {
unsafe {
drop(Box::from_raw(std::ptr::slice_from_raw_parts_mut(
ptr_val as *mut u8,
len,
)))
}
}),
);
ptr
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allocation_tracking_double_free_string() {
use std::ffi::CString;
let test_string = CString::new("test allocation tracking").unwrap();
let c_string = to_c_string(test_string.to_str().unwrap().to_string());
assert!(!c_string.is_null());
let result1 = cimpl_free(c_string as *mut c_void);
assert_eq!(result1, 0);
let result2 = cimpl_free(c_string as *mut c_void);
assert_eq!(result2, -1);
}
#[test]
fn test_to_c_string_basic() {
let rust_string = "Hello, C!".to_string();
let c_string = to_c_string(rust_string);
assert!(!c_string.is_null());
cimpl_free(c_string as *mut c_void);
}
#[test]
fn test_to_c_bytes_basic() {
let bytes = vec![1, 2, 3, 4, 5];
let ptr = to_c_bytes(bytes);
assert!(!ptr.is_null());
cimpl_free(ptr as *mut c_void);
}
#[test]
fn test_to_c_string_with_null_byte() {
let bad_string = "Hello\0World".to_string();
let c_string = to_c_string(bad_string);
assert!(c_string.is_null());
}
#[test]
fn test_to_c_string_empty() {
let empty = "".to_string();
let c_string = to_c_string(empty);
assert!(!c_string.is_null());
cimpl_free(c_string as *mut c_void);
}
#[test]
fn test_to_c_bytes_empty() {
let empty_bytes: Vec<u8> = vec![];
let ptr = to_c_bytes(empty_bytes);
assert!(!ptr.is_null());
cimpl_free(ptr as *mut c_void);
}
#[test]
fn test_box_tracked() {
struct TestStruct {
value: i32,
}
let test = TestStruct { value: 42 };
let ptr = track_box(Box::into_raw(Box::new(test)));
assert!(!ptr.is_null());
unsafe {
let test_ref = &*ptr;
assert_eq!(test_ref.value, 42);
}
cimpl_free(ptr as *mut c_void);
}
#[test]
fn test_track_box_returns_pointer() {
let value = Box::new(123i32);
let ptr = track_box(Box::into_raw(value));
assert!(!ptr.is_null());
unsafe {
assert_eq!(*ptr, 123);
}
cimpl_free(ptr as *mut c_void);
}
#[test]
fn test_validate_pointer_with_valid_pointer() {
let value = Box::new(456i32);
let ptr = track_box(Box::into_raw(value));
let result = validate_pointer::<i32>(ptr);
assert!(result.is_ok());
cimpl_free(ptr as *mut c_void);
}
#[test]
fn test_validate_pointer_with_null() {
let result = validate_pointer::<i32>(std::ptr::null_mut());
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.variant(), Some("NullParameter"));
}
#[test]
fn test_validate_pointer_with_untracked() {
let value = Box::new(789i32);
let ptr = Box::into_raw(value);
let result = validate_pointer::<i32>(ptr);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.variant(), Some("UntrackedPointer"));
unsafe {
let _ = Box::from_raw(ptr);
}
}
#[test]
fn test_safe_slice_from_raw_parts_valid() {
let data = vec![1u8, 2, 3, 4, 5];
let ptr = data.as_ptr();
let len = data.len();
let result = unsafe { safe_slice_from_raw_parts(ptr, len, "test_data") };
assert!(result.is_ok());
let slice = result.unwrap();
assert_eq!(slice, &[1, 2, 3, 4, 5]);
}
#[test]
fn test_safe_slice_from_raw_parts_null() {
let result = unsafe { safe_slice_from_raw_parts(std::ptr::null(), 5, "test_param") };
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.variant(), Some("NullParameter"));
}
#[test]
fn test_safe_slice_from_raw_parts_invalid_size() {
let data = vec![1u8, 2, 3];
let ptr = data.as_ptr();
let result = unsafe { safe_slice_from_raw_parts(ptr, usize::MAX, "overflow_test") };
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(err.variant(), Some("InvalidBufferSize"));
}
#[test]
fn test_arc_tracked() {
use std::sync::Arc;
let value = Arc::new(100i32);
let ptr = track_arc(Arc::into_raw(value) as *mut i32);
assert!(!ptr.is_null());
unsafe {
assert_eq!(*ptr, 100);
}
cimpl_free(ptr as *mut c_void);
}
#[test]
fn test_arc_mutex_tracked() {
use std::sync::{Arc, Mutex};
let value = Arc::new(Mutex::new(200i32));
let ptr = track_arc_mutex(Arc::into_raw(value) as *mut Mutex<i32>);
assert!(!ptr.is_null());
unsafe {
let guard = (*ptr).lock().unwrap();
assert_eq!(*guard, 200);
}
cimpl_free(ptr as *mut c_void);
}
}