use cpp::cpp;
use crate::ffi::parser::Parser;
const MAX_DIMS: usize = 8;
pub struct NetworkDefinition {
internal: *mut std::ffi::c_void,
pub(crate) _parser: Option<Parser>,
}
unsafe impl Send for NetworkDefinition {}
unsafe impl Sync for NetworkDefinition {}
impl NetworkDefinition {
pub(crate) fn wrap(internal: *mut std::ffi::c_void) -> Self {
Self {
internal,
_parser: None,
}
}
pub fn inputs(&self) -> Vec<Tensor> {
let mut inputs = Vec::with_capacity(self.num_inputs());
for index in 0..self.num_inputs() {
inputs.push(self.input(index));
}
inputs
}
pub fn num_inputs(&self) -> usize {
let internal = self.as_ptr();
let num_inputs = cpp!(unsafe [
internal as "const void*"
] -> std::os::raw::c_int as "int" {
return ((const INetworkDefinition*) internal)->getNbInputs();
});
num_inputs as usize
}
pub fn input(&self, index: usize) -> Tensor<'_> {
let internal = self.as_ptr();
let index = index as std::os::raw::c_int;
let tensor_internal = cpp!(unsafe [
internal as "const void*",
index as "int"
] -> *mut std::ffi::c_void as "void*" {
return ((const INetworkDefinition*) internal)->getInput(index);
});
Tensor::wrap(tensor_internal)
}
pub fn outputs(&self) -> Vec<Tensor<'_>> {
let mut outputs = Vec::with_capacity(self.num_outputs());
for index in 0..self.num_outputs() {
outputs.push(self.output(index));
}
outputs
}
pub fn num_outputs(&self) -> usize {
let internal = self.as_ptr();
let num_outputs = cpp!(unsafe [
internal as "const void*"
] -> std::os::raw::c_int as "int" {
return ((const INetworkDefinition*) internal)->getNbOutputs();
});
num_outputs as usize
}
pub fn output(&self, index: usize) -> Tensor<'_> {
let internal = self.as_ptr();
let index = index as std::os::raw::c_int;
let tensor_internal = cpp!(unsafe [
internal as "const void*",
index as "int"
] -> *mut std::ffi::c_void as "void*" {
return ((const INetworkDefinition*) internal)->getOutput(index);
});
Tensor::wrap(tensor_internal)
}
#[inline(always)]
pub fn as_ptr(&self) -> *const std::ffi::c_void {
let NetworkDefinition { internal, .. } = *self;
internal
}
#[inline(always)]
pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
let NetworkDefinition { internal, .. } = *self;
internal
}
}
impl Drop for NetworkDefinition {
fn drop(&mut self) {
let internal = self.as_mut_ptr();
cpp!(unsafe [
internal as "void*"
] {
destroy((INetworkDefinition*) internal);
});
}
}
#[derive(Copy, Clone)]
pub enum NetworkDefinitionCreationFlags {
None,
ExplicitBatchSize,
}
pub struct Tensor<'parent> {
internal: *mut std::ffi::c_void,
_phantom: std::marker::PhantomData<&'parent ()>,
}
unsafe impl<'parent> Send for Tensor<'parent> {}
unsafe impl<'parent> Sync for Tensor<'parent> {}
impl<'parent> Tensor<'parent> {
#[inline]
pub(crate) fn wrap(internal: *mut std::ffi::c_void) -> Self {
Self {
internal,
_phantom: Default::default(),
}
}
pub fn name(&self) -> String {
let internal = self.as_ptr();
let name = cpp!(unsafe [
internal as "const void*"
] -> *const std::os::raw::c_char as "const char*" {
return ((const ITensor*) internal)->getName();
});
unsafe { std::ffi::CStr::from_ptr(name).to_string_lossy().to_string() }
}
pub fn set_name(&mut self, name: &str) {
let internal = self.as_mut_ptr();
let name_ffi = std::ffi::CString::new(name).unwrap();
let name_ptr = name_ffi.as_ptr();
cpp!(unsafe [
internal as "void*",
name_ptr as "const char*"
] {
return ((ITensor*) internal)->setName(name_ptr);
});
}
pub fn get_dimensions(&self) -> Vec<i32> {
let internal = self.as_ptr();
let mut dims = Vec::with_capacity(MAX_DIMS);
let dims_ptr = dims.as_mut_ptr();
let num_dimensions = cpp!(unsafe [
internal as "void*",
dims_ptr as "int32_t*"
] -> i32 as "int32_t" {
auto dims = ((const ITensor*) internal)->getDimensions();
if (dims.nbDims > 0) {
for (int i = 0; i < dims.nbDims; ++i) {
dims_ptr[i] = dims.d[i];
}
}
return dims.nbDims;
});
if num_dimensions > 0 {
unsafe {
dims.set_len(num_dimensions as usize);
}
}
dims
}
#[inline(always)]
pub fn as_ptr(&self) -> *const std::ffi::c_void {
let Tensor { internal, .. } = *self;
internal
}
#[inline(always)]
pub fn as_mut_ptr(&mut self) -> *mut std::ffi::c_void {
let Tensor { internal, .. } = *self;
internal
}
}
#[cfg(test)]
mod tests {
use crate::tests::utils::*;
#[tokio::test]
async fn test_network_inputs_and_outputs() {
let (_, network) = simple_network!();
assert_eq!(network.num_inputs(), 1);
assert_eq!(network.num_outputs(), 1);
let inputs = network.inputs();
let input = inputs.first().unwrap();
assert_eq!(input.name(), "X");
let outputs = network.outputs();
let output = outputs.first().unwrap();
assert_eq!(output.name(), "Y");
}
#[tokio::test]
async fn test_tensor_set_name() {
let (_, network) = simple_network!();
network.outputs()[0].set_name("Z");
assert_eq!(network.outputs()[0].name(), "Z");
}
}