use ffi::{TF_DataType, TF_Session, TF_Tensor};
use libc::{c_int, size_t};
use std::ffi::CString;
use std::{mem, ptr};
use Result;
use buffer::Buffer;
use options::Options;
use status::Status;
use tensor::Tensor;
use value::Value;
pub struct Session {
status: Status,
raw: *mut TF_Session,
}
pub struct Input {
name: CString,
tensor: Option<Box<Flexor>>,
}
pub struct Output {
name: CString,
tensor: Option<*mut TF_Tensor>,
}
pub struct Target {
name: CString,
}
trait Flexor {
fn copy_raw(&self) -> Result<*mut TF_Tensor>;
fn kind(&self) -> TF_DataType;
}
impl Session {
pub fn new(options: &Options) -> Result<Self> {
let status = try!(Status::new());
let raw = nonnull!(ffi!(TF_NewSession(options.as_raw(), status.as_raw())), &status);
Ok(Session { status: status, raw: raw })
}
pub fn extend(&mut self, definition: &Buffer) -> Result<()> {
let definition = definition.as_ref();
ok!(ffi!(TF_ExtendGraph(self.raw, definition.as_ptr() as *const _,
definition.len() as size_t, self.status.as_raw())), &self.status);
Ok(())
}
pub fn run(&mut self, inputs: &[Input], outputs: &mut [Output], targets: &[Target],
options: Option<&Buffer>, metadata: Option<&mut Buffer>) -> Result<()> {
let ni = inputs.len();
let mut input_names = vec![ptr::null(); ni];
let mut input_tensors = vec![ptr::null_mut(); ni];
macro_rules! cleanup(() => ({
for tensor in input_tensors.drain(..) {
ffi!(TF_DeleteTensor(tensor));
}
}));
for i in 0..ni {
input_names[i] = inputs[i].name.as_ptr();
input_tensors[i] = match inputs[i].tensor.as_ref().map(|tensor| tensor.copy_raw()) {
Some(Ok(tensor)) => tensor,
Some(Err(error)) => {
cleanup!();
return Err(error);
},
_ => {
cleanup!();
raise!("some of the inputs have not been set");
},
};
}
let no = outputs.len();
let mut output_names = vec![ptr::null(); no];
let mut output_tensors = vec![ptr::null_mut(); no];
for i in 0..no {
output_names[i] = outputs[i].name.as_ptr();
}
let nt = targets.len();
let mut target_names = vec![ptr::null(); nt];
for i in 0..nt {
target_names[i] = targets[i].name.as_ptr();
}
let options_buffer = if let Some(buffer) = options {
buffer.as_raw()
} else {
ptr::null_mut()
};
let metadata_buffer = if let Some(ref buffer) = metadata {
buffer.as_raw()
} else {
ptr::null_mut()
};
ok!(ffi!(TF_Run(self.raw, options_buffer, input_names.as_mut_ptr(),
input_tensors.as_mut_ptr(), ni as c_int, output_names.as_mut_ptr(),
output_tensors.as_mut_ptr(), no as c_int, target_names.as_mut_ptr(),
nt as c_int, metadata_buffer, self.status.as_raw())), &self.status);
for i in 0..no {
outputs[i].set(output_tensors[i]);
}
if let Some(buffer) = metadata {
unsafe { buffer.reset() };
}
Ok(())
}
}
impl Drop for Session {
#[inline]
fn drop(&mut self) {
ffi!(TF_CloseSession(self.raw, self.status.as_raw()));
ffi!(TF_DeleteSession(self.raw, self.status.as_raw()));
}
}
impl Input {
#[inline]
pub fn new<T, U>(name: T, tensor: Tensor<U>) -> Self where T: Into<String>, U: Value {
Input { name: into_cstring!(name), tensor: Some(Box::new(tensor)) }
}
pub fn get<T>(&mut self) -> Result<Tensor<T>> where T: Value {
if self.tensor.is_none() {
raise!("the tensor has not been set");
}
if self.tensor.as_ref().unwrap().kind() != T::kind() {
raise!("the data types do not match");
}
let tensor = self.tensor.take().unwrap();
Ok(*unsafe { Box::from_raw(Box::into_raw(tensor) as *mut _) })
}
#[inline]
pub fn set<T>(&mut self, tensor: Tensor<T>) where T: Value {
self.tensor = Some(Box::new(tensor));
}
}
impl Output {
#[inline]
pub fn new<T>(name: T) -> Self where T: Into<String> {
Output { name: into_cstring!(name), tensor: None }
}
pub fn get<T>(&mut self) -> Result<Tensor<T>> where T: Value {
match self.tensor.take() {
Some(tensor) => Tensor::from_raw(tensor),
_ => raise!("the tensor has not been set"),
}
}
#[inline]
fn set(&mut self, tensor: *mut TF_Tensor) {
if let Some(tensor) = mem::replace(&mut self.tensor, Some(tensor)) {
ffi!(TF_DeleteTensor(tensor));
}
}
}
impl Drop for Output {
#[inline]
fn drop(&mut self) {
if let Some(tensor) = self.tensor.take() {
ffi!(TF_DeleteTensor(tensor));
}
}
}
impl Target {
#[inline]
pub fn new<T>(name: T) -> Self where T: Into<String> {
Target { name: into_cstring!(name) }
}
}
impl<T> Flexor for Tensor<T> where T: Value {
#[inline]
fn copy_raw(&self) -> Result<*mut TF_Tensor> {
self.copy_raw()
}
#[inline]
fn kind(&self) -> TF_DataType {
T::kind()
}
}
#[cfg(test)]
mod tests {
use session::Input;
use tensor::Tensor;
#[test]
fn input_get() {
let a = Tensor::new(vec![42.0, 69.0], &[2]).unwrap();
let mut a = Input::new("a", a);
let a = a.get::<f64>().unwrap();
assert_eq!(&a[..], &[42.0, 69.0]);
}
}