use crate::{
garbage_collector::GcRootPtr,
marshal::Marshal,
reflection::{ArgumentReflection, ReturnTypeReflection},
GarbageCollector, Runtime,
};
use mun_memory::{
gc::{GcPtr, GcRuntime, HasIndirectionPtr},
Type,
};
use std::{
ptr::{self, NonNull},
sync::Arc,
};
#[repr(transparent)]
#[derive(Clone)]
pub struct RawStruct(GcPtr);
impl RawStruct {
pub unsafe fn get_ptr(&self) -> *const u8 {
self.0.deref()
}
}
#[derive(Clone)]
pub struct StructRef<'s> {
raw: RawStruct,
runtime: &'s Runtime,
}
impl<'s> StructRef<'s> {
fn new<'r>(raw: RawStruct, runtime: &'r Runtime) -> Self
where
'r: 's,
{
Self { raw, runtime }
}
pub fn into_raw(self) -> RawStruct {
self.raw
}
pub fn root(self) -> RootedStruct {
RootedStruct::new(&self.runtime.gc, self.raw)
}
pub fn type_info(&self) -> Type {
self.runtime.gc.ptr_type(self.raw.0)
}
unsafe fn get_field_ptr_unchecked<T>(&self, offset: usize) -> NonNull<T> {
let ptr = self.raw.get_ptr();
NonNull::new_unchecked(ptr.add(offset).cast::<T>() as *mut T)
}
pub fn get<T: ReturnTypeReflection + Marshal<'s>>(&self, field_name: &str) -> Result<T, String>
where
T: 's,
{
let type_info = self.type_info();
let struct_info = type_info.as_struct().unwrap();
let field_info = struct_info
.fields()
.find_by_name(field_name)
.ok_or_else(|| {
format!(
"Struct `{}` does not contain field `{}`.",
type_info.name(),
field_name
)
})?;
if !T::accepts_type(&field_info.ty()) {
return Err(format!(
"Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.",
type_info.name(),
field_name,
T::type_hint(),
field_info.ty().name(),
));
};
let field_ptr = unsafe { self.get_field_ptr_unchecked::<T::MunType>(field_info.offset()) };
Ok(Marshal::marshal_from_ptr(
field_ptr,
self.runtime,
&field_info.ty(),
))
}
pub fn replace<T: ArgumentReflection + Marshal<'s>>(
&mut self,
field_name: &str,
value: T,
) -> Result<T, String>
where
T: 's,
{
let type_info = self.type_info();
let struct_info = type_info.as_struct().unwrap();
let field_info = struct_info
.fields()
.find_by_name(field_name)
.ok_or_else(|| {
format!(
"Struct `{}` does not contain field `{}`.",
type_info.name(),
field_name
)
})?;
let value_type = value.type_info(self.runtime);
if field_info.ty() != value_type {
return Err(format!(
"Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.",
type_info.name(),
field_name,
value_type.name(),
field_info.ty()
));
}
let field_ptr = unsafe { self.get_field_ptr_unchecked::<T::MunType>(field_info.offset()) };
let old = Marshal::marshal_from_ptr(field_ptr, self.runtime, &field_info.ty());
Marshal::marshal_to_ptr(value, field_ptr, &field_info.ty());
Ok(old)
}
pub fn set<T: ArgumentReflection + Marshal<'s>>(
&mut self,
field_name: &str,
value: T,
) -> Result<(), String> {
let type_info = self.type_info();
let struct_info = type_info.as_struct().unwrap();
let field_info = struct_info
.fields()
.find_by_name(field_name)
.ok_or_else(|| {
format!(
"Struct `{}` does not contain field `{}`.",
type_info.name(),
field_name
)
})?;
let value_type = value.type_info(self.runtime);
if field_info.ty() != value_type {
return Err(format!(
"Mismatched types for `{}::{}`. Expected: `{}`. Found: `{}`.",
type_info.name(),
field_name,
value_type.name(),
field_info.ty()
));
}
let field_ptr = unsafe { self.get_field_ptr_unchecked::<T::MunType>(field_info.offset()) };
Marshal::marshal_to_ptr(value, field_ptr, &field_info.ty());
Ok(())
}
}
impl<'r> ArgumentReflection for StructRef<'r> {
fn type_info(&self, _runtime: &Runtime) -> Type {
self.type_info()
}
}
impl<'s> Marshal<'s> for StructRef<'s> {
type MunType = RawStruct;
fn marshal_from<'r>(value: Self::MunType, runtime: &'r Runtime) -> Self
where
'r: 's,
{
StructRef::new(value, runtime)
}
fn marshal_into<'r>(self) -> Self::MunType {
self.into_raw()
}
fn marshal_from_ptr<'r>(
ptr: NonNull<Self::MunType>,
runtime: &'r Runtime,
type_info: &Type,
) -> StructRef<'s>
where
Self: 's,
'r: 's,
{
let struct_info = type_info.as_struct().unwrap();
let gc_handle = if struct_info.is_value_struct() {
let mut gc_handle = runtime.gc().alloc(type_info);
let src = ptr.cast::<u8>().as_ptr() as *const _;
let dest = unsafe { gc_handle.deref_mut::<u8>() };
unsafe { ptr::copy_nonoverlapping(src, dest, type_info.value_layout().size()) };
gc_handle
} else {
unsafe { *ptr.cast::<GcPtr>().as_ptr() }
};
StructRef::new(RawStruct(gc_handle), runtime)
}
fn marshal_to_ptr(value: Self, mut ptr: NonNull<Self::MunType>, type_info: &Type) {
let struct_info = type_info.as_struct().unwrap();
if struct_info.is_value_struct() {
let dest = ptr.cast::<u8>().as_ptr();
unsafe {
ptr::copy_nonoverlapping(
value.into_raw().get_ptr(),
dest,
type_info.value_layout().size(),
)
};
} else {
unsafe { *ptr.as_mut() = value.into_raw() };
}
}
}
impl<'r> ReturnTypeReflection for StructRef<'r> {
fn accepts_type(ty: &Type) -> bool {
ty.is_struct()
}
fn type_hint() -> &'static str {
"struct"
}
}
#[derive(Clone)]
pub struct RootedStruct {
handle: GcRootPtr,
}
impl RootedStruct {
fn new(gc: &Arc<GarbageCollector>, raw: RawStruct) -> Self {
assert!(gc.ptr_type(raw.0).is_struct());
Self {
handle: GcRootPtr::new(gc, raw.0),
}
}
pub fn as_ref<'r>(&self, runtime: &'r Runtime) -> StructRef<'r> {
assert_eq!(Arc::as_ptr(&runtime.gc), self.handle.runtime().as_ptr());
StructRef::new(RawStruct(self.handle.handle()), runtime)
}
}