use crate::dims::IsDim;
use crate::network::Network;
use std::error;
use std::ffi::CString;
use std::fmt::Formatter;
use std::io;
use std::path::{Path, PathBuf};
use tensorrt_sys::{
uffparser_create_uff_parser, uffparser_destroy_uff_parser, uffparser_parse,
uffparser_register_input, uffparser_register_output,
};
#[repr(C)]
pub enum UffInputOrder {
Nchw,
Nhwc,
Nc,
}
pub struct UffFile(PathBuf);
impl UffFile {
pub fn new(file_name: &Path) -> Result<UffFile, io::Error> {
if !file_name.exists() {
return Err(io::Error::new(
io::ErrorKind::NotFound,
"UFF file does not exist",
));
}
if file_name.extension().unwrap() != "uff" {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid UFF file. UFF files should have a .uff ending",
));
}
Ok(UffFile(file_name.to_path_buf()))
}
pub fn path(&self) -> CString {
CString::new(self.0.to_str().unwrap()).unwrap()
}
}
pub struct UffParser {
internal_uffparser: *mut tensorrt_sys::UffParser_t,
}
impl UffParser {
pub fn new() -> UffParser {
let parser = unsafe { uffparser_create_uff_parser() };
UffParser {
internal_uffparser: parser,
}
}
pub fn register_input(
&self,
input_name: &str,
dims: impl IsDim,
input_order: UffInputOrder,
) -> Result<(), UFFRegistrationError> {
let res = unsafe {
uffparser_register_input(
self.internal_uffparser,
CString::new(input_name).unwrap().as_ptr(),
dims.internal_dims(),
input_order as i32,
)
};
if res {
Ok(())
} else {
Err(UFFRegistrationError::new("Input Registration Failed"))
}
}
pub fn register_output(&self, output_name: &str) -> Result<(), UFFRegistrationError> {
let res = unsafe {
uffparser_register_output(
self.internal_uffparser,
CString::new(output_name).unwrap().as_ptr(),
)
};
if res {
Ok(())
} else {
Err(UFFRegistrationError::new("Output Registration Failed"))
}
}
pub fn parse(&self, uff_file: &UffFile, network: &Network) -> Result<(), UFFParseError> {
let res = unsafe {
uffparser_parse(
self.internal_uffparser,
uff_file.path().as_ptr(),
network.internal_network,
)
};
if res {
Ok(())
} else {
Err(UFFParseError::new("Error parsing UFF file"))
}
}
}
impl Drop for UffParser {
fn drop(&mut self) {
unsafe { uffparser_destroy_uff_parser(self.internal_uffparser) };
}
}
#[derive(Debug, Clone)]
pub struct UFFRegistrationError {
message: String,
}
impl UFFRegistrationError {
pub fn new(message: &str) -> Self {
UFFRegistrationError {
message: message.to_string(),
}
}
}
impl std::fmt::Display for UFFRegistrationError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl error::Error for UFFRegistrationError {}
#[derive(Debug, Clone)]
pub struct UFFParseError {
message: String,
}
impl UFFParseError {
pub fn new(message: &str) -> Self {
UFFParseError {
message: message.to_string(),
}
}
}
impl std::fmt::Display for UFFParseError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl error::Error for UFFParseError {}