extern crate libc;
extern crate tensorflow_sys as ffi;
use libc::{c_int, c_void, int64_t, size_t};
use std::ffi::{CStr, CString};
use std::path::Path;
macro_rules! nonnull(
($pointer:expr) => ({
let pointer = $pointer;
assert!(!pointer.is_null());
pointer
});
);
macro_rules! ok(
($status:expr) => ({
if ffi::TF_GetCode($status) != ffi::TF_OK {
panic!(CStr::from_ptr(ffi::TF_Message($status)).to_string_lossy().into_owned());
}
});
);
fn main() {
use std::mem::size_of;
use std::ptr::{null, null_mut};
use std::slice::from_raw_parts;
unsafe {
let options = nonnull!(ffi::TF_NewSessionOptions());
let status = nonnull!(ffi::TF_NewStatus());
let session = nonnull!(ffi::TF_NewDeprecatedSession(options, status));
let graph = read("examples/assets/multiplication.pb"); ffi::TF_ExtendGraph(session, graph.as_ptr() as *const _, graph.len() as size_t, status);
ok!(status);
let mut input_names = vec![];
let mut inputs = vec![];
let name = CString::new("a").unwrap();
let mut data = vec![1f32, 2.0, 3.0];
let dims = vec![data.len() as int64_t];
let tensor = nonnull!(ffi::TF_NewTensor(ffi::TF_FLOAT, dims.as_ptr(), dims.len() as c_int,
data.as_mut_ptr() as *mut _, data.len() as size_t,
Some(noop), null_mut()));
input_names.push(name.as_ptr());
inputs.push(tensor);
let name = CString::new("b").unwrap();
let mut data = vec![4f32, 5.0, 6.0];
let dims = vec![data.len() as int64_t];
let tensor = nonnull!(ffi::TF_NewTensor(ffi::TF_FLOAT, dims.as_ptr(), dims.len() as c_int,
data.as_mut_ptr() as *mut _, data.len() as size_t,
Some(noop), null_mut()));
input_names.push(name.as_ptr());
inputs.push(tensor);
let mut output_names = vec![];
let mut outputs = vec![];
let name = CString::new("c").unwrap();
output_names.push(name.as_ptr());
outputs.push(null_mut());
let mut target_names = vec![];
ffi::TF_Run(session, null(), input_names.as_mut_ptr(), inputs.as_mut_ptr(),
input_names.len() as c_int, output_names.as_mut_ptr(), outputs.as_mut_ptr(),
output_names.len() as c_int, target_names.as_mut_ptr(),
target_names.len() as c_int, null_mut(), status);
ok!(status);
let tensor = nonnull!(outputs[0]);
let data = nonnull!(ffi::TF_TensorData(tensor)) as *const f32;
let data = from_raw_parts(data, ffi::TF_TensorByteSize(tensor) / size_of::<f32>());
assert_eq!(data, &[1.0 * 4.0, 2.0 * 5.0, 3.0 * 6.0]);
ffi::TF_CloseDeprecatedSession(session, status);
ffi::TF_DeleteTensor(tensor);
ffi::TF_DeleteDeprecatedSession(session, status);
ffi::TF_DeleteStatus(status);
ffi::TF_DeleteSessionOptions(options);
}
unsafe extern "C" fn noop(_: *mut c_void, _: size_t, _: *mut c_void) {}
}
fn read<T: AsRef<Path>>(path: T) -> Vec<u8> {
use std::fs::File;
use std::io::Read;
let mut buffer = vec![];
let mut file = File::open(path).unwrap();
file.read_to_end(&mut buffer).unwrap();
buffer
}