#![cfg_attr(not(std), no_std)]
extern "C" {
fn _smalloca(size: usize, dto: *mut u8, wrapper: extern "C" fn(*mut u8, *mut u8));
}
#[repr(C)]
struct SmallocaDto<T, F, R>
where
F: FnOnce(&mut [T]) -> R,
{
size: usize,
len: usize,
ret: Option<R>,
prototype: Option<T>,
function: Option<F>,
}
fn clone_init<T>(smalloca: &mut [T], prototype: T)
where
T: Clone,
{
if smalloca.len() > 0 {
for i in 0..smalloca.len() - 1 {
let uninit = core::mem::replace::<T>(&mut smalloca[i], prototype.clone());
core::mem::forget(uninit);
}
let uninit = core::mem::replace::<T>(&mut smalloca[smalloca.len() - 1], prototype);
core::mem::forget(uninit);
};
}
fn default_init<T>(smalloca: &mut [T])
where
T: Default,
{
for i in 0..smalloca.len() {
let uninit = core::mem::replace::<T>(&mut smalloca[i], T::default());
core::mem::forget(uninit);
}
}
fn zero_out<T>(smalloca: &mut [T]) {
for i in 0..smalloca.len() {
unsafe { core::mem::replace::<T>(&mut smalloca[i], core::mem::zeroed::<T>()) };
}
}
extern "C" fn wrapper_clone<T, F, R>(space: *mut u8, dto: *mut u8)
where
T: Clone,
F: FnOnce(&mut [T]) -> R,
{
let dto: &mut SmallocaDto<T, F, R> = unsafe { core::mem::transmute(dto) };
let smalloca: &mut [T] = unsafe { core::slice::from_raw_parts_mut(space as *mut T, dto.len) };
clone_init(smalloca, dto.prototype.take().unwrap());
dto.ret = Some((dto.function.take().unwrap())(smalloca));
let smalloca: &mut [T] = unsafe { core::slice::from_raw_parts_mut(space as *mut T, dto.len) };
zero_out(smalloca);
}
extern "C" fn wrapper_default<T, F, R>(space: *mut u8, dto: *mut u8)
where
T: Default,
F: FnOnce(&mut [T]) -> R,
{
let dto: &mut SmallocaDto<T, F, R> = unsafe { core::mem::transmute(dto) };
let smalloca: &mut [T] = unsafe { core::slice::from_raw_parts_mut(space as *mut T, dto.len) };
default_init(smalloca);
dto.ret = Some((dto.function.take().unwrap())(smalloca));
let smalloca: &mut [T] = unsafe { core::slice::from_raw_parts_mut(space as *mut T, dto.len) };
zero_out(smalloca);
}
pub fn smalloca<T, F, R>(prototype: T, len: usize, function: F) -> R
where
T: Clone,
F: FnOnce(&mut [T]) -> R,
{
let size = core::mem::size_of::<T>() * len;
let mut dto = SmallocaDto {
size,
len,
ret: None,
prototype: Some(prototype),
function: Some(function),
};
unsafe {
_smalloca(
size,
&mut dto as *mut SmallocaDto<T, F, R> as *mut u8,
wrapper_clone::<T, F, R>,
);
}
dto.ret.unwrap()
}
pub fn smalloca_default<T, F, R>(len: usize, function: F) -> R
where
T: Default,
F: FnOnce(&mut [T]) -> R,
{
let size = core::mem::size_of::<T>() * len;
let mut dto = SmallocaDto {
size,
len,
ret: None,
prototype: None,
function: Some(function),
};
unsafe {
_smalloca(
size,
&mut dto as *mut SmallocaDto<T, F, R> as *mut u8,
wrapper_default::<T, F, R>,
);
}
dto.ret.unwrap()
}
#[cfg(test)]
mod tests {
use crate::{smalloca, smalloca_default};
use core::sync::atomic::{AtomicIsize, Ordering};
static GENERATOR: AtomicIsize = AtomicIsize::new(0);
#[derive(Debug)]
struct MemSafetyTester {
genesis: isize,
}
impl MemSafetyTester {
fn new() -> Self {
let id = GENERATOR.fetch_add(1, Ordering::Relaxed);
MemSafetyTester { genesis: id }
}
}
impl Clone for MemSafetyTester {
fn clone(&self) -> Self {
let id = GENERATOR.fetch_add(1, Ordering::SeqCst);
MemSafetyTester { genesis: id }
}
}
impl Default for MemSafetyTester {
fn default() -> Self {
let id = GENERATOR.fetch_add(1, Ordering::SeqCst);
MemSafetyTester { genesis: id }
}
}
impl Drop for MemSafetyTester {
fn drop(&mut self) {
GENERATOR.fetch_sub(1, Ordering::SeqCst);
}
}
#[test]
fn memsafety_test_clone() {
assert_eq!(GENERATOR.load(Ordering::SeqCst), 0);
smalloca(MemSafetyTester::new(), 8, |_a| {
assert_eq!(GENERATOR.load(Ordering::SeqCst), 8);
});
assert_eq!(GENERATOR.load(Ordering::SeqCst), 0);
}
#[test]
fn memsafety_test_default() {
assert_eq!(GENERATOR.load(Ordering::SeqCst), 0);
smalloca_default(8, |_a: &mut [MemSafetyTester]| {
assert_eq!(GENERATOR.load(Ordering::SeqCst), 8);
});
assert_eq!(GENERATOR.load(Ordering::SeqCst), 0);
}
#[test]
fn return_test() {
let x = smalloca(8, 8, |_a| 5);
assert_eq!(x, 5);
}
#[test]
fn dyn_smalloca_test() {
for i in 0..32 {
let x = smalloca(1, i, |a| {
let mut sum = 0;
for x in a {
sum += *x;
}
sum
});
assert_eq!(x, i);
}
}
fn inner_function(values: &mut [u32]) -> u32 {
let mut sum = 0;
for i in 0..values.len() {
sum += values[i];
}
sum
}
#[test]
fn inner_fucntion_call_test() {
let result = smalloca(1, 8, |a| inner_function(a));
assert_eq!(result, 8);
}
#[test]
#[should_panic]
fn panic_test() {
smalloca(1, 8, |_a| panic!("This is a test"));
}
}