use scirs2_core::ndarray::{Array1, Array2};
use std::collections::HashMap;
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::sync::{Arc, Mutex, RwLock};
pub struct MemorySafety;
impl MemorySafety {
pub fn document_safety(operation: &str) -> MemorySafetyGuarantee {
match operation {
"array_indexing" => MemorySafetyGuarantee {
operation: operation.to_string(),
guarantees: vec![
"Bounds checking prevents buffer overflows".to_string(),
"Panic on out-of-bounds access in debug mode".to_string(),
"Optional bounds checking in release mode for performance".to_string(),
],
unsafe_blocks: vec![],
mitigation_strategies: vec![
"Use checked indexing methods when bounds are uncertain".to_string(),
"Validate input dimensions before processing".to_string(),
],
},
"parallel_processing" => MemorySafetyGuarantee {
operation: operation.to_string(),
guarantees: vec![
"Send and Sync traits prevent data races".to_string(),
"Rayon provides work-stealing without data races".to_string(),
"Immutable borrows allow safe parallel reading".to_string(),
],
unsafe_blocks: vec![],
mitigation_strategies: vec![
"Use Arc<T> for shared ownership across threads".to_string(),
"Use Mutex<T> or RwLock<T> for shared mutable access".to_string(),
],
},
"gpu_operations" => MemorySafetyGuarantee {
operation: operation.to_string(),
guarantees: vec![
"CUDA memory is managed through RAII wrappers".to_string(),
"GPU pointers are opaque and cannot be dereferenced on CPU".to_string(),
"Automatic cleanup of GPU resources on drop".to_string(),
],
unsafe_blocks: vec![
"CUDA FFI calls require unsafe blocks".to_string(),
"Memory transfers between CPU and GPU use unsafe operations".to_string(),
],
mitigation_strategies: vec![
"Wrap all CUDA operations in safe abstractions".to_string(),
"Validate GPU memory allocation success".to_string(),
"Use typed GPU pointers to prevent type confusion".to_string(),
],
},
_ => MemorySafetyGuarantee {
operation: operation.to_string(),
guarantees: vec!["General Rust memory safety guarantees apply".to_string()],
unsafe_blocks: vec![],
mitigation_strategies: vec![],
},
}
}
pub fn validate_unsafe_usage(code_block: &str) -> UnsafeValidationResult {
let mut issues = Vec::new();
let mut recommendations = Vec::new();
if code_block.contains("transmute") {
issues.push("transmute operations can break type safety".to_string());
recommendations.push("Consider using safe casting alternatives".to_string());
}
if code_block.contains("from_raw_parts") {
issues.push("Raw pointer operations require careful validation".to_string());
recommendations.push("Ensure pointer validity and proper alignment".to_string());
}
if code_block.contains("assume_init") {
issues.push("Uninitialized memory access detected".to_string());
recommendations
.push("Use MaybeUninit for safer uninitialized memory handling".to_string());
}
let safety_score = if issues.is_empty() {
100
} else {
std::cmp::max(0, 100 - (issues.len() * 20)) as u8
};
UnsafeValidationResult {
safety_score,
issues,
recommendations,
requires_review: safety_score < 80,
}
}
}
#[derive(Debug, Clone)]
pub struct MemorySafetyGuarantee {
pub operation: String,
pub guarantees: Vec<String>,
pub unsafe_blocks: Vec<String>,
pub mitigation_strategies: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct UnsafeValidationResult {
pub safety_score: u8, pub issues: Vec<String>,
pub recommendations: Vec<String>,
pub requires_review: bool,
}
pub trait SafeArrayOps<T> {
fn safe_get(&self, index: &[usize]) -> Option<&T>;
fn safe_get_mut(&mut self, index: &[usize]) -> Option<&mut T>;
fn validate_dimensions(&self) -> Result<(), String>;
fn is_valid_index(&self, index: &[usize]) -> bool;
}
impl<T> SafeArrayOps<T> for Array2<T> {
fn safe_get(&self, index: &[usize]) -> Option<&T> {
if index.len() != 2 {
return None;
}
self.get((index[0], index[1]))
}
fn safe_get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
if index.len() != 2 {
return None;
}
self.get_mut((index[0], index[1]))
}
fn validate_dimensions(&self) -> Result<(), String> {
if self.nrows() == 0 || self.ncols() == 0 {
Err("Array has zero-sized dimension".to_string())
} else if self.nrows() > isize::MAX as usize || self.ncols() > isize::MAX as usize {
Err("Array dimension exceeds maximum safe size".to_string())
} else {
Ok(())
}
}
fn is_valid_index(&self, index: &[usize]) -> bool {
index.len() == 2 && index[0] < self.nrows() && index[1] < self.ncols()
}
}
impl<T> SafeArrayOps<T> for Array1<T> {
fn safe_get(&self, index: &[usize]) -> Option<&T> {
if index.len() != 1 {
return None;
}
self.get(index[0])
}
fn safe_get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
if index.len() != 1 {
return None;
}
self.get_mut(index[0])
}
fn validate_dimensions(&self) -> Result<(), String> {
if self.is_empty() {
Err("Array is empty".to_string())
} else if self.len() > isize::MAX as usize {
Err("Array length exceeds maximum safe size".to_string())
} else {
Ok(())
}
}
fn is_valid_index(&self, index: &[usize]) -> bool {
index.len() == 1 && index[0] < self.len()
}
}
pub struct SafeMemoryPool<T> {
pools: Arc<Mutex<HashMap<usize, Vec<Vec<T>>>>>,
allocated_count: Arc<Mutex<usize>>,
max_pool_size: usize,
}
impl<T> SafeMemoryPool<T> {
pub fn new() -> Self {
Self {
pools: Arc::new(Mutex::new(HashMap::new())),
allocated_count: Arc::new(Mutex::new(0)),
max_pool_size: 1000, }
}
pub fn with_limits(max_pool_size: usize) -> Self {
Self {
pools: Arc::new(Mutex::new(HashMap::new())),
allocated_count: Arc::new(Mutex::new(0)),
max_pool_size,
}
}
pub fn allocate(&self, capacity: usize) -> SafePooledBuffer<T> {
let buffer = {
let mut pools = self.pools.lock().unwrap_or_else(|e| e.into_inner());
if let Some(pool) = pools.get_mut(&capacity) {
if let Some(mut buffer) = pool.pop() {
buffer.clear();
buffer
} else {
Vec::with_capacity(capacity)
}
} else {
Vec::with_capacity(capacity)
}
};
{
let mut count = self
.allocated_count
.lock()
.unwrap_or_else(|e| e.into_inner());
*count += 1;
}
SafePooledBuffer {
buffer: Some(buffer),
capacity,
pool: self.pools.clone(),
allocated_count: self.allocated_count.clone(),
max_pool_size: self.max_pool_size,
}
}
pub fn stats(&self) -> MemoryPoolStats {
let allocated_count = *self
.allocated_count
.lock()
.unwrap_or_else(|e| e.into_inner());
let pools = self.pools.lock().unwrap_or_else(|e| e.into_inner());
let pooled_count: usize = pools.values().map(|v| v.len()).sum();
MemoryPoolStats {
allocated_count,
pooled_count,
pool_sizes: pools.iter().map(|(&k, v)| (k, v.len())).collect(),
}
}
}
impl<T> Default for SafeMemoryPool<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct MemoryPoolStats {
pub allocated_count: usize,
pub pooled_count: usize,
pub pool_sizes: Vec<(usize, usize)>, }
pub struct SafePooledBuffer<T> {
buffer: Option<Vec<T>>,
capacity: usize,
pool: Arc<Mutex<HashMap<usize, Vec<Vec<T>>>>>,
allocated_count: Arc<Mutex<usize>>,
max_pool_size: usize,
}
impl<T> SafePooledBuffer<T> {
pub fn as_mut_vec(&mut self) -> &mut Vec<T> {
self.buffer.as_mut().expect("Buffer has been consumed")
}
pub fn as_ref_vec(&self) -> &Vec<T> {
self.buffer.as_ref().expect("Buffer has been consumed")
}
pub fn into_inner(mut self) -> Vec<T> {
self.buffer.take().expect("Buffer has been consumed")
}
}
impl<T> Drop for SafePooledBuffer<T> {
fn drop(&mut self) {
if let Some(buffer) = self.buffer.take() {
let mut pools = self.pool.lock().unwrap_or_else(|e| e.into_inner());
let pool = pools.entry(self.capacity).or_default();
if pool.len() < self.max_pool_size {
pool.push(buffer);
}
let mut count = self
.allocated_count
.lock()
.unwrap_or_else(|e| e.into_inner());
*count = count.saturating_sub(1);
}
}
}
impl<T> std::ops::Deref for SafePooledBuffer<T> {
type Target = Vec<T>;
fn deref(&self) -> &Self::Target {
self.as_ref_vec()
}
}
impl<T> std::ops::DerefMut for SafePooledBuffer<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut_vec()
}
}
#[derive(Debug)]
pub struct SafePtr<T> {
ptr: NonNull<T>,
_marker: PhantomData<T>,
}
impl<T> SafePtr<T> {
pub unsafe fn new(ptr: NonNull<T>) -> Self {
Self {
ptr,
_marker: PhantomData,
}
}
pub unsafe fn as_ptr(&self) -> *const T {
self.ptr.as_ptr()
}
pub unsafe fn as_mut_ptr(&self) -> *mut T {
self.ptr.as_ptr()
}
}
unsafe impl<T: Send> Send for SafePtr<T> {}
unsafe impl<T: Sync> Sync for SafePtr<T> {}
pub struct SafeSharedModel<T> {
inner: Arc<RwLock<T>>,
id: String,
}
impl<T> SafeSharedModel<T> {
pub fn new(model: T, id: String) -> Self {
Self {
inner: Arc::new(RwLock::new(model)),
id,
}
}
pub fn read(&self) -> std::sync::RwLockReadGuard<'_, T> {
self.inner
.read()
.unwrap_or_else(|e| panic!("RwLock poisoned for model {}: {}", self.id, e))
}
pub fn write(&self) -> std::sync::RwLockWriteGuard<'_, T> {
self.inner
.write()
.unwrap_or_else(|e| panic!("RwLock poisoned for model {}: {}", self.id, e))
}
pub fn try_read(&self) -> Option<std::sync::RwLockReadGuard<'_, T>> {
self.inner.try_read().ok()
}
pub fn try_write(&self) -> Option<std::sync::RwLockWriteGuard<'_, T>> {
self.inner.try_write().ok()
}
pub fn clone_ref(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
id: self.id.clone(),
}
}
}
impl<T: Clone> SafeSharedModel<T> {
pub fn clone_model(&self) -> T {
self.read().clone()
}
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_memory_safety_documentation() {
let guarantee = MemorySafety::document_safety("array_indexing");
assert_eq!(guarantee.operation, "array_indexing");
assert!(!guarantee.guarantees.is_empty());
}
#[test]
fn test_unsafe_validation() {
let safe_code = "let x = vec![1, 2, 3]; let y = &x[0];";
let result = MemorySafety::validate_unsafe_usage(safe_code);
assert_eq!(result.safety_score, 100);
assert!(result.issues.is_empty());
let unsafe_code = "let x = transmute::<i32, f32>(42);";
let result = MemorySafety::validate_unsafe_usage(unsafe_code);
assert!(result.safety_score < 100);
assert!(!result.issues.is_empty());
}
#[test]
fn test_safe_array_operations() {
let array = Array2::<f64>::zeros((10, 10));
assert!(array.safe_get(&[0, 0]).is_some());
assert!(array.safe_get(&[10, 10]).is_none());
assert!(array.safe_get(&[5]).is_none());
assert!(array.validate_dimensions().is_ok());
assert!(array.is_valid_index(&[5, 5]));
assert!(!array.is_valid_index(&[10, 5]));
}
#[test]
fn test_memory_pool() {
let pool = SafeMemoryPool::<i32>::new();
let buffer = pool.allocate(100);
assert_eq!(buffer.capacity(), 100);
let stats = pool.stats();
assert_eq!(stats.allocated_count, 1);
drop(buffer);
let stats = pool.stats();
assert_eq!(stats.allocated_count, 0);
assert_eq!(stats.pooled_count, 1);
}
#[test]
fn test_shared_model() {
let model = vec![1, 2, 3, 4, 5];
let shared = SafeSharedModel::new(model, "test_model".to_string());
{
let reader = shared.read();
assert_eq!(reader.len(), 5);
}
{
let mut writer = shared.write();
writer.push(6);
assert_eq!(writer.len(), 6);
}
let shared2 = shared.clone_ref();
let reader = shared2.read();
assert_eq!(reader.len(), 6);
}
#[test]
fn test_pooled_buffer_deref() {
let pool = SafeMemoryPool::<i32>::new();
let mut buffer = pool.allocate(10);
buffer.push(42);
assert_eq!(buffer.len(), 1);
assert_eq!(buffer[0], 42);
let inner = buffer.into_inner();
assert_eq!(inner, vec![42]);
}
}