use core::ops::Not;
use core::pin::Pin;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use executorch_sys as sys;
use crate::memory::{Storable, Storage};
use crate::tensor::TensorAny;
use crate::util::IntoCpp;
use crate::{Error, Result};
use super::{EValue, Tag};
pub struct BoxedEvalueList<'a, T: BoxedEvalueListElement<'a>>(
pub(crate) T::__ListImpl,
PhantomData<&'a ()>,
);
impl<'a, T: BoxedEvalueListElement<'a>> BoxedEvalueList<'a, T> {
pub fn new(
wrapped_vals: &'a EValuePtrList<'_>,
unwrapped_vals: Pin<&'a mut [Storage<T>]>,
) -> Result<Self> {
let wrapped_vals_slice = wrapped_vals.as_slice();
if wrapped_vals_slice.len() != unwrapped_vals.len() {
crate::log::error!("wrapped and unwrapped lengths do not match");
return Err(Error::InvalidArgument);
}
for i in 0..wrapped_vals_slice.len() {
let elm = wrapped_vals.get(i).unwrap();
if let Some(elm) = elm {
if elm.tag() != T::__ELEMENT_TAG {
crate::log::error!("Value does not match T");
return Err(Error::InvalidType);
}
} else if !T::__ALLOW_NULL_ELEMENT {
crate::log::error!("T does not allow null elements");
return Err(Error::InvalidType);
}
}
let wrapped_vals = sys::ArrayRefEValuePtr {
data: wrapped_vals_slice.as_ptr(),
len: wrapped_vals_slice.len(),
};
let list = unsafe { T::__ListImpl::__new(wrapped_vals, unwrapped_vals)? };
Ok(Self(list, PhantomData))
}
}
pub trait BoxedEvalueListElement<'a>: Storable {
#[doc(hidden)]
const __ELEMENT_TAG: Tag;
#[doc(hidden)]
const __ALLOW_NULL_ELEMENT: bool;
#[doc(hidden)]
type __ListImpl: __BoxedEvalueListImpl<Element<'a> = Self>;
private_decl! {}
}
#[doc(hidden)]
pub trait __BoxedEvalueListImpl {
type Element<'a>: BoxedEvalueListElement<'a, __ListImpl = Self>;
unsafe fn __new(
wrapped_vals: sys::ArrayRefEValuePtr,
unwrapped_vals: Pin<&mut [Storage<Self::Element<'_>>]>,
) -> Result<Self>
where
Self: Sized;
private_decl! {}
}
macro_rules! ptr2ref {
($ptr:expr) => {
$ptr
};
($ptr:expr, $ref_type:path) => {{
$ref_type {
ptr: $ptr as *mut _,
}
}};
}
macro_rules! impl_boxed_evalue_list {
($element:path, $list_impl:path, $element_tag:ident, $allow_null_element:expr $(, $unwrapped_type:ty)?) => {
impl<'a> BoxedEvalueListElement<'a> for $element {
const __ELEMENT_TAG: Tag = Tag::$element_tag;
const __ALLOW_NULL_ELEMENT: bool = $allow_null_element;
type __ListImpl = $list_impl;
private_impl! {}
}
impl __BoxedEvalueListImpl for $list_impl {
type Element<'a> = $element;
unsafe fn __new(
wrapped_vals: sys::ArrayRefEValuePtr,
unwrapped_vals: Pin<&mut [Storage<Self::Element<'_>>]>,
) -> Result<Self> {
let unwrapped_vals = unsafe { unwrapped_vals.get_unchecked_mut() };
let unwrapped_vals_ptr = unwrapped_vals.as_mut_ptr() as *mut <Self::Element<'_> as Storable>::__Storage;
Ok(Self {
wrapped_vals,
unwrapped_vals: ptr2ref!(unwrapped_vals_ptr $(, $unwrapped_type)?),
})
}
private_impl! {}
}
};
}
impl_boxed_evalue_list!(i64, sys::BoxedEvalueListI64, Int, false);
impl_boxed_evalue_list!(
Option<TensorAny<'a>>,
sys::BoxedEvalueListOptionalTensor,
Tensor,
true,
sys::OptionalTensorRefMut
);
impl_boxed_evalue_list!(
TensorAny<'a>,
sys::BoxedEvalueListTensor,
Tensor,
false,
sys::TensorRefMut
);
pub struct EValuePtrList<'a>(EValuePtrListInner<'a>);
enum EValuePtrListInner<'a> {
#[cfg(feature = "alloc")]
Vec(
(
crate::alloc::Vec<sys::EValueRef>,
PhantomData<&'a ()>,
),
),
Slice(
(
&'a [sys::EValueRef],
PhantomData<&'a ()>,
),
),
}
impl<'a> EValuePtrList<'a> {
#[cfg(feature = "alloc")]
fn new_impl(values: impl IntoIterator<Item = Option<&'a EValue<'a>>>) -> Self {
let values: crate::alloc::Vec<sys::EValueRef> = values
.into_iter()
.map(|value| {
value.map(|value| value.cpp()).unwrap_or(sys::EValueRef {
ptr: std::ptr::null(),
})
})
.collect();
Self(EValuePtrListInner::Vec((values, PhantomData)))
}
#[cfg(feature = "alloc")]
pub fn new(values: impl IntoIterator<Item = &'a EValue<'a>>) -> Self {
Self::new_impl(values.into_iter().map(Some))
}
#[cfg(feature = "alloc")]
pub fn new_optional(values: impl IntoIterator<Item = Option<&'a EValue<'a>>>) -> Self {
Self::new_impl(values)
}
fn new_in_storage_impl(
values: impl IntoIterator<Item = Option<&'a EValue<'a>>>,
storage: Pin<&'a mut [Storage<EValuePtrListElem>]>,
) -> Self {
let mut values = values.into_iter();
let storage = unsafe { storage.get_unchecked_mut() };
let storage = unsafe {
std::mem::transmute::<
&mut [Storage<EValuePtrListElem>],
&mut [MaybeUninit<<EValuePtrListElem as Storable>::__Storage>],
>(storage)
};
let mut storage_iter = storage.iter_mut();
loop {
match (values.next(), storage_iter.next()) {
(Some(value), Some(storage)) => {
storage.write(value.map(|value| value.cpp()).unwrap_or(sys::EValueRef {
ptr: std::ptr::null(),
}));
}
(None, None) => break,
_ => panic!("Mismatched lengths"),
}
}
let storage = unsafe {
std::mem::transmute::<
&mut [MaybeUninit<<EValuePtrListElem as Storable>::__Storage>],
&mut [<EValuePtrListElem as Storable>::__Storage],
>(storage)
};
Self(EValuePtrListInner::Slice((storage, PhantomData)))
}
pub fn new_in_storage(
values: impl IntoIterator<Item = &'a EValue<'a>>,
storage: Pin<&'a mut [Storage<EValuePtrListElem>]>,
) -> Self {
Self::new_in_storage_impl(values.into_iter().map(Some), storage)
}
pub fn new_optional_in_storage(
values: impl IntoIterator<Item = Option<&'a EValue<'a>>>,
storage: Pin<&'a mut [Storage<EValuePtrListElem>]>,
) -> Self {
Self::new_in_storage_impl(values, storage)
}
fn as_slice(&self) -> &[sys::EValueRef] {
match &self.0 {
#[cfg(feature = "alloc")]
EValuePtrListInner::Vec((values, _)) => values.as_slice(),
EValuePtrListInner::Slice((values, _)) => values,
}
}
fn get(&self, index: usize) -> Option<Option<EValue<'_>>> {
let ptr = *self.as_slice().get(index)?;
Some(
ptr.ptr
.is_null()
.not()
.then(|| unsafe { EValue::from_inner_ref(ptr) }),
)
}
}
pub struct EValuePtrListElem(#[allow(unused)] sys::EValueRef);
impl Storable for EValuePtrListElem {
type __Storage = sys::EValueRef;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::evalue::EValue;
use crate::storage;
#[test]
#[should_panic]
fn evalue_ptr_list_length_mismatch() {
let evalue1_storage = storage!(EValue);
let evalue2_storage = storage!(EValue);
let evalue3_storage = storage!(EValue);
let evalue1 = EValue::new_in_storage(42, evalue1_storage);
let evalue2 = EValue::new_in_storage(17, evalue2_storage);
let evalue3 = EValue::new_in_storage(6, evalue3_storage);
let wrapped_vals_storage = storage!(EValuePtrListElem, [2]); let _ = EValuePtrList::new_in_storage([&evalue1, &evalue2, &evalue3], wrapped_vals_storage);
}
#[test]
fn length_mismatch() {
let evalue1_storage = storage!(EValue);
let evalue2_storage = storage!(EValue);
let evalue3_storage = storage!(EValue);
let evalue1 = EValue::new_in_storage(42, evalue1_storage);
let evalue2 = EValue::new_in_storage(17, evalue2_storage);
let evalue3 = EValue::new_in_storage(6, evalue3_storage);
let wrapped_vals_storage = storage!(EValuePtrListElem, [3]);
let wrapped_vals =
EValuePtrList::new_in_storage([&evalue1, &evalue2, &evalue3], wrapped_vals_storage);
let unwrapped_vals = storage!(i64, [2]);
let res = BoxedEvalueList::new(&wrapped_vals, unwrapped_vals);
assert!(matches!(res, Err(Error::InvalidArgument)));
}
#[test]
fn wrong_type() {
let evalue1_storage = storage!(EValue);
let evalue2_storage = storage!(EValue);
let evalue3_storage = storage!(EValue);
let evalue1 = EValue::new_in_storage(42, evalue1_storage);
let evalue2 = EValue::new_in_storage(17, evalue2_storage);
let evalue3 = EValue::new_in_storage(6.5, evalue3_storage);
let wrapped_vals_storage = storage!(EValuePtrListElem, [3]);
let wrapped_vals =
EValuePtrList::new_in_storage([&evalue1, &evalue2, &evalue3], wrapped_vals_storage);
let unwrapped_vals = storage!(i64, [3]);
let res = BoxedEvalueList::new(&wrapped_vals, unwrapped_vals);
assert!(matches!(res, Err(Error::InvalidType)));
}
#[test]
fn null_element() {
let evalue1_storage = storage!(EValue);
let evalue2_storage = storage!(EValue);
let evalue1 = EValue::new_in_storage(42, evalue1_storage);
let evalue2 = EValue::new_in_storage(17, evalue2_storage);
let evalue3 = None;
let wrapped_vals_storage = storage!(EValuePtrListElem, [3]);
let wrapped_vals = EValuePtrList::new_optional_in_storage(
[Some(&evalue1), Some(&evalue2), evalue3],
wrapped_vals_storage,
);
let unwrapped_vals = storage!(i64, [3]);
let res = BoxedEvalueList::new(&wrapped_vals, unwrapped_vals);
assert!(matches!(res, Err(Error::InvalidType)));
}
}