use alloc::{boxed::Box, format, sync::Arc, vec::Vec};
use core::{
fmt::{self, Debug, Display},
marker::PhantomData,
ptr::{self}
};
use super::{DowncastableTarget, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker, format_value_type};
use crate::{
AsPointer, ErrorCode,
error::{Error, Result},
memory::Allocator,
ortsys,
value::DynValueTypeMarker
};
pub trait SequenceValueTypeMarker: ValueTypeMarker {
private_trait!();
}
#[derive(Debug)]
pub struct DynSequenceValueType;
impl ValueTypeMarker for DynSequenceValueType {
fn fmt(f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("DynSequence")
}
private_impl!();
}
impl SequenceValueTypeMarker for DynSequenceValueType {
private_impl!();
}
impl DowncastableTarget for DynSequenceValueType {
fn can_downcast(dtype: &ValueType) -> bool {
matches!(dtype, ValueType::Sequence { .. })
}
private_impl!();
}
#[derive(Debug)]
pub struct SequenceValueType<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized>(PhantomData<T>);
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> ValueTypeMarker for SequenceValueType<T> {
fn fmt(f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("Sequence<")?;
format_value_type::<T>().fmt(f)?;
f.write_str(">")
}
private_impl!();
}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> SequenceValueTypeMarker for SequenceValueType<T> {
private_impl!();
}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + ?Sized> DowncastableTarget for SequenceValueType<T> {
fn can_downcast(dtype: &ValueType) -> bool {
match dtype {
ValueType::Sequence(ty) => T::can_downcast(ty),
_ => false
}
}
private_impl!();
}
pub type DynSequence = Value<DynSequenceValueType>;
pub type Sequence<T> = Value<SequenceValueType<T>>;
pub type DynSequenceRef<'v> = ValueRef<'v, DynSequenceValueType>;
pub type DynSequenceRefMut<'v> = ValueRefMut<'v, DynSequenceValueType>;
pub type SequenceRef<'v, T> = ValueRef<'v, SequenceValueType<T>>;
pub type SequenceRefMut<'v, T> = ValueRefMut<'v, SequenceValueType<T>>;
impl<Type: SequenceValueTypeMarker + Sized> Value<Type> {
pub fn try_extract_sequence<'s, OtherType: ValueTypeMarker + DowncastableTarget + Debug + Sized>(&'s self) -> Result<Vec<ValueRef<'s, OtherType>>> {
match self.dtype() {
ValueType::Sequence(_) => {
let allocator = Allocator::default();
let mut len = 0;
ortsys![unsafe GetValueCount(self.ptr(), &mut len)?];
let mut vec = Vec::with_capacity(len);
for i in 0..len {
let value = extract_from_sequence(self.ptr(), i, &allocator)?;
let value_type = value.dtype();
if !OtherType::can_downcast(value.dtype()) {
return Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("Cannot extract Sequence<{}> from {value_type:?}", format_value_type::<OtherType>())
));
}
vec.push(value.downcast()?);
}
Ok(vec)
}
t => Err(Error::new(format!("Cannot extract Sequence<{}> from {t}", format_value_type::<OtherType>())))
}
}
}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized + 'static> Value<SequenceValueType<T>> {
pub fn new(values: impl IntoIterator<Item = Value<T>>) -> Result<Self> {
let mut value_ptr = ptr::null_mut();
let values: Vec<Value<T>> = values.into_iter().collect();
let value_ptrs: Vec<*const ort_sys::OrtValue> = values.iter().map(|c| c.ptr()).collect();
ortsys![
unsafe CreateValue(value_ptrs.as_ptr(), values.len(), ort_sys::ONNXType::ONNX_TYPE_SEQUENCE, &mut value_ptr)?;
nonNull(value_ptr)
];
Ok(Value {
inner: ValueInner::new_backed(
value_ptr,
ValueType::Sequence(Box::new(values[0].inner.dtype.clone())),
None,
true,
Box::new(values)
),
_markers: PhantomData
})
}
}
impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValueType<T>> {
pub fn extract_sequence<'s>(&'s self) -> Vec<ValueRef<'s, T>> {
self.try_extract_sequence().expect("Failed to extract sequence")
}
#[inline]
pub fn len(&self) -> usize {
let mut len = 0;
ortsys![unsafe GetValueCount(self.ptr(), &mut len).expect("infallible")];
len
}
#[inline]
pub fn is_empty(&self) -> bool {
let mut len = 0;
ortsys![unsafe GetValueCount(self.ptr(), &mut len).expect("infallible")];
len == 0
}
pub fn get(&self, index: usize) -> Option<ValueRef<'_, T>> {
extract_from_sequence(self.ptr(), index, &Allocator::default())
.ok()
.and_then(|x| x.downcast().ok())
}
#[inline]
pub fn upcast(self) -> DynSequence {
unsafe { self.transmute_type() }
}
#[inline]
pub fn upcast_ref(&self) -> DynSequenceRef<'_> {
DynSequenceRef::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
})
}
#[inline]
pub fn upcast_mut(&mut self) -> DynSequenceRefMut<'_> {
DynSequenceRefMut::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
})
}
}
fn extract_from_sequence<'s>(ptr: *const ort_sys::OrtValue, i: usize, allocator: &Allocator) -> Result<ValueRef<'s, DynValueTypeMarker>> {
let mut value_ptr = ptr::null_mut();
ortsys![unsafe GetValue(ptr, i as _, allocator.ptr().cast_mut(), &mut value_ptr)?; nonNull(value_ptr)];
let mut value = ValueRef::new(unsafe { Value::from_ptr(value_ptr, None) });
value.upgradable = false;
Ok(value)
}