use std::alloc::Layout;
use std::cell::Cell;
use std::marker::PhantomData;
use std::mem;
use std::mem::ManuallyDrop;
use std::ptr;
#[repr(C)]
pub struct UnsizedEnum<B, V0: ?Sized, V1> {
header: Header<B>,
phantomdata: PhantomData<V1>,
val: ManuallyDrop<V0>,
}
#[repr(C)]
struct UnsizedEnum_V1<B, V0: ?Sized, V1> {
header: Header<B>,
val: V1,
phantomdata: PhantomData<V0>,
}
#[repr(C)]
struct Header<B> {
disc: usize,
base: B,
}
#[allow(dead_code)]
const fn max(a: usize, b: usize) -> usize {
[a, b][(a < b) as usize]
}
#[inline(always)]
fn vtable_of<T: ?Sized>(fat: &T) -> *const () {
#[repr(C)]
struct FatPointer {
data: *const (),
meta: *const (),
}
if mem::size_of_val(&fat) < mem::size_of::<FatPointer>() {
std::ptr::null()
} else {
assert_eq!(mem::size_of_val(&fat), mem::size_of::<FatPointer>());
let repr = unsafe { mem::transmute_copy::<&T, FatPointer>(&fat) };
repr.meta
}
}
impl<B, V0, V1> UnsizedEnum<B, V0, V1> {
#[allow(dead_code)]
const LAYOUT: Layout = unsafe {
Layout::from_size_align_unchecked(
max(
mem::size_of::<UnsizedEnum<B, V0, V1>>(),
mem::size_of::<UnsizedEnum_V1<B, V0, V1>>(),
),
max(
mem::align_of::<UnsizedEnum<B, V0, V1>>(),
mem::align_of::<UnsizedEnum_V1<B, V0, V1>>(),
),
)
};
unsafe fn alloc() -> *mut UnsizedEnum<B, V0, V1> {
let p = std::alloc::alloc(Self::LAYOUT) as *mut UnsizedEnum<B, V0, V1>;
if p.is_null() {
std::alloc::handle_alloc_error(Self::LAYOUT);
}
p
}
pub fn new_v0(base: B, val: V0) -> Box<Self> {
let inner = UnsizedEnum {
header: Header { disc: 0, base },
phantomdata: PhantomData,
val: ManuallyDrop::new(val),
};
unsafe {
let p = Self::alloc();
ptr::write(p, inner);
Box::from_raw(p)
}
}
pub fn new_v1(base: B, val: V1) -> Box<Self> {
let inner = UnsizedEnum_V1 {
header: Header { disc: 1, base },
val,
phantomdata: PhantomData,
};
unsafe {
let p = Self::alloc();
ptr::write(p as *mut UnsizedEnum_V1<B, V0, V1>, inner);
Box::from_raw(p)
}
}
}
impl<B, V0: ?Sized, V1> UnsizedEnum<B, V0, V1> {
pub fn base(&mut self) -> &mut B {
&mut self.header.base
}
pub fn set_spare(&mut self, spare: usize) {
self.header.disc = (self.header.disc & 1) | (spare & !1);
}
pub fn get_spare(&self) -> usize {
self.header.disc
}
#[inline]
pub fn set_v0(&mut self, v0: &Move<V0>) {
let p = v0.get_ref().expect("Already moved value passed to set_v0");
assert_eq!(
vtable_of(p),
vtable_of(self),
"Passed a value to set_v0() from a different underlying type"
);
unsafe {
self.drop_value();
std::ptr::copy_nonoverlapping(
p as *const V0 as *const u8,
(&mut *self.val) as *mut V0 as *mut u8,
std::mem::size_of_val(&*p),
);
self.header.disc &= !1;
}
}
#[inline]
pub fn set_v1(&mut self, v1: V1) {
unsafe {
self.drop_value();
let p = self as *mut UnsizedEnum<B, V0, V1>;
let p = p as *mut UnsizedEnum_V1<B, V0, V1>;
ptr::write(&mut (*p).val, v1);
self.header.disc |= 1;
}
}
unsafe fn drop_value(&mut self) {
if 0 == (self.header.disc & 1) {
ManuallyDrop::drop(&mut self.val);
} else {
let p = self as *mut UnsizedEnum<B, V0, V1>;
let p = p as *mut UnsizedEnum_V1<B, V0, V1>;
ptr::drop_in_place(&mut (*p).val);
}
}
pub fn get_mut(&mut self) -> EnumRef<V0, V1> {
unsafe {
if 0 == (self.header.disc & 1) {
EnumRef::V0(&mut *self.val)
} else {
let p = self as *mut UnsizedEnum<B, V0, V1>;
let p = p as *mut UnsizedEnum_V1<B, V0, V1>;
EnumRef::V1(&mut (*p).val)
}
}
}
}
impl<B, V0: ?Sized, V1> Drop for UnsizedEnum<B, V0, V1> {
fn drop(&mut self) {
unsafe { self.drop_value() };
}
}
pub enum EnumRef<'a, V0: ?Sized, V1> {
V0(&'a mut V0),
V1(&'a mut V1),
}
pub struct Move<T: ?Sized> {
moved: Cell<bool>,
value: ManuallyDrop<T>,
}
impl<T> Move<T> {
#[inline]
pub fn new(val: T) -> Self {
Self {
moved: Cell::new(false),
value: ManuallyDrop::new(val),
}
}
}
impl<T: ?Sized> Move<T> {
#[inline]
pub fn get_ref(&self) -> Option<&T> {
if self.moved.get() {
None
} else {
self.moved.set(true);
Some(&*self.value)
}
}
}
impl<T: ?Sized> Drop for Move<T> {
fn drop(&mut self) {
if !self.moved.get() {
unsafe { ManuallyDrop::drop(&mut self.value) };
}
}
}
#[cfg(test)]
mod tests {
use super::{EnumRef, Move, UnsizedEnum};
struct Base(usize);
struct A(u16);
trait Sum {
fn sum(&self) -> f64;
}
struct B {
a: f64,
b: f64,
}
impl Sum for B {
fn sum(&self) -> f64 {
self.a + self.b
}
}
struct C {
a: u32,
b: u32,
c: u32,
}
impl Sum for C {
fn sum(&self) -> f64 {
(self.a + self.b + self.c) as f64
}
}
fn calc_sum(r: &mut Box<UnsizedEnum<Base, dyn Sum, A>>) -> f64 {
match r.get_mut() {
EnumRef::V0(v) => v.sum(),
EnumRef::V1(a) => a.0 as f64,
}
}
#[test]
fn test() {
let mut e: Box<UnsizedEnum<Base, dyn Sum, A>>;
e = UnsizedEnum::new_v0(Base(654321), B { a: 1.0, b: 2.0 });
assert_eq!(calc_sum(&mut e), 3.0);
e.set_v1(A(54321));
assert_eq!(calc_sum(&mut e), 54321.0);
e.set_v0(&Move::new(B { a: 3.0, b: 4.0 }));
assert_eq!(calc_sum(&mut e), 7.0);
e = UnsizedEnum::<Base, C, A>::new_v1(Base(654321), A(12345));
assert_eq!(calc_sum(&mut e), 12345.0);
e.set_v0(&Move::new(C { a: 3, b: 4, c: 5 }));
assert_eq!(calc_sum(&mut e), 12.0);
e.set_v1(A(13542));
assert_eq!(calc_sum(&mut e), 13542.0);
}
#[test]
#[should_panic]
fn test_writing_wrong_type_1() {
let mut e: Box<UnsizedEnum<Base, dyn Sum, A>>;
e = UnsizedEnum::<Base, C, A>::new_v1(Base(654321), A(12345));
e.set_v0(&Move::new(B { a: 3.0, b: 4.0 }));
}
#[test]
#[should_panic]
fn test_writing_wrong_type_2() {
let mut e: Box<UnsizedEnum<Base, dyn Sum, A>>;
e = UnsizedEnum::new_v0(Base(654321), C { a: 3, b: 4, c: 5 });
e.set_v0(&Move::new(B { a: 3.0, b: 4.0 }));
}
#[test]
fn test_sized() {
let mut e = UnsizedEnum::new_v0(Base(654321), B { a: 1.0, b: 2.0 });
e.set_v1(A(54321));
e.set_v0(&Move::new(B { a: 3.0, b: 4.0 }));
}
}