use analyser::rules::prelude::*;
use ndarray::prelude::*;
use ops::prelude::*;
pub fn build(pb: &::tfpb::node_def::NodeDef) -> Result<Box<Op>> {
let n = pb.get_attr_int("N")?;
let t = pb.get_attr_datum_type("T")?;
let tidx = pb.get_attr_datum_type("Tidx")?;
Ok(boxed_new!(ConcatV2(t)(n, tidx)))
}
#[derive(Debug, Clone, new)]
pub struct ConcatV2<T: Datum> {
n: usize,
tidx: DatumType,
t: PhantomData<T>,
}
impl<T: Datum> Op for ConcatV2<T> {
fn get_attributes(&self) -> HashMap<&'static str, Attr> {
hashmap!{
"n" => Attr::Usize(self.n),
"t" => Attr::DatumType(T::datum_type()),
"tidx" => Attr::DatumType(self.tidx),
}
}
fn eval(&self, mut inputs: TVec<Value>) -> Result<TVec<Value>> {
let axis: i32 = inputs
.pop()
.and_then(|t| t.as_i32())
.ok_or("Expected a i32 scalar")?;
let mats: Result<Vec<ArrayViewD<T>>> =
inputs.iter().map(|mat| mat.to_array_view()).collect();
let result = ::ndarray::stack(Axis(axis as usize), &*mats?)?;
Ok(tvec![result.into()])
}
fn new_buffer(&self) -> Box<OpBuffer> {
Box::new(QueuesBuffer::new(self.n))
}
fn step(
&self,
mut inputs: TVec<StepValue>,
buffer: &mut Box<OpBuffer>,
) -> Result<Option<TVec<Value>>> {
let n = inputs
.pop()
.ok_or("Unexpectedly found zero inputs in ConcatV2")?;
let axis_tensor = n.into_const().ok_or("Axis input should not be streamed.")?;
let axis: i32 = axis_tensor.as_i32().ok_or("Expected a i32 scalar")?;
if inputs
.iter()
.all(|i| i.stream_info().map(|i| i.axis) == Some(axis as usize))
{
let chunk = inputs
.into_iter()
.map(|sv| sv.into_value().ok_or("Expected a value".into()))
.collect::<Result<TVec<Value>>>()?;
Ok(Some(chunk))
} else {
let buffer = buffer
.downcast_mut::<QueuesBuffer>()
.ok_or("The buffer can't be downcasted to QueuesBuffer.")?;
buffer.append(inputs)?;
if buffer.iter_mut().any(|q| q.is_empty()) {
Ok(None)
} else {
let mut chunks = buffer
.iter_mut()
.map(|b| b.pop_front().unwrap())
.collect::<TVec<_>>();
chunks.push(axis_tensor);
Ok(Some(self.eval(chunks)?))
}
}
}
}
impl<T: Datum> InferenceRulesOp for ConcatV2<T> {
fn rules<'r, 'p: 'r, 's: 'r>(
&'s self,
solver: &mut Solver<'r>,
inputs: &'p TensorsProxy,
outputs: &'p TensorsProxy,
) {
let n = self.n;
solver
.equals(&inputs.len, n as isize + 1)
.equals(&outputs.len, 1)
.equals_all((0..self.n).map(|i| (&inputs[i].datum_type).bex()).collect())
.equals(&outputs[0].datum_type, &inputs[0].datum_type)
.equals(&inputs[n].datum_type, DatumType::I32)
.equals_all((0..self.n).map(|i| (&inputs[i].rank).bex()).collect())
.equals(&inputs[n].rank, 0)
.equals(&outputs[0].rank, &inputs[0].rank)
.given(&inputs[n].value, move |solver, axis: Tensor| {
let axis = axis.as_i32().unwrap() as usize; trace!("axis for Concatv2: {}", axis);
(0..axis).for_each(|d| {
solver.equals_all((0..n).map(|i| (&inputs[i].shape[d]).bex()).collect());
});
(0..axis).for_each(|d| {
solver.equals(&inputs[0].shape[d], &outputs[0].shape[d]);
});
solver.given(&inputs[0].rank, move |solver, rank: isize| {
trace!("Given rank {}", rank);
((axis + 1)..(rank as usize)).for_each(|d| {
solver.equals(&inputs[0].shape[d], &outputs[0].shape[d]);
});
((axis + 1)..(rank as usize)).for_each(|d| {
solver.equals_all((0..n).map(|i| (&inputs[i].shape[d]).bex()).collect());
});
});
let mut concat_dim = -1isize * outputs[0].shape[axis].bex();
for i in 0..n {
concat_dim = concat_dim + inputs[i].shape[axis].bex();
}
solver.equals_zero(concat_dim);
});
}
}