#![feature(iter_map_windows)]
#![feature(allocator_api)]
#![warn(missing_docs)]
use std::{
alloc::{Allocator, Global, Layout},
fmt::Debug,
marker::PhantomData,
mem::MaybeUninit,
ops::{Deref, DerefMut},
ptr::NonNull,
};
pub struct Columned<A: Allocator = Global> {
alloc: A,
ptr: NonNull<u8>,
layout: Layout,
#[cfg(feature = "asserts")]
deallocated: std::sync::Arc<std::sync::OnceLock<()>>,
}
impl Columned {
pub unsafe fn new<const N: usize>(columns: [ColumnAlloc<'_>; N]) -> Self {
unsafe { Self::new_in(columns, Global) }
}
pub unsafe fn new_in<const N: usize, A: Allocator>(
mut columns: [ColumnAlloc<'_>; N],
alloc: A,
) -> Columned<A> {
if columns.is_empty() {
return Columned {
alloc,
ptr: NonNull::dangling(),
layout: Layout::new::<()>(),
#[cfg(feature = "asserts")]
deallocated: std::sync::Arc::new(std::sync::OnceLock::new()),
};
}
#[cfg(feature = "asserts")]
for (i, cols) in columns.windows(2).enumerate() {
assert!(
cols[0].align >= cols[1].align,
"columns should be ordered by alignment, but align(columns[{}]) < align(columns[{}])",
i,
i + 1
)
}
let align = columns[0].align;
let size = columns.iter().map(|e| e.size * e.requested_len).sum();
let layout = Layout::from_size_align(size, align).unwrap();
let ptr: NonNull<u8> = alloc.allocate(layout).unwrap().cast();
let mut p = ptr.as_ptr();
#[cfg(feature = "asserts")]
let deallocated = std::sync::Arc::new(std::sync::OnceLock::new());
for e in columns.iter_mut() {
*(e.ptr) = p as *mut ();
*e.len = e.requested_len;
p = p.wrapping_add(e.size * e.requested_len);
#[cfg(feature = "asserts")]
{
*e.deallocated = deallocated.clone();
*e.init = false;
}
}
#[cfg(not(feature = "asserts"))]
{
Columned { alloc, ptr, layout }
}
#[cfg(feature = "asserts")]
Columned {
alloc,
ptr,
layout,
deallocated,
}
}
}
impl<A> Drop for Columned<A>
where
A: Allocator,
{
fn drop(&mut self) {
#[cfg(feature = "asserts")]
{
self.deallocated.get_or_init(|| ());
}
unsafe { self.alloc.deallocate(self.ptr, self.layout) };
}
}
pub struct ColumnAlloc<'a> {
size: usize,
align: usize,
ptr: &'a mut *mut (),
len: &'a mut usize,
requested_len: usize,
#[cfg(feature = "asserts")]
deallocated: &'a mut std::sync::Arc<std::sync::OnceLock<()>>,
#[cfg(feature = "asserts")]
init: &'a mut bool,
}
pub struct Column<E>
where
E: Sized,
{
ptr: *mut (),
len: usize,
pd: PhantomData<E>,
#[cfg(feature = "asserts")]
deallocated: std::sync::Arc<std::sync::OnceLock<()>>,
#[cfg(feature = "asserts")]
init: bool,
}
impl<E> Debug for Column<E>
where
E: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(self.deref(), f)
}
}
impl<E> Drop for Column<E> {
fn drop(&mut self) {
if std::mem::needs_drop::<E>() {
#[cfg(feature = "asserts")]
{
assert!(
self.deallocated.get().is_none(),
"Underlying memory of Column has been deallocated. Therefore, cannot drop."
);
assert!(
self.init,
"Underlying memory not initialized. Therefore, cannot drop."
);
}
let ptr = self.ptr as *mut E;
for i in 0..self.len {
unsafe { std::ptr::drop_in_place(ptr.wrapping_add(i)) };
}
}
}
}
impl<E> Default for Column<E>
where
E: Sized,
{
fn default() -> Self {
Self {
ptr: std::ptr::dangling_mut::<E>() as *mut (),
len: 0,
pd: Default::default(),
#[cfg(feature = "asserts")]
deallocated: std::sync::Arc::new(std::sync::OnceLock::new()),
#[cfg(feature = "asserts")]
init: true,
}
}
}
impl<E> Column<E>
where
E: Sized,
{
pub fn new() -> Self {
Self::default()
}
pub fn alloc(&mut self, len: usize) -> ColumnAlloc<'_> {
ColumnAlloc {
size: core::mem::size_of::<E>(),
align: core::mem::align_of::<E>(),
ptr: &mut self.ptr,
len: &mut self.len,
requested_len: len,
#[cfg(feature = "asserts")]
deallocated: &mut self.deallocated,
#[cfg(feature = "asserts")]
init: &mut self.init,
}
}
pub fn maybe_uninit(&mut self) -> &mut [MaybeUninit<E>] {
#[cfg(feature = "asserts")]
{
self.init = true;
}
unsafe { std::slice::from_raw_parts_mut(self.ptr as *mut MaybeUninit<E>, self.len) }
}
}
impl<E> Deref for Column<E>
where
E: Sized,
{
type Target = [E];
fn deref(&self) -> &Self::Target {
#[cfg(feature = "asserts")]
{
assert!(
self.deallocated.get().is_none(),
"Underlying memory of Column has been deallocated"
);
assert!(self.init, "Underlying memory not initialized");
}
unsafe { std::slice::from_raw_parts(self.ptr as *const E, self.len) }
}
}
impl<E> DerefMut for Column<E>
where
E: Sized,
{
fn deref_mut(&mut self) -> &mut Self::Target {
#[cfg(feature = "asserts")]
{
assert!(
self.deallocated.get().is_none(),
"Underlying memory of Column has been deallocated"
);
assert!(self.init, "Underlying memory not initialized");
}
unsafe { std::slice::from_raw_parts_mut(self.ptr as *mut E, self.len) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic() {
let _columned;
let mut xs: Column<u64> = Default::default();
let mut ys: Column<u64> = Default::default();
let mut sums: Column<u64> = Default::default();
_columned = unsafe { Columned::new([xs.alloc(10), ys.alloc(10), sums.alloc(10)]) };
for (i, x) in xs.maybe_uninit().iter_mut().enumerate() {
x.write(i as u64);
}
for (i, y) in ys.maybe_uninit().iter_mut().enumerate() {
y.write(i as u64);
}
for sum in sums.maybe_uninit().iter_mut() {
sum.write(0);
}
for ((sum, x), y) in sums.iter_mut().zip(xs.iter()).zip(ys.iter()) {
*sum = x + y;
}
for (i, sum) in sums.iter().enumerate() {
assert_eq!(*sum, 2 * i as u64);
}
}
#[cfg(feature = "asserts")]
#[test]
#[should_panic]
fn test_use_after_free() {
let mut xs: Column<u64> = Default::default();
let mut ys: Column<u64> = Default::default();
let mut sums: Column<u64> = Default::default();
let _columned = unsafe { Columned::new([xs.alloc(10), ys.alloc(10), sums.alloc(10)]) };
for (i, x) in xs.maybe_uninit().iter_mut().enumerate() {
x.write(i as u64);
}
for (i, y) in ys.maybe_uninit().iter_mut().enumerate() {
y.write(i as u64);
}
for sum in sums.maybe_uninit().iter_mut() {
sum.write(0);
}
drop(_columned);
xs[0];
}
#[test]
fn test_no_drop_no_init() {
let mut xs: Column<u64> = Default::default();
let mut ys: Column<u64> = Default::default();
let mut sums: Column<u64> = Default::default();
let _columned = unsafe { Columned::new([xs.alloc(10), ys.alloc(10), sums.alloc(10)]) };
for (i, x) in xs.maybe_uninit().iter_mut().enumerate() {
x.write(i as u64);
}
for (i, y) in ys.maybe_uninit().iter_mut().enumerate() {
y.write(i as u64);
}
for sum in sums.maybe_uninit().iter_mut() {
sum.write(0);
}
}
#[cfg(feature = "asserts")]
#[test]
#[should_panic]
fn test_drop_no_init() {
struct WillDrop;
impl Drop for WillDrop {
fn drop(&mut self) {}
}
let mut xs: Column<WillDrop> = Default::default();
let _columned = unsafe { Columned::new([xs.alloc(10)]) };
}
#[cfg(feature = "asserts")]
#[test]
#[should_panic]
fn test_drop_with_init_but_wrong_order() {
struct WillDrop;
impl Drop for WillDrop {
fn drop(&mut self) {}
}
let mut xs: Column<WillDrop> = Default::default();
let _columned = unsafe { Columned::new([xs.alloc(10)]) };
for x in xs.maybe_uninit() {
x.write(WillDrop);
}
}
#[test]
fn test_drop_with_init_but_right_order() {
struct WillDrop;
impl Drop for WillDrop {
fn drop(&mut self) {}
}
let _columned;
let mut xs: Column<WillDrop> = Default::default();
_columned = unsafe { Columned::new([xs.alloc(10)]) };
for x in xs.maybe_uninit() {
x.write(WillDrop);
}
}
}