use std::{
any::TypeId,
collections::HashMap,
os::raw::c_uchar,
sync::{Arc, Mutex},
};
use crate::{cimpl::cimpl_error::CimplError, error::Error, maybe_send_sync::MaybeSend};
type CleanupFn = Box<dyn FnMut() + Send>;
pub struct PointerRegistry {
tracked: Mutex<HashMap<usize, (TypeId, CleanupFn)>>,
}
impl PointerRegistry {
fn new() -> Self {
Self {
tracked: Mutex::new(HashMap::new()),
}
}
fn track(&self, ptr: usize, type_id: TypeId, cleanup: CleanupFn) {
if ptr != 0 {
if let Ok(mut tracked) = self.tracked.lock() {
tracked.insert(ptr, (type_id, cleanup));
}
}
}
pub fn validate(&self, ptr: usize, expected_type: TypeId) -> Result<(), Error> {
if ptr == 0 {
return Err(Error::from(CimplError::null_parameter("pointer")));
}
let tracked = self
.tracked
.lock()
.map_err(|_| Error::from(CimplError::mutex_poisoned()))?;
match tracked.get(&ptr) {
Some((actual_type, _)) if *actual_type == expected_type => Ok(()),
Some(_) => Err(Error::from(CimplError::wrong_pointer_type(ptr as u64))),
None => Err(Error::from(CimplError::untracked_pointer(ptr as u64))),
}
}
pub fn untrack(&self, ptr: usize, expected_type: TypeId) -> Result<(), Error> {
if ptr == 0 {
return Err(Error::from(CimplError::null_parameter("pointer")));
}
let mut tracked = self
.tracked
.lock()
.map_err(|_| Error::from(CimplError::mutex_poisoned()))?;
match tracked.get(&ptr) {
Some((actual_type, _)) if *actual_type == expected_type => {
tracked.remove(&ptr);
Ok(())
}
Some(_) => Err(Error::from(CimplError::wrong_pointer_type(ptr as u64))),
None => Err(Error::from(CimplError::untracked_pointer(ptr as u64))),
}
}
pub fn free(&self, ptr: usize) -> Result<(), Error> {
if ptr == 0 {
return Ok(()); }
let mut cleanup = {
let mut tracked = self
.tracked
.lock()
.map_err(|_| Error::from(CimplError::mutex_poisoned()))?;
match tracked.remove(&ptr) {
Some((_, cleanup)) => cleanup,
None => return Err(Error::from(CimplError::untracked_pointer(ptr as u64))),
}
};
cleanup(); Ok(())
}
}
impl Drop for PointerRegistry {
fn drop(&mut self) {
let tracked = self.tracked.lock().unwrap_or_else(|e| e.into_inner());
if !tracked.is_empty() {
eprintln!(
"\n⚠️ WARNING: {} pointer(s) were not freed at shutdown!",
tracked.len()
);
eprintln!("This indicates C code did not properly free all allocated pointers.");
eprintln!("Each pointer should be freed exactly once with cimpl_free().\n");
}
}
}
pub(crate) fn get_registry() -> &'static PointerRegistry {
use std::sync::OnceLock;
static REGISTRY: OnceLock<PointerRegistry> = OnceLock::new();
REGISTRY.get_or_init(PointerRegistry::new)
}
pub fn track_box<T: 'static + MaybeSend>(ptr: *mut T) -> *mut T {
let ptr_val = ptr as usize; let cleanup = move || unsafe {
drop(Box::from_raw(ptr_val as *mut T));
};
get_registry().track(ptr as usize, TypeId::of::<T>(), Box::new(cleanup));
ptr
}
pub fn track_arc<T: 'static + MaybeSend>(ptr: *mut T) -> *mut T {
let ptr_val = ptr as usize; let cleanup = move || unsafe {
drop(Arc::from_raw(ptr_val as *const T));
};
get_registry().track(ptr as usize, TypeId::of::<T>(), Box::new(cleanup));
ptr
}
pub fn track_arc_mutex<T: 'static + MaybeSend>(ptr: *mut Mutex<T>) -> *mut Mutex<T> {
let ptr_val = ptr as usize; let cleanup = move || unsafe {
drop(Arc::from_raw(ptr_val as *const Mutex<T>));
};
get_registry().track(ptr as usize, TypeId::of::<Mutex<T>>(), Box::new(cleanup));
ptr
}
pub fn validate_pointer<T: 'static>(ptr: *mut T) -> Result<(), Error> {
get_registry().validate(ptr as usize, TypeId::of::<T>())
}
pub fn untrack_pointer<T: 'static>(ptr: *mut T) -> Result<(), Error> {
get_registry().untrack(ptr as usize, TypeId::of::<T>())
}
#[no_mangle]
pub extern "C" fn cimpl_free(ptr: *mut std::ffi::c_void) -> i32 {
match get_registry().free(ptr as usize) {
Ok(()) => 0,
Err(e) => {
let error = CimplError::from(e);
#[cfg(test)]
{
if ptr as usize != 0 {
eprintln!(
"\n⚠️ ERROR: cimpl_free failed for pointer 0x{:x}: {}\n\
This usually means:\n\
1. The pointer was not allocated with box_tracked!/track_box\n\
2. The pointer was already freed (double-free)\n\
3. The pointer is invalid/corrupted\n",
ptr as usize, error
);
}
}
error.set_last();
-1
}
}
}
pub unsafe fn is_safe_buffer_size(size: usize, ptr: *const c_uchar) -> bool {
if size == 0 || size > isize::MAX as usize {
return false;
}
if !ptr.is_null() {
let end_ptr = ptr.add(size);
if end_ptr < ptr {
return false; }
}
true
}
pub unsafe fn safe_slice_from_raw_parts(
ptr: *const c_uchar,
len: usize,
param_name: &str,
) -> Result<&[u8], Error> {
if ptr.is_null() {
return Err(Error::from(CimplError::null_parameter(param_name)));
}
if !is_safe_buffer_size(len, ptr) {
return Err(Error::from(CimplError::invalid_buffer_size(
len, param_name,
)));
}
Ok(std::slice::from_raw_parts(ptr, len))
}
pub fn to_c_string(s: String) -> *mut std::os::raw::c_char {
use std::ffi::CString;
match CString::new(s) {
Ok(c_str) => {
let ptr = c_str.into_raw();
let ptr_val = ptr as usize;
get_registry().track(
ptr_val,
TypeId::of::<CString>(),
Box::new(move || unsafe {
drop(CString::from_raw(ptr_val as *mut std::os::raw::c_char))
}),
);
ptr
}
Err(_) => std::ptr::null_mut(),
}
}
pub fn to_c_bytes(bytes: Vec<u8>) -> *const c_uchar {
let len = bytes.len();
if len == 0 {
return std::ptr::null();
}
let ptr = Box::into_raw(bytes.into_boxed_slice()) as *const c_uchar;
let ptr_val = ptr as usize;
get_registry().track(
ptr_val,
TypeId::of::<Box<[u8]>>(),
Box::new(move || {
unsafe {
drop(Box::from_raw(std::ptr::slice_from_raw_parts_mut(
ptr_val as *mut u8,
len,
)))
}
}),
);
ptr
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_allocation_tracking_double_free_string() {
use std::ffi::CString;
let test_string = CString::new("test allocation tracking").unwrap();
let c_string = to_c_string(test_string.to_str().unwrap().to_string());
assert!(!c_string.is_null());
let result1 = cimpl_free(c_string as *mut std::ffi::c_void);
assert_eq!(result1, 0);
let result2 = cimpl_free(c_string as *mut std::ffi::c_void);
assert_eq!(result2, -1);
}
#[test]
fn test_to_c_string_basic() {
let rust_string = "Hello, C!".to_string();
let c_string = to_c_string(rust_string);
assert!(!c_string.is_null());
cimpl_free(c_string as *mut std::ffi::c_void);
}
#[test]
fn test_to_c_bytes_basic() {
let bytes = vec![1, 2, 3, 4, 5];
let ptr = to_c_bytes(bytes);
assert!(!ptr.is_null());
cimpl_free(ptr as *mut std::ffi::c_void);
}
#[test]
fn test_to_c_string_with_null_byte() {
let bad_string = "Hello\0World".to_string();
let c_string = to_c_string(bad_string);
assert!(c_string.is_null());
}
#[test]
fn test_untrack_pointer_removes_from_registry() {
let ptr = track_box(Box::into_raw(Box::new(42i32)));
assert!(validate_pointer::<i32>(ptr).is_ok());
assert!(untrack_pointer::<i32>(ptr).is_ok());
let result = cimpl_free(ptr as *mut std::ffi::c_void);
assert_eq!(result, -1);
unsafe { drop(Box::from_raw(ptr)) };
}
#[test]
fn test_untrack_wrong_type_fails() {
let ptr = track_box(Box::into_raw(Box::new(42i32)));
let result = untrack_pointer::<u64>(ptr as *mut u64);
assert!(result.is_err());
assert!(validate_pointer::<i32>(ptr).is_ok());
cimpl_free(ptr as *mut std::ffi::c_void);
}
#[test]
fn test_untrack_null_pointer_fails() {
let result = untrack_pointer::<i32>(std::ptr::null_mut());
assert!(result.is_err());
}
#[test]
fn test_untrack_unregistered_pointer_fails() {
let ptr = Box::into_raw(Box::new(42i32));
let result = untrack_pointer::<i32>(ptr);
assert!(result.is_err());
unsafe { drop(Box::from_raw(ptr)) };
}
#[test]
fn test_untrack_already_untracked_fails() {
let ptr = track_box(Box::into_raw(Box::new(42i32)));
assert!(untrack_pointer::<i32>(ptr).is_ok());
let result = untrack_pointer::<i32>(ptr);
assert!(result.is_err());
unsafe { drop(Box::from_raw(ptr)) };
}
}