use analyser::rules::prelude::*;
use ops::prelude::*;
pub fn squeeze(pb: &::tfpb::node_def::NodeDef) -> Result<Box<Op>> {
let mut squeeze_dims = pb.get_attr_opt_list_int("squeeze_dims")?;
if let Some(ref mut squeeze_dims) = squeeze_dims {
squeeze_dims.sort();
squeeze_dims.reverse();
}
let t = pb.get_attr_datum_type("T")?;
Ok(boxed_new!(Squeeze(t)(squeeze_dims)))
}
#[derive(Debug, Clone, new)]
pub struct Squeeze<T: Datum> {
squeeze_dims: Option<Vec<isize>>,
t: PhantomData<T>,
}
impl<T: Datum> Squeeze<T> {
fn squeezable(&self, ix: usize, d: usize, stream_dim: Option<usize>) -> bool {
stream_dim != Some(ix) && d == 1
&& self
.squeeze_dims
.as_ref()
.filter(|v| v.len() > 0)
.map(|squeeze_dims| squeeze_dims.contains(&(ix as _)))
.unwrap_or(true)
}
fn squeeze_shape(&self, shape: &[usize], stream_dim: Option<usize>) -> Vec<usize> {
shape
.iter()
.enumerate()
.filter_map(|(ix, d)| {
if self.squeezable(ix, *d, stream_dim) {
None
} else {
Some(*d)
}
})
.collect()
}
}
impl<T: Datum> Op for Squeeze<T> {
fn eval(&self, mut inputs: TVec<Value>) -> Result<TVec<Value>> {
let input = args_1!(inputs);
let data = input.into_array::<T>()?;
let shape = self.squeeze_shape(data.shape(), None);
Ok(tvec![data.into_shape(shape)?.into()])
}
fn get_attributes(&self) -> HashMap<&'static str, Attr> {
let mut attrs = hashmap!{ "T" => Attr::DatumType(T::datum_type()) };
if let Some(dim) = self.squeeze_dims.as_ref() {
attrs.insert("squeeze_dims", Attr::IsizeVec(dim.clone()));
}
attrs
}
fn step(
&self,
mut inputs: TVec<StepValue>,
_buffer: &mut Box<OpBuffer>,
) -> Result<Option<TVec<Value>>> {
let input = args_1!(inputs);
if let Some(Stream {
info,
chunk: Some(chunk),
..
}) = input.into_stream()
{
let chunk = chunk.into_tensor().into_array::<T>()?;
let shape = self.squeeze_shape(chunk.shape(), Some(info.axis));
Ok(Some(tvec![Tensor::from(chunk.into_shape(shape)?).into(),]))
} else {
Ok(None)
}
}
}
impl<T: Datum> InferenceRulesOp for Squeeze<T> {
fn rules<'r, 'p: 'r, 's: 'r>(
&'s self,
solver: &mut Solver<'r>,
inputs: &'p TensorsProxy,
outputs: &'p TensorsProxy,
) {
solver
.equals(&inputs.len, 1)
.equals(&outputs.len, 1)
.equals(&inputs[0].datum_type, &outputs[0].datum_type)
.equals(&inputs[0].datum_type, T::datum_type())
.given(&inputs[0].shape, move |solver, shape: Vec<TDim>| {
let stream_dim = shape.iter().position(|d| d.is_stream());
let shape: Vec<TDim> = shape
.into_iter()
.enumerate()
.filter(|(ix, d)| {
!self.squeezable(*ix, d.to_integer().unwrap_or(1) as usize, stream_dim)
})
.map(|(_, d)| d)
.collect();
solver.equals(&outputs[0].shape, shape);
});
}
}
#[cfg(test)]
mod tests {
#![allow(non_snake_case)]
use super::*;
use dim::TDim;
use ndarray::*;
use ops::InferenceOp;
use Tensor;
fn run<I>(op: Squeeze<i32>, input: I) -> Tensor
where
I: Into<Tensor>,
{
op.eval(tvec![input.into().into()])
.unwrap()
.pop()
.unwrap()
.into_tensor()
}
#[test]
fn squeeze_1() {
assert_eq!(
run(Squeeze::new(None), Array::from_elem([1, 2, 1, 3, 1, 1], 0)).shape(),
&[2, 3]
);
}
#[test]
fn squeeze_2() {
assert_eq!(
run(
Squeeze::new(Some(vec![2, 4])),
Array::from_elem([1, 2, 1, 3, 1, 1], 0)
).shape(),
&[1, 2, 3, 1]
);
}
#[test]
fn squeeze_3() {
assert_eq!(
run(
Squeeze::new(Some(vec![])),
Array::from_elem([1, 2, 1, 3, 1, 1], 0)
).shape(),
&[2, 3]
);
}
#[test]
fn squeeze_inference_1() {
let input = TensorFact::default()
.with_datum_type(DatumType::TDim)
.with_shape(shapefact![1, 1, (TDim::stream() - 2), 16]);
let op = Squeeze::<TDim>::new(Some(vec![1]));
let inferred = op
.infer(tvec!(input), tvec!(TensorFact::default()))
.unwrap();
let expect: TVec<_> = tvec!(
TensorFact::default()
.with_datum_type(DatumType::TDim)
.with_shape(shapefact![1, (TDim::stream() - 2), 16])
);
assert_eq!(inferred.1, expect);
}
}