use std::ffi::CStr;
use std::marker::PhantomData;
use std::ptr;
use executorch_sys as sys;
use crate::data_loader::DataLoader;
use crate::data_map::{AbstractNamedDataMap, NamedDataMap};
use crate::evalue::{EValue, Tag};
use crate::event_tracer::EventTracer;
use crate::memory::MemoryManager;
use crate::tensor::ScalarType;
use crate::util::{try_c_new, ArrayRef, IntoCpp, IntoRust, __ArrayRefImpl, chars2str, FfiChar};
use crate::{Error, Result};
pub struct Program<'a>(sys::Program, PhantomData<&'a ()>);
impl<'a> Program<'a> {
pub fn load(
data_loader: &'a dyn DataLoader,
verification: Option<ProgramVerification>,
) -> Result<Self> {
let data_loader = sys::DataLoaderRefMut {
ptr: data_loader._cpp_ptr().cast_mut(),
};
let verification = verification.unwrap_or(ProgramVerification::Minimal).cpp();
let program = unsafe {
try_c_new(|program| sys::executorch_Program_load(data_loader, verification, program))?
};
Ok(Self(program, PhantomData))
}
pub fn num_methods(&self) -> usize {
unsafe { sys::executorch_Program_num_methods(&self.0) }
}
pub fn get_method_name(&self, method_index: usize) -> Result<&str> {
let method_name = unsafe {
try_c_new(|method_name| {
sys::executorch_Program_get_method_name(&self.0, method_index, method_name)
})?
};
let method_name = unsafe { CStr::from_ptr(method_name) };
method_name.to_str().map_err(|_| Error::InvalidString)
}
pub fn get_named_data_map(&self) -> Result<&dyn NamedDataMap> {
let data_map = unsafe {
try_c_new(|data_map| sys::executorch_Program_get_named_data_map(&self.0, data_map))?
};
let data_map = data_map.ptr.cast::<AbstractNamedDataMap>();
Ok(unsafe { &*data_map })
}
pub fn load_method<'b>(
&'b self,
method_name: &CStr,
memory_manager: &'b MemoryManager,
event_tracer: Option<&'b mut EventTracer>,
named_data_map: Option<&'b dyn NamedDataMap>,
) -> Result<Method<'b>> {
let memory_manager = memory_manager.0.get();
let event_tracer = event_tracer
.map(|tracer| tracer as *mut EventTracer)
.unwrap_or(ptr::null_mut());
let event_tracer = sys::EventTracerRefMut {
ptr: event_tracer as *mut _,
};
let named_data_map = named_data_map
.map(|map| map._cpp_ptr())
.unwrap_or(ptr::null());
let named_data_map = sys::NamedDataMapRef {
ptr: named_data_map,
};
let method = unsafe {
try_c_new(|method| {
sys::executorch_Program_load_method(
&self.0,
method_name.as_ptr(),
memory_manager,
event_tracer,
named_data_map,
method,
)
})?
};
Ok(Method(method, PhantomData))
}
pub fn method_meta<'b>(&'b self, method_name: &CStr) -> Result<MethodMeta<'b>> {
let meta = unsafe {
try_c_new(|meta| {
sys::executorch_Program_method_meta(&self.0, method_name.as_ptr(), meta)
})?
};
Ok(unsafe { MethodMeta::new(meta) })
}
pub fn check_header(data: &[u8]) -> HeaderStatus {
unsafe { sys::executorch_Program_check_header(data.as_ptr() as *const _, data.len()) }.rs()
}
}
impl Drop for Program<'_> {
fn drop(&mut self) {
unsafe { sys::executorch_Program_destructor(&mut self.0) };
}
}
#[repr(u8)]
#[doc = " Types of validation that the Program can do before parsing the data."]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum ProgramVerification {
#[doc = " Do minimal verification of the data, ensuring that the header appears\n correct.\n\n Has minimal runtime overhead."]
Minimal = sys::ProgramVerification::ProgramVerification_Minimal as u8,
#[doc = " Do full verification of the data, ensuring that internal pointers are\n self-consistent and that the data has not been truncated or obviously\n corrupted. May not catch all types of corruption, but should guard\n against illegal memory operations during parsing.\n\n Will have higher runtime overhead, scaling with the complexity of the\n proram data."]
InternalConsistency = sys::ProgramVerification::ProgramVerification_InternalConsistency as u8,
}
impl IntoCpp for ProgramVerification {
type CppType = sys::ProgramVerification;
fn cpp(self) -> Self::CppType {
match self {
ProgramVerification::Minimal => sys::ProgramVerification::ProgramVerification_Minimal,
ProgramVerification::InternalConsistency => {
sys::ProgramVerification::ProgramVerification_InternalConsistency
}
}
}
}
#[repr(u32)]
#[doc = " Describes the presence of an ExecuTorch program header."]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum HeaderStatus {
#[doc = " An ExecuTorch program header is present, and its version is compatible\n with this version of the runtime."]
CompatibleVersion = sys::ProgramHeaderStatus::ProgramHeaderStatus_CompatibleVersion as u32,
#[doc = " An ExecuTorch program header is present, but its version is not\n compatible with this version of the runtime."]
IncompatibleVersion = sys::ProgramHeaderStatus::ProgramHeaderStatus_IncompatibleVersion as u32,
#[doc = " An ExecuTorch program header is not present."]
NotPresent = sys::ProgramHeaderStatus::ProgramHeaderStatus_NotPresent as u32,
#[doc = " The data provided was too short to find the program header."]
ShortData = sys::ProgramHeaderStatus::ProgramHeaderStatus_ShortData as u32,
}
impl IntoRust for sys::ProgramHeaderStatus {
type RsType = HeaderStatus;
fn rs(self) -> Self::RsType {
match self {
sys::ProgramHeaderStatus::ProgramHeaderStatus_CompatibleVersion => {
HeaderStatus::CompatibleVersion
}
sys::ProgramHeaderStatus::ProgramHeaderStatus_IncompatibleVersion => {
HeaderStatus::IncompatibleVersion
}
sys::ProgramHeaderStatus::ProgramHeaderStatus_NotPresent => HeaderStatus::NotPresent,
sys::ProgramHeaderStatus::ProgramHeaderStatus_ShortData => HeaderStatus::ShortData,
}
}
}
pub struct MethodMeta<'a>(sys::MethodMeta, PhantomData<&'a ()>);
impl MethodMeta<'_> {
pub(crate) unsafe fn new(meta: sys::MethodMeta) -> Self {
Self(meta, PhantomData)
}
pub fn name(&self) -> &str {
let name = unsafe { sys::executorch_MethodMeta_name(&self.0) };
let name = unsafe { CStr::from_ptr(name) };
name.to_str().map_err(|_| Error::InvalidString).unwrap()
}
pub fn num_inputs(&self) -> usize {
unsafe { sys::executorch_MethodMeta_num_inputs(&self.0) }
}
pub fn input_tag(&self, idx: usize) -> Result<Tag> {
unsafe {
try_c_new(|tag| sys::executorch_MethodMeta_input_tag(&self.0, idx, tag))
.map(IntoRust::rs)
}
}
pub fn input_tensor_meta(&self, idx: usize) -> Result<TensorInfo<'_>> {
let info = unsafe {
try_c_new(|info| sys::executorch_MethodMeta_input_tensor_meta(&self.0, idx, info))?
};
Ok(unsafe { TensorInfo::new(info) })
}
pub fn num_outputs(&self) -> usize {
unsafe { sys::executorch_MethodMeta_num_outputs(&self.0) }
}
pub fn output_tag(&self, idx: usize) -> Result<Tag> {
unsafe {
try_c_new(|tag| sys::executorch_MethodMeta_output_tag(&self.0, idx, tag))
.map(IntoRust::rs)
}
}
pub fn output_tensor_meta(&self, idx: usize) -> Result<TensorInfo<'_>> {
let info = unsafe {
try_c_new(|info| sys::executorch_MethodMeta_output_tensor_meta(&self.0, idx, info))?
};
Ok(unsafe { TensorInfo::new(info) })
}
pub fn num_attributes(&self) -> usize {
unsafe { sys::executorch_MethodMeta_num_attributes(&self.0) }
}
pub fn attribute_tensor_meta(&self, idx: usize) -> Result<TensorInfo<'_>> {
let info = unsafe {
try_c_new(|info| sys::executorch_MethodMeta_attribute_tensor_meta(&self.0, idx, info))?
};
Ok(unsafe { TensorInfo::new(info) })
}
pub fn num_memory_planned_buffers(&self) -> usize {
unsafe { sys::executorch_MethodMeta_num_memory_planned_buffers(&self.0) }
}
pub fn memory_planned_buffer_size(&self, idx: usize) -> Result<usize> {
let size = unsafe {
try_c_new(|size| {
sys::executorch_MethodMeta_memory_planned_buffer_size(&self.0, idx, size)
})?
};
Ok(size as usize)
}
pub fn uses_backend(&self, backend_name: &CStr) -> bool {
unsafe { sys::executorch_MethodMeta_uses_backend(&self.0, backend_name.as_ptr()) }
}
pub fn num_backends(&self) -> usize {
unsafe { sys::executorch_MethodMeta_num_backends(&self.0) }
}
pub fn get_backend_name(&self, index: usize) -> Result<&str> {
let backend_name = unsafe {
try_c_new(|name| sys::executorch_MethodMeta_get_backend_name(&self.0, index, name))?
};
let backend_name = unsafe { CStr::from_ptr(backend_name) };
backend_name.to_str().map_err(|_| Error::InvalidString)
}
}
pub struct TensorInfo<'a>(sys::TensorInfo, PhantomData<&'a ()>);
impl<'a> TensorInfo<'a> {
pub(crate) unsafe fn new(info: sys::TensorInfo) -> Self {
Self(info, PhantomData)
}
pub fn sizes(&self) -> &[i32] {
let span = unsafe { sys::executorch_TensorInfo_sizes(&self.0) };
unsafe { ArrayRef::from_inner(span) }.as_slice()
}
pub fn dim_order(&self) -> &[u8] {
let span = unsafe { sys::executorch_TensorInfo_dim_order(&self.0) };
unsafe { ArrayRef::from_inner(span) }.as_slice()
}
pub fn scalar_type(&self) -> ScalarType {
unsafe { sys::executorch_TensorInfo_scalar_type(&self.0) }.rs()
}
pub fn is_memory_planned(&self) -> bool {
unsafe { sys::executorch_TensorInfo_is_memory_planned(&self.0) }
}
pub fn nbytes(&self) -> usize {
unsafe { sys::executorch_TensorInfo_nbytes(&self.0) }
}
pub fn name(&self) -> Result<&str, std::str::Utf8Error> {
chars2str(self.name_chars())
}
pub fn name_chars(&self) -> &[std::ffi::c_char] {
let chars = unsafe { sys::executorch_TensorInfo_name(&self.0).as_slice() };
FfiChar::slice_to_ffi(chars)
}
}
impl std::fmt::Debug for TensorInfo<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("TensorInfo")
.field("name", &self.name().unwrap_or("<invalid utf8>"))
.field("sizes", &self.sizes())
.field("dim_order", &self.dim_order())
.field("scalar_type", &self.scalar_type())
.field("nbytes", &self.nbytes())
.finish()
}
}
pub struct Method<'a>(sys::Method, PhantomData<&'a ()>);
impl Method<'_> {
pub fn start_execution(&mut self) -> Execution<'_> {
Execution::new(&mut self.0)
}
pub fn inputs_size(&self) -> usize {
unsafe { sys::executorch_Method_inputs_size(&self.0) }
}
#[cfg(feature = "alloc")]
pub fn get_attribute<'b>(&'b mut self, name: &str) -> Result<crate::tensor::TensorAny<'b>> {
let name = ArrayRef::from_chars(crate::util::str2chars(name));
let tensor = unsafe {
crate::util::NonTriviallyMovable::try_new_boxed(|tensor: *mut sys::TensorStorage| {
let tensor = sys::TensorRefMut { ptr: tensor.cast() };
sys::executorch_Method_get_attribute(&mut self.0, name.0, tensor).rs()
})?
};
unsafe {
Ok(crate::tensor::TensorAny::from_raw_tensor(
crate::tensor::RawTensor::new_impl(tensor),
))
}
}
}
impl Drop for Method<'_> {
fn drop(&mut self) {
unsafe { sys::executorch_Method_destructor(&mut self.0) };
}
}
pub struct Execution<'a> {
method: &'a mut sys::Method,
set_inputs: u64,
}
impl<'a> Execution<'a> {
fn new(method: &'a mut sys::Method) -> Self {
assert!(
unsafe { sys::executorch_Method_inputs_size(method) } <= u64::BITS as usize,
"more that 64 inputs for method, unsupported"
);
Self {
method,
set_inputs: 0,
}
}
pub fn set_input(&mut self, input: &'a EValue, input_idx: usize) -> Result<()> {
unsafe { sys::executorch_Method_set_input(self.method, input.cpp(), input_idx) }.rs()?;
self.set_inputs |= 1 << input_idx;
Ok(())
}
pub fn execute(self) -> Result<Outputs<'a>> {
if self.set_inputs != (1 << unsafe { sys::executorch_Method_inputs_size(self.method) }) - 1
{
crate::log::error!("Not all inputs were set before executing the method");
return Err(Error::InvalidArgument);
}
unsafe { sys::executorch_Method_execute(self.method) }.rs()?;
Ok(Outputs::new(self.method))
}
}
pub struct Outputs<'a> {
method: &'a mut sys::Method,
}
impl<'a> Outputs<'a> {
fn new(method: &'a mut sys::Method) -> Self {
Self { method }
}
pub fn len(&self) -> usize {
unsafe { sys::executorch_Method_outputs_size(self.method) }
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn get(&self, index: usize) -> EValue<'_> {
let value = unsafe { sys::executorch_Method_get_output(self.method as *const _, index) };
unsafe { EValue::from_inner_ref(value) }
}
}
#[cfg(test)]
mod tests {
use crate::data_loader::BufferDataLoader;
use crate::tests::ADD_MODEL_BYTES;
use super::*;
#[test]
fn load() {
let loader = BufferDataLoader::new(ADD_MODEL_BYTES);
let program = Program::load(&loader, None);
assert!(program.is_ok());
let loader = BufferDataLoader::new(ADD_MODEL_BYTES);
let program = Program::load(&loader, Some(ProgramVerification::Minimal));
assert!(program.is_ok());
let loader = BufferDataLoader::new(ADD_MODEL_BYTES);
let program = Program::load(&loader, Some(ProgramVerification::InternalConsistency));
assert!(program.is_ok());
let loader = BufferDataLoader::new(&[]);
let program = Program::load(&loader, None);
assert!(program.is_err());
}
#[test]
fn num_methods() {
let loader = BufferDataLoader::new(ADD_MODEL_BYTES);
let program = Program::load(&loader, None).unwrap();
assert_eq!(program.num_methods(), 1);
}
#[test]
fn get_method_name() {
let loader = BufferDataLoader::new(ADD_MODEL_BYTES);
let program = Program::load(&loader, None).unwrap();
assert_eq!(program.get_method_name(0).ok(), Some("forward"));
assert_eq!(program.get_method_name(1).ok(), None);
}
#[test]
fn method_meta() {
let loader = BufferDataLoader::new(ADD_MODEL_BYTES);
let program = Program::load(&loader, None).unwrap();
let method_meta = program.method_meta(c"forward").unwrap();
assert!(program.method_meta(c"non-existing-method").is_err());
assert_eq!(method_meta.name(), "forward");
assert_eq!(method_meta.num_inputs(), 2);
assert_eq!(method_meta.input_tag(0).unwrap(), Tag::Tensor);
assert_eq!(method_meta.input_tag(1).unwrap(), Tag::Tensor);
assert!(method_meta.input_tag(2).is_err());
let tinfo1 = method_meta.input_tensor_meta(1).unwrap();
let tinfo2 = method_meta.input_tensor_meta(0).unwrap();
for tinfo in [tinfo1, tinfo2] {
assert_eq!(tinfo.sizes(), &[1]);
assert_eq!(tinfo.dim_order(), &[0]);
assert_eq!(tinfo.scalar_type(), ScalarType::Float);
assert!(tinfo.is_memory_planned());
assert_eq!(tinfo.nbytes(), 4);
}
assert_eq!(method_meta.num_outputs(), 1);
assert_eq!(method_meta.output_tag(0).unwrap(), Tag::Tensor);
assert!(method_meta.output_tag(1).is_err());
let tinfo = method_meta.output_tensor_meta(0).unwrap();
assert_eq!(tinfo.sizes(), &[1]);
assert_eq!(tinfo.dim_order(), &[0]);
assert_eq!(tinfo.scalar_type(), ScalarType::Float);
assert!(tinfo.is_memory_planned());
assert_eq!(tinfo.nbytes(), 4);
assert!(method_meta.output_tensor_meta(1).is_err());
assert_eq!(method_meta.num_attributes(), 0);
assert!(method_meta.attribute_tensor_meta(0).is_err());
for i in 0..method_meta.num_memory_planned_buffers() {
assert!(method_meta.memory_planned_buffer_size(i).is_ok());
}
assert!(method_meta
.memory_planned_buffer_size(method_meta.num_memory_planned_buffers())
.is_err());
for i in 0..method_meta.num_backends() {
let backend_name = method_meta.get_backend_name(i).unwrap();
assert!(!backend_name.is_empty());
#[cfg(feature = "alloc")]
assert!(
method_meta.uses_backend(std::ffi::CString::new(backend_name).unwrap().as_c_str())
);
}
assert!(!method_meta.uses_backend(c"non-existing-backend"));
}
#[cfg(tests_with_kernels)]
#[test]
fn load_method() {
use crate::memory::{BufferMemoryAllocator, HierarchicalAllocator, MemoryAllocatorExt};
use crate::util::Span;
let mut buffer = [0_u8; 4096];
let allocator = BufferMemoryAllocator::new(&mut buffer);
let data_loader = BufferDataLoader::new(ADD_MODEL_BYTES);
let program =
Program::load(&data_loader, Some(ProgramVerification::InternalConsistency)).unwrap();
let method_meta = program.method_meta(c"forward").unwrap();
let num_memory_planned_buffers = method_meta.num_memory_planned_buffers();
let planned_arenas = allocator
.allocate_arr_fn(num_memory_planned_buffers, |idx| {
let buf_size = method_meta.memory_planned_buffer_size(idx).unwrap();
Span::from_slice(allocator.allocate_arr::<u8>(buf_size).unwrap())
})
.unwrap();
let mut planned_memory = HierarchicalAllocator::new(planned_arenas);
let memory_manager = MemoryManager::new(&allocator, Some(&mut planned_memory), None);
assert!(program
.load_method(c"non-existing-method", &memory_manager, None, None)
.is_err());
assert!(program
.load_method(c"forward", &memory_manager, None, None)
.is_ok());
}
#[test]
fn check_header() {
assert_ne!(Program::check_header(&[]), HeaderStatus::CompatibleVersion);
assert_ne!(
Program::check_header(&[42, 6, 17]),
HeaderStatus::CompatibleVersion
);
assert_eq!(
Program::check_header(ADD_MODEL_BYTES),
HeaderStatus::CompatibleVersion
);
}
#[cfg(tests_with_kernels)]
#[test]
fn method_execution() {
use crate::memory::{BufferMemoryAllocator, HierarchicalAllocator, MemoryAllocatorExt};
use crate::tensor::{Tensor, TensorImpl};
use crate::util::Span;
let mut buffer = [0_u8; 4096];
let allocator = BufferMemoryAllocator::new(&mut buffer);
let data_loader = BufferDataLoader::new(ADD_MODEL_BYTES);
let program =
Program::load(&data_loader, Some(ProgramVerification::InternalConsistency)).unwrap();
let method_meta = program.method_meta(c"forward").unwrap();
let num_memory_planned_buffers = method_meta.num_memory_planned_buffers();
let planned_arenas = allocator
.allocate_arr_fn(num_memory_planned_buffers, |idx| {
let buf_size = method_meta.memory_planned_buffer_size(idx).unwrap();
Span::from_slice(allocator.allocate_arr::<u8>(buf_size).unwrap())
})
.unwrap();
let mut planned_memory = HierarchicalAllocator::new(planned_arenas);
let memory_manager = MemoryManager::new(&allocator, Some(&mut planned_memory), None);
let mut method = program
.load_method(c"forward", &memory_manager, None, None)
.unwrap();
assert_eq!(method.inputs_size(), 2);
assert!(method.get_attribute("non-existing-attr").is_err());
let execution = method.start_execution();
assert!(matches!(
execution.execute(), Err(Error::InvalidArgument)
));
let mut execution = method.start_execution();
let sizes = [1];
let data = [1.0_f32];
let dim_order = [0];
let strides = [1];
let tensor_impl = TensorImpl::from_slice(&sizes, &data, &dim_order, &strides).unwrap();
let tensor = Tensor::new_in_allocator(&tensor_impl, &allocator);
let input1 = EValue::new_in_allocator(tensor, &allocator);
let sizes = [1];
let data = [1.0_f32];
let dim_order = [0];
let strides = [1];
let tensor_impl = TensorImpl::from_slice(&sizes, &data, &dim_order, &strides).unwrap();
let tensor = Tensor::new_in_allocator(&tensor_impl, &allocator);
let input2 = EValue::new_in_allocator(tensor, &allocator);
assert!(execution.set_input(&input1, 2).is_err());
execution.set_input(&input1, 0).unwrap();
execution.set_input(&input2, 1).unwrap();
let outputs = execution.execute().unwrap();
assert!(!outputs.is_empty());
assert_eq!(outputs.len(), 1);
let output = outputs.get(0);
assert_eq!(output.tag(), Tag::Tensor);
let output = output.as_tensor().into_typed::<f32>();
assert_eq!(output.sizes(), [1]);
assert_eq!(output[&[0]], 2.0);
}
}