Skip to main content

tract_gpu/ops/
softmax.rs

1use crate::tensor::DeviceTensorExt;
2use tract_core::internal::*;
3use tract_core::ops::nn as core_ops_nn;
4
5use crate::tensor::DeviceTensor;
6
7pub type DispatchSoftmaxFn = fn(&DeviceTensor, usize, &DeviceTensor) -> TractResult<()>;
8
9#[derive(Clone)]
10pub struct GpuSoftmax {
11    pub axes: TVec<usize>,
12    pub backend_name: &'static str,
13    pub dispatch: DispatchSoftmaxFn,
14}
15
16impl GpuSoftmax {
17    pub fn new(
18        axes: TVec<usize>,
19        backend_name: &'static str,
20        dispatch: DispatchSoftmaxFn,
21    ) -> TractResult<Self> {
22        ensure!(
23            axes.len() == 1,
24            "Only one axis of softmax is supported by {}Softmax",
25            backend_name
26        );
27        Ok(Self { axes, backend_name, dispatch })
28    }
29
30    pub fn from_tract_core(
31        core_softmax: &core_ops_nn::Softmax,
32        backend_name: &'static str,
33        dispatch: DispatchSoftmaxFn,
34    ) -> TractResult<Self> {
35        ensure!(core_softmax.quant_output_dt.is_none());
36        Self::new(core_softmax.axes.clone(), backend_name, dispatch)
37    }
38}
39
40impl std::fmt::Debug for GpuSoftmax {
41    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
42        write!(f, "{}Softmax(axes: {:?})", self.backend_name, self.axes)
43    }
44}
45
46impl PartialEq for GpuSoftmax {
47    fn eq(&self, other: &Self) -> bool {
48        self.backend_name == other.backend_name && self.axes == other.axes
49    }
50}
51
52impl Eq for GpuSoftmax {}
53
54impl std::hash::Hash for GpuSoftmax {
55    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
56        self.backend_name.hash(state);
57        self.axes.hash(state);
58    }
59}
60
61impl Op for GpuSoftmax {
62    fn name(&self) -> StaticName {
63        format!("{}Softmax", self.backend_name).into()
64    }
65
66    fn info(&self) -> TractResult<Vec<String>> {
67        Ok(vec![format!("axes: {:?}", self.axes)])
68    }
69
70    op_as_typed_op!();
71}
72
73impl EvalOp for GpuSoftmax {
74    fn is_stateless(&self) -> bool {
75        true
76    }
77
78    fn eval_with_session(
79        &self,
80        node_id: usize,
81        session: &TurnState,
82        inputs: TVec<TValue>,
83    ) -> TractResult<TVec<TValue>> {
84        let input_value = args_1!(inputs);
85        let input = input_value.to_device_tensor()?;
86        let output = crate::session_handler::make_tensor_for_node(
87            session,
88            node_id,
89            input.datum_type(),
90            input.shape(),
91        )?;
92        (self.dispatch)(input, self.axes[0], &output)?;
93        Ok(tvec!(output.into_tensor().into_tvalue()))
94    }
95}
96
97impl TypedOp for GpuSoftmax {
98    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
99        crate::utils::facts_to_device_facts(inputs, |facts| {
100            let dt = facts[0].datum_type;
101            let fact = dt.fact(facts[0].shape.clone());
102            Ok(tvec!(fact))
103        })
104        .with_context(|| format!("Error while computing facts for {:?}", self.name()))
105    }
106
107    fn axes_mapping(
108        &self,
109        inputs: &[&TypedFact],
110        outputs: &[&TypedFact],
111    ) -> TractResult<AxesMapping> {
112        AxesMapping::natural(inputs, outputs)
113    }
114
115    fn change_axes(
116        &self,
117        model: &TypedModel,
118        node: &TypedNode,
119        _io: InOut,
120        change: &AxisOp,
121    ) -> TractResult<Option<AxisChangeConsequence>> {
122        let axes: Option<TVec<usize>> =
123            self.axes.iter().map(|it| change.transform_axis(*it)).collect();
124        if let Some(axes) = axes {
125            Ok(Some(AxisChangeConsequence::new(
126                model,
127                node,
128                Some(Box::new(GpuSoftmax {
129                    axes,
130                    backend_name: self.backend_name,
131                    dispatch: self.dispatch,
132                })),
133                change,
134            )))
135        } else {
136            Ok(None)
137        }
138    }
139
140    as_op!();
141}