use derive_new::new;
use tract_core::internal::tract_smallvec::ToSmallVec;
use tract_core::internal::*;
use tract_core::ops::OpStateFreeze;
use tract_gpu::ops::change_axes::GpuAxisOp;
use tract_gpu::tensor::{DeviceTensor, DeviceTensorExt};
#[derive(Clone, Debug, new, PartialEq, Eq)]
pub struct MetalFusedAxisOp {
pub grouped_axis_ops: TVec<TVec<GpuAxisOp>>,
pub op: Box<dyn TypedOp>,
}
#[derive(Debug, Clone, new)]
pub struct MetalFusedAxisOpState {
pub op_state: Box<dyn OpState>,
}
fn compute_reshaped_inputs(
inputs: TVec<TValue>,
grouped_axis_ops: &TVec<TVec<GpuAxisOp>>,
session: &TurnState,
) -> TractResult<TVec<TValue>> {
inputs
.into_iter()
.zip(grouped_axis_ops.iter())
.map(|(input, axis_ops)| {
if axis_ops.is_empty() {
return Ok(input);
};
let m_input = input.to_device_tensor()?;
let reshaped_input = axis_ops.iter().try_fold(
m_input.clone(),
|t, axis_op| -> TractResult<DeviceTensor> {
let new_shape = match &axis_op.inner {
AxisOp::Reshape(skip, from, to) => {
let from =
from.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
let to = to.iter().map(|d| d.eval(&session.resolved_symbols)).collect();
let mut shape: TVec<usize> = t.shape().into();
AxisOp::Reshape(*skip, from, to)
.change_shape_array(&mut shape, false)?;
shape
}
AxisOp::Add(_) | AxisOp::Rm(_) | AxisOp::Move(..) => {
let mut shape: TVec<usize> = t.shape().into();
axis_op.inner.change_shape_array(&mut shape, false)?;
shape
}
};
if let AxisOp::Move(from, to) = axis_op.inner {
let mut out_strides: TVec<isize> = t.strides().to_smallvec();
let removed_stride = out_strides.remove(from);
out_strides.insert(to, removed_stride);
let tmp_t = t.reshaped(new_shape)?;
tmp_t.restrided(out_strides)
} else {
t.reshaped(new_shape)
}
},
)?;
Ok(reshaped_input.into_tensor().into())
})
.collect::<TractResult<TVec<_>>>()
}
impl OpState for MetalFusedAxisOpState {
fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
self.op_state.init_tensor_fact()
}
fn load_from(
&mut self,
session: &mut TurnState,
states: &mut dyn Iterator<Item = tract_core::value::TValue>,
) -> TractResult<()> {
self.op_state.load_from(session, states)
}
fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
self.op_state.save_to(states)
}
fn resolve_symbols(&mut self, session: &mut TurnState) -> TractResult<()> {
self.op_state.resolve_symbols(session)
}
fn eval(
&mut self,
session: &mut TurnState,
op: &dyn Op,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let fused_axis_op = op.downcast_ref::<MetalFusedAxisOp>().unwrap();
let inputs = compute_reshaped_inputs(inputs, &fused_axis_op.grouped_axis_ops, session)?;
self.op_state.eval(session, fused_axis_op.op.as_op(), inputs)
}
}
#[derive(Debug, Clone)]
pub struct FrozenMetalFusedAxisOpState {
pub op_state: Box<dyn FrozenOpState>,
}
impl OpStateFreeze for MetalFusedAxisOpState {
fn freeze(&self) -> Box<dyn FrozenOpState + 'static> {
Box::new(FrozenMetalFusedAxisOpState { op_state: self.op_state.freeze() })
}
fn freeze_into(self: Box<Self>) -> Box<dyn FrozenOpState> {
Box::new(FrozenMetalFusedAxisOpState { op_state: self.op_state.freeze_into() })
}
}
impl FrozenOpState for FrozenMetalFusedAxisOpState {
fn unfreeze(&self) -> Box<dyn OpState> {
Box::new(MetalFusedAxisOpState { op_state: self.op_state.unfreeze() })
}
}
impl Op for MetalFusedAxisOp {
fn name(&self) -> StaticName {
self.op.name()
}
fn info(&self) -> TractResult<Vec<String>> {
let mut info = self.op.info()?;
for (idx, axis_ops) in self.grouped_axis_ops.iter().enumerate() {
if !axis_ops.is_empty() {
info.push(format!(
"Fused axis Op on Input #{idx}: {}",
axis_ops
.iter()
.map(|axis_op| Ok(format!(
"{} - {}",
axis_op.name(),
axis_op.info()?.join(" | ")
)))
.collect::<TractResult<TVec<_>>>()?
.join(" | ")
));
}
}
Ok(info)
}
op_as_typed_op!();
}
impl EvalOp for MetalFusedAxisOp {
fn is_stateless(&self) -> bool {
self.op.is_stateless()
}
fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
if let Some(state) = self.op.state(session, node_id)? {
Ok(Some(Box::new(MetalFusedAxisOpState { op_state: state })))
} else {
Ok(None)
}
}
fn eval_with_session(
&self,
node_id: usize,
session: &TurnState,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let inputs = compute_reshaped_inputs(inputs, &self.grouped_axis_ops, session)?;
self.op.eval_with_session(node_id, session, inputs)
}
}
impl TypedOp for MetalFusedAxisOp {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
ensure!(
inputs.len() == self.grouped_axis_ops.len(),
"Number of inputs and fused axis ops are not aligned"
);
let inputs = inputs
.iter()
.zip(self.grouped_axis_ops.iter())
.map(|(i, axis_ops)| {
axis_ops.iter().try_fold((*i).clone(), |reshaped_i, axis_op| {
Ok(axis_op.output_facts(&[&reshaped_i])?[0].clone())
})
})
.collect::<TractResult<TVec<_>>>()?;
let inputs_ref = inputs.iter().collect::<TVec<_>>();
self.op.output_facts(&inputs_ref)
}
as_op!();
}