use crate::tensor::DeviceTensorExt;
use crate::utils::compute_broadcast_strides;
use tract_core::internal::*;
use tract_core::ops::array::Slice;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct GpuSlice {
pub inner: Slice,
}
impl GpuSlice {
pub fn new(inner: Slice) -> Self {
Self { inner }
}
}
impl Op for GpuSlice {
fn name(&self) -> StaticName {
"GpuSlice".into()
}
fn info(&self) -> TractResult<Vec<String>> {
self.inner.info()
}
op_as_typed_op!();
}
impl EvalOp for GpuSlice {
fn is_stateless(&self) -> bool {
true
}
fn eval_with_session(
&self,
node_id: usize,
session: &TurnState,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let input_value = args_1!(inputs);
let input = input_value.to_device_tensor()?;
let start = self.inner.start.eval(&session.resolved_symbols).to_usize()?;
let end = self.inner.end.eval(&session.resolved_symbols).to_usize()?;
let axis = self.inner.axis;
let input_shape = input.shape();
let input_strides = input.strides();
let input_dt = input.datum_type();
ensure!(
end <= input_shape[axis] && start <= end,
"Invalid range {}..{} for slicing {:?} on axis {}",
start,
end,
input,
axis
);
let mut o_shape: TVec<usize> = input_shape.into();
o_shape[axis] = end - start;
let offset = (start * input_strides[axis] as usize) * input_dt.size_of();
let output = crate::session_handler::make_tensor_for_node(
session,
node_id,
input.datum_type(),
&o_shape,
)?;
if o_shape[axis] != 0 {
let broadcast_strides: TVec<isize> =
compute_broadcast_strides(&o_shape, input_strides)?;
let ctx = crate::device::get_context()?;
ctx.copy_nd(
input,
offset,
&broadcast_strides,
&output,
0,
output.shape(),
output.strides(),
)?;
}
Ok(tvec![output.into_tensor().into_tvalue()])
}
}
impl TypedOp for GpuSlice {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
crate::utils::facts_to_device_facts(inputs, |facts| self.inner.output_facts(facts))
.with_context(|| format!("Error while computing facts for {:?}", self.name()))
}
fn concretize_dims(
&self,
_source: &TypedModel,
node: &TypedNode,
target: &mut TypedModel,
mapping: &HashMap<OutletId, OutletId>,
values: &SymbolValues,
) -> TractResult<TVec<OutletId>> {
let op = GpuSlice {
inner: Slice {
axis: self.inner.axis,
start: self.inner.start.eval(values),
end: self.inner.end.eval(values),
},
};
let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
target.wire_node(&node.name, op, &inputs)
}
as_op!();
}