tensorrt-rs 0.3.0

Rust library for using Nvidia's TensorRT deep learning acceleration library
use crate::context::Context;
use crate::runtime::Runtime;
use crate::dims::Dims;
use std::convert::TryInto;
use std::ffi::{CStr, CString};
use std::os::raw::c_void;
use std::slice;
use tensorrt_sys::{deserialize_cuda_engine, destroy_cuda_engine, destroy_host_memory, engine_create_execution_context, engine_serialize, get_binding_index, get_binding_name, get_nb_bindings, host_memory_get_data, host_memory_get_size, get_binding_dimensions};

#[derive(Debug)]
pub struct Engine {
    pub(crate) internal_engine: *mut tensorrt_sys::Engine_t,
}

impl Engine {
    pub fn new(runtime: Runtime, buffer: Vec<u8>) -> Engine {
        let engine = unsafe {
            deserialize_cuda_engine(
                runtime.internal_runtime,
                buffer.as_ptr() as *const c_void,
                buffer.len() as u64,
            )
        };

        Engine {
            internal_engine: engine,
        }
    }

    pub fn get_nb_bindings(&self) -> i32 {
        unsafe { get_nb_bindings(self.internal_engine) }
    }

    pub fn get_binding_name(&self, binding_index: i32) -> Option<String> {
        if binding_index >= self.get_nb_bindings() {
            return None;
        }

        let binding_name = unsafe {
            let raw_binding_name =
                get_binding_name(self.internal_engine, binding_index.try_into().unwrap());
            CStr::from_ptr(raw_binding_name)
        };

        Some(binding_name.to_str().unwrap().to_string())
    }

    pub fn get_binding_index(&self, binding_name: &str) -> Option<i32> {
        let binding_index = unsafe {
            get_binding_index(
                self.internal_engine,
                CString::new(binding_name).unwrap().as_ptr(),
            )
        };

        return if binding_index == -1 {
            None
        } else {
            Some(binding_index)
        };
    }

    pub fn get_binding_dimensions(&self, binding_index: i32) -> Dims {
        let raw_dims = unsafe {
            get_binding_dimensions(self.internal_engine, binding_index)
        };

        Dims { internal_dims: raw_dims }
    }

    pub fn create_execution_context(&self) -> Context {
        let execution_context = unsafe { engine_create_execution_context(self.internal_engine) };
        Context {
            internal_context: execution_context,
            _engine: &self,
        }
    }

    pub fn serialize(&self) -> HostMemory {
        let memory = unsafe { engine_serialize(self.internal_engine) };
        HostMemory { memory }
    }
}

unsafe impl Send for Engine {}

impl Drop for Engine {
    fn drop(&mut self) {
        unsafe { destroy_cuda_engine(self.internal_engine) };
    }
}

pub struct HostMemory {
    pub(crate) memory: *mut tensorrt_sys::HostMemory_t,
}

impl HostMemory {
    pub fn data(&self) -> &[u8] {
        let ptr = unsafe { host_memory_get_data(self.memory) };
        let size = unsafe { host_memory_get_size(self.memory) };
        unsafe { slice::from_raw_parts(ptr as *const u8, size) }
    }
}

impl AsRef<[u8]> for HostMemory {
    fn as_ref(&self) -> &[u8] {
        self.data()
    }
}

impl Drop for HostMemory {
    fn drop(&mut self) {
        unsafe {
            destroy_host_memory(self.memory);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::builder::Builder;
    use crate::dims::DimsCHW;
    use crate::runtime::{Logger, Runtime};
    use crate::uff::{UffFile, UffInputOrder, UffParser};
    use lazy_static::lazy_static;
    use std::fs::{remove_file, write, File};
    use std::io::prelude::*;
    use std::path::Path;
    use std::sync::Mutex;

    lazy_static! {
        static ref ENGINE: Mutex<Engine> = Mutex::new(setup_engine_test_uff());
    }

    fn setup_engine_test_uff() -> Engine {
        let logger = Logger::new();
        let builder = Builder::new(&logger);

        let uff_parser = UffParser::new();
        let dim = DimsCHW::new(1, 28, 28);

        uff_parser
            .register_input("in", dim, UffInputOrder::Nchw)
            .unwrap();
        uff_parser.register_output("out").unwrap();
        let uff_file = UffFile::new(Path::new("../lenet5.uff")).unwrap();
        uff_parser.parse(&uff_file, builder.get_network()).unwrap();

        builder.build_cuda_engine()
    }

    #[test]
    fn get_nb_bindings() {
        let engine = ENGINE.lock().unwrap();

        assert_eq!(2, engine.get_nb_bindings());
    }

    #[test]
    fn get_engine_binding_name() {
        let engine = ENGINE.lock().unwrap();
        assert_eq!("in", engine.get_binding_name(0).unwrap());
    }

    #[test]
    fn get_invalid_engine_binding() {
        let engine = ENGINE.lock().unwrap();
        assert_eq!(None, engine.get_binding_name(3));
    }

    #[test]
    fn get_binding_index() {
        let engine = ENGINE.lock().unwrap();

        assert_eq!(Some(0), engine.get_binding_index("in"));
    }

    #[test]
    fn write_and_read_engine() {
        let uff_engine: &Engine = &*ENGINE.lock().unwrap();
        let seralized_path = Path::new("../lenet5.engine");
        write(seralized_path, uff_engine.serialize()).unwrap();

        assert!(seralized_path.exists());

        let logger = Logger::new();
        let runtime = Runtime::new(&logger);

        let mut f = File::open(seralized_path).unwrap();
        let mut buffer = Vec::new();
        f.read_to_end(&mut buffer).unwrap();

        let seralized_engine = Engine::new(runtime, buffer);

        assert_eq!(
            uff_engine.get_nb_bindings(),
            seralized_engine.get_nb_bindings()
        );

        for i in 0..uff_engine.get_nb_bindings() {
            assert_eq!(
                uff_engine.get_binding_name(i),
                seralized_engine.get_binding_name(i)
            );
            assert_eq!(
                uff_engine.get_binding_name(i),
                seralized_engine.get_binding_name(i)
            );
        }

        remove_file(seralized_path).unwrap();
    }
}