use std::any::TypeId;
use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::fmt::{Debug, Display, Formatter};
use std::mem::MaybeUninit;
use std::ops::{Deref, DerefMut};
pub unsafe trait IsA<T> {}
unsafe impl<T> IsA<()> for T where T: Object {}
unsafe impl Object for () {}
#[repr(C)]
#[derive(Clone)] pub struct Typed<T> {
type_ids: HashSet<TypeId>, value: T
}
impl<T> Debug for Typed<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Typed")
.field("type_ids", &self.type_ids)
.finish_non_exhaustive()
}
}
impl<T> Deref for Typed<T> {
type Target = T;
fn deref(&self) -> &T{
&self.value
}
}
impl<T> DerefMut for Typed<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.value
}
}
impl<T: 'static> Typed<T> {
pub fn new(value: T) -> Self where T: Object {
Typed {
type_ids: T::type_ids(),
value
}
}
#[must_use]
pub fn upcast_ref<Target: 'static>(&self) -> &Typed<Target> where T: IsA<Target> {
unsafe { std::mem::transmute(self) }
}
#[must_use]
pub fn upcast_mut<Target: 'static>(&mut self) -> &mut Typed<Target> where T: IsA<Target> {
unsafe { std::mem::transmute(self) }
}
pub fn try_cast_ref<Target: Object + 'static>(&self) -> Result<&Typed<Target>, ()> {
unsafe {
if self.type_ids.contains(&TypeId::of::<Target>()) {
Ok(std::mem::transmute(self))
} else {
Err(())
}
}
}
pub fn try_cast_mut<Target: 'static>(&mut self) -> Result<&mut Typed<Target>, ()> {
unsafe {
if self.type_ids.contains(&TypeId::of::<Target>()) {
Ok(std::mem::transmute(self))
} else {
Err(())
}
}
}
pub fn extract_value(self) -> T {
self.value
}
}
pub unsafe trait Object where Self: Sized + 'static {
fn type_ids() -> HashSet<TypeId> {
let mut set = HashSet::new();
set.insert(TypeId::of::<Self>());
set.insert(TypeId::of::<()>());
set
}
fn builder() -> ObjectBuilder<Self> {
ObjectBuilder::new()
}
#[doc(hidden)]
unsafe fn uninit_selective_drop(uninit_self: *mut Self, _init_offsets: &Vec<usize>, _size_of_0: Option<usize>) {
unsafe {
uninit_self.drop_in_place()
}
}
}
#[derive(Debug)]
pub struct ObjectBuilder<T: Object + 'static> {
#[doc(hidden)]
pub _initialized_fields_size: HashMap<usize, usize>,
#[doc(hidden)]
pub _value: MaybeUninit<T>,
}
#[derive(Debug, Copy, Clone)]
pub struct ObjectInitError(usize);
impl Display for ObjectInitError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"called `ObjectBuilder::build(...)` before fully initializing the object. ({} fields were initialized)",
self.0
)
}
}
mod private {
}
impl Error for ObjectInitError {}
impl<T: Object> ObjectBuilder<T> {
pub fn new() -> Self {
Self {
_value: MaybeUninit::uninit(),
_initialized_fields_size: HashMap::new(),
}
}
pub fn insert_typed<P: 'static>(self, obj: Typed<P>) -> Self where T: IsA<P>, P: Object {
self.insert_value(obj.value)
}
pub fn insert_value<P: 'static>(mut self, obj: P) -> Self where T: IsA<P>, P: Object {
unsafe {
let ptr = self._value.as_mut_ptr() as *mut P;
P::uninit_selective_drop(
ptr,
&self._initialized_fields_size.iter().map(|p| *p.0).collect(),
self._initialized_fields_size.get(&0).map(|p| *p)
);
ptr.write(obj);
}
self._initialized_fields_size.insert(0, _align(align_of::<P>(), size_of::<P>()));
self
}
#[must_use]
pub fn is_init(&self) -> bool {
let potential_parent_field_size = self._initialized_fields_size.get(&0).unwrap_or(&0);
let struct_init_bytes = self._initialized_fields_size.iter()
.filter(|(offset, _)| **offset != 0) .filter(|(offset, _size)| *offset >= potential_parent_field_size)
.map(|(_, size)| size)
.sum::<usize>() + potential_parent_field_size;
struct_init_bytes == size_of::<T>()
}
#[must_use]
pub fn build(self) -> Typed<T> {
match self.try_build() {
Ok(val) => val,
Err(err) => panic!("{err}"),
}
}
pub fn try_build(self) -> Result<Typed<T>, ObjectInitError> {
self.try_build_as_value().map(|val| Typed::new(val))
}
#[must_use]
pub fn build_as_value(self) -> T {
match self.try_build_as_value() {
Ok(val) => val,
Err(err) => panic!("{err}"),
}
}
#[must_use]
pub fn try_build_as_value(self) -> Result<T, ObjectInitError> {
if self.is_init() {
Ok(unsafe { self._value.assume_init() })
} else {
Err(ObjectInitError(self._initialized_fields_size.len()))
}
}
}
impl<P: 'static, C: 'static> From<Typed<P>> for ObjectBuilder<C> where C: IsA<P>, P: Object, C: Object {
fn from(val: Typed<P>) -> Self {
ObjectBuilder::<C>::new().insert_typed(val)
}
}
impl<P: 'static, C: 'static> From<P> for ObjectBuilder<C> where C: IsA<P>, P: Object, C: Object {
fn from(val: P) -> Self {
ObjectBuilder::<C>::new().insert_value(val)
}
}
pub fn _align(to: usize, number: usize) -> usize {
number + ( to - ( (number - 1) % to) ) - 1
}