use std::collections::HashMap;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::mem;
use std::ops::{Index, IndexMut};
#[cfg(feature = "serialize")]
use std::result::Result as StdResult;
use std::sync::Arc;
use analyser::prelude::*;
use dim::TDim;
use model::TVec;
use ops::nn::local_patch::{DataFormat, Padding};
use {DatumType, Result, Tensor};
use downcast_rs::Downcast;
use objekt;
#[cfg(feature = "serialize")]
use serde::ser::{Serialize, Serializer};
#[macro_use]
mod macros;
mod array;
mod cast;
#[cfg(features = "image_ops")]
pub mod image;
pub mod konst;
mod math;
pub mod nn;
mod unimpl;
pub mod prelude {
pub use super::{Attr, Op, OpRegister};
pub use super::{OpBuffer, QueuesBuffer, StepValue, Stream, StreamInfo, Value};
pub use dim::TDim;
pub use model::TVec;
pub use std::collections::HashMap;
pub use std::marker::PhantomData;
pub use tensor::{Datum, DatumType, Tensor};
pub use Result;
}
#[derive(Debug, Clone)]
pub enum Value {
Owned(Tensor),
Shared(Arc<Tensor>),
}
impl Value {
pub fn into_shared(self) -> Value {
match self {
Value::Owned(m) => Value::Shared(Arc::new(m)),
Value::Shared(_) => self,
}
}
pub fn into_tensor(self) -> Tensor {
match self {
Value::Owned(m) => m,
Value::Shared(m) => m.as_ref().clone(),
}
}
pub fn as_tensor(&self) -> &Tensor {
match self {
&Value::Owned(ref m) => &m,
&Value::Shared(ref m) => m.as_ref(),
}
}
pub fn share(&mut self) -> Value {
if let Value::Owned(_) = self {
let dummy = Value::Owned(Tensor::i32s(&[], &[0]).unwrap());
let shared = match mem::replace(self, dummy) {
Value::Owned(m) => Value::Shared(Arc::new(m)),
_ => panic!(),
};
*self = shared;
}
self.clone()
}
pub fn into_array<'a, D: ::tensor::Datum>(self) -> ::Result<::ndarray::ArrayD<D>> {
self.into_tensor().into_array()
}
pub fn to_array_view<'a, D: ::tensor::Datum>(
&'a self,
) -> ::Result<::ndarray::ArrayViewD<'a, D>> {
self.as_tensor().to_array_view()
}
}
impl<M> From<M> for Value
where
Tensor: From<M>,
{
fn from(m: M) -> Value {
Value::Owned(m.into())
}
}
impl From<Arc<Tensor>> for Value {
fn from(m: Arc<Tensor>) -> Value {
Value::Shared(m)
}
}
impl ::std::ops::Deref for Value {
type Target = Tensor;
fn deref(&self) -> &Tensor {
match self {
&Value::Owned(ref m) => &m,
&Value::Shared(ref m) => m.as_ref(),
}
}
}
impl PartialEq for Value {
fn eq(&self, other: &Value) -> bool {
self.as_tensor() == other.as_tensor()
}
}
#[derive(Debug, Clone)]
pub enum StepValue {
Const(Value),
Stream(Stream),
}
#[derive(Debug, Clone)]
pub struct Stream {
pub info: StreamInfo,
pub offset: u64,
pub chunk: Option<Value>,
}
#[derive(Debug, Copy, Clone, Default)]
pub struct StreamInfo {
pub axis: usize,
pub len: TDim,
}
impl StepValue {
pub fn as_value(&self) -> Option<&Value> {
match self {
StepValue::Const(v) => Some(v),
StepValue::Stream(s) => s.chunk.as_ref(),
}
}
pub fn into_value(self) -> Option<Value> {
match self {
StepValue::Const(v) => Some(v),
StepValue::Stream(s) => s.chunk,
}
}
pub fn as_const(&self) -> Option<&Value> {
match self {
StepValue::Const(v) => Some(v),
_ => None,
}
}
pub fn into_const(self) -> Option<Value> {
match self {
StepValue::Const(v) => Some(v),
_ => None,
}
}
pub fn as_stream(&self) -> Option<&Stream> {
match self {
StepValue::Stream(s) => Some(s),
_ => None,
}
}
pub fn into_stream(self) -> Option<Stream> {
match self {
StepValue::Stream(s) => Some(s),
_ => None,
}
}
pub fn stream_info(&self) -> Option<StreamInfo> {
self.as_stream().map(|s| s.info)
}
pub fn is_const(&self) -> bool {
match self {
StepValue::Const(_) => true,
StepValue::Stream(_) => false,
}
}
}
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[derive(Debug, Clone)]
pub enum Attr {
I64(i64),
Usize(usize),
DatumType(DatumType),
DataFormat(DataFormat),
Padding(Padding),
Tensor(Tensor),
UsizeVec(Vec<usize>),
IsizeVec(Vec<isize>),
}
pub trait Op: Debug + objekt::Clone + Send + Sync + 'static + InferenceOp {
fn get_attributes(&self) -> HashMap<&'static str, Attr> {
hashmap!()
}
fn eval(&self, _inputs: TVec<Value>) -> Result<TVec<Value>> {
bail!("Unexpected call on op.eval(). {:?}", self)
}
fn new_buffer(&self) -> Box<OpBuffer> {
Box::new(EmptyBuffer {})
}
fn step(
&self,
_inputs: TVec<StepValue>,
_buffer: &mut Box<OpBuffer>,
) -> Result<Option<TVec<Value>>> {
bail!("Streaming is not available for operator {:?}", self)
}
fn infer_and_propagate(
&self,
inputs: TVec<TensorFact>,
outputs: TVec<TensorFact>,
) -> Result<(TVec<TensorFact>, TVec<TensorFact>)> {
let (infered_inputs, infered_outputs) = self.infer(inputs, outputs)?;
if infered_inputs.iter().all(|i| i.value.is_concrete()) {
let input_values = infered_inputs
.iter()
.map(|i| i.value.concretize().unwrap().clone().into())
.collect(); let output_value = self.eval(input_values)?.pop().unwrap();
Ok((
infered_inputs,
tvec![::analyser::helpers::tensor_to_fact(
output_value.into_tensor(),
)],
))
} else {
Ok((infered_inputs, infered_outputs))
}
}
fn final_prep(
&self,
_inputs: TVec<TensorFact>,
_outputs: TVec<TensorFact>,
) -> Result<Option<Box<Op>>> {
Ok(None)
}
fn const_value(&self) -> Option<Value> {
None
}
fn rounding_errors(&self) -> bool {
false
}
}
pub trait InferenceOp {
fn infer(
&self,
inputs: TVec<TensorFact>,
outputs: TVec<TensorFact>,
) -> Result<(TVec<TensorFact>, TVec<TensorFact>)>;
}
clone_trait_object!(Op);
#[cfg(feature = "serialize")]
impl Serialize for Op {
fn serialize<S>(&self, serializer: S) -> StdResult<S::Ok, S::Error>
where
S: Serializer,
{
self.get_attributes().serialize(serializer)
}
}
pub type OpRegister = HashMap<&'static str, fn(&::tfpb::node_def::NodeDef) -> Result<Box<Op>>>;
pub struct OpBuilder(OpRegister);
impl OpBuilder {
pub fn new() -> OpBuilder {
let mut reg = OpRegister::new();
array::register_all_ops(&mut reg);
cast::register_all_ops(&mut reg);
konst::register_all_ops(&mut reg);
math::register_all_ops(&mut reg);
nn::register_all_ops(&mut reg);
OpBuilder(reg)
}
pub fn build(&self, pb: &::tfpb::node_def::NodeDef) -> Result<Box<Op>> {
match self.0.get(pb.get_op()) {
Some(builder) => builder(pb),
None => Ok(Box::new(unimpl::UnimplementedOp(
pb.get_op().to_string(),
pb.to_owned(),
))),
}
}
}
pub trait OpBuffer: Downcast + Debug + objekt::Clone + Send + 'static {}
clone_trait_object!(OpBuffer);
impl_downcast!(OpBuffer);
#[derive(Debug, Clone)]
pub struct EmptyBuffer {}
impl OpBuffer for EmptyBuffer {}
#[derive(Debug, Clone)]
pub struct QueuesBuffer(TVec<VecDeque<Value>>);
impl OpBuffer for QueuesBuffer {}
impl QueuesBuffer {
pub fn new(size: usize) -> QueuesBuffer {
QueuesBuffer(tvec![VecDeque::new(); size])
}
pub fn append(&mut self, views: TVec<StepValue>) -> Result<()> {
if views.len() > self.0.len() {
bail!("There are more input Values than queues in the buffer.");
}
for (i, view) in views.into_iter().enumerate() {
if let Some(v) = view.into_value() {
self.0[i].push_back(v);
}
}
Ok(())
}
pub fn iter<'a>(&'a mut self) -> impl Iterator<Item = &'a VecDeque<Value>> {
self.0.iter()
}
pub fn iter_mut<'a>(&'a mut self) -> impl Iterator<Item = &'a mut VecDeque<Value>> {
self.0.iter_mut()
}
}
impl Index<usize> for QueuesBuffer {
type Output = VecDeque<Value>;
fn index(&self, index: usize) -> &VecDeque<Value> {
&self.0[index]
}
}
impl IndexMut<usize> for QueuesBuffer {
fn index_mut(&mut self, index: usize) -> &mut VecDeque<Value> {
&mut self.0[index]
}
}