1use crate::tensor::DeviceTensorExt;
2use tract_core::internal::*;
3use tract_core::ops::nn as core_ops_nn;
4
5use crate::tensor::DeviceTensor;
6
7pub type DispatchSoftmaxFn = fn(&DeviceTensor, usize, &DeviceTensor) -> TractResult<()>;
8
9#[derive(Clone)]
10pub struct GpuSoftmax {
11 pub axes: TVec<usize>,
12 pub backend_name: &'static str,
13 pub dispatch: DispatchSoftmaxFn,
14}
15
16impl GpuSoftmax {
17 pub fn new(
18 axes: TVec<usize>,
19 backend_name: &'static str,
20 dispatch: DispatchSoftmaxFn,
21 ) -> TractResult<Self> {
22 ensure!(
23 axes.len() == 1,
24 "Only one axis of softmax is supported by {}Softmax",
25 backend_name
26 );
27 Ok(Self { axes, backend_name, dispatch })
28 }
29
30 pub fn from_tract_core(
31 core_softmax: &core_ops_nn::Softmax,
32 backend_name: &'static str,
33 dispatch: DispatchSoftmaxFn,
34 ) -> TractResult<Self> {
35 ensure!(core_softmax.quant_output_dt.is_none());
36 Self::new(core_softmax.axes.clone(), backend_name, dispatch)
37 }
38}
39
40impl std::fmt::Debug for GpuSoftmax {
41 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
42 write!(f, "{}Softmax(axes: {:?})", self.backend_name, self.axes)
43 }
44}
45
46impl PartialEq for GpuSoftmax {
47 fn eq(&self, other: &Self) -> bool {
48 self.backend_name == other.backend_name && self.axes == other.axes
49 }
50}
51
52impl Eq for GpuSoftmax {}
53
54impl std::hash::Hash for GpuSoftmax {
55 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
56 self.backend_name.hash(state);
57 self.axes.hash(state);
58 }
59}
60
61impl Op for GpuSoftmax {
62 fn name(&self) -> StaticName {
63 format!("{}Softmax", self.backend_name).into()
64 }
65
66 fn info(&self) -> TractResult<Vec<String>> {
67 Ok(vec![format!("axes: {:?}", self.axes)])
68 }
69
70 op_as_typed_op!();
71}
72
73impl EvalOp for GpuSoftmax {
74 fn is_stateless(&self) -> bool {
75 true
76 }
77
78 fn eval_with_session(
79 &self,
80 node_id: usize,
81 session: &TurnState,
82 inputs: TVec<TValue>,
83 ) -> TractResult<TVec<TValue>> {
84 let input_value = args_1!(inputs);
85 let input = input_value.to_device_tensor()?;
86 let output = crate::session_handler::make_tensor_for_node(
87 session,
88 node_id,
89 input.datum_type(),
90 input.shape(),
91 )?;
92 (self.dispatch)(input, self.axes[0], &output)?;
93 Ok(tvec!(output.into_tensor().into_tvalue()))
94 }
95}
96
97impl TypedOp for GpuSoftmax {
98 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
99 crate::utils::facts_to_device_facts(inputs, |facts| {
100 let dt = facts[0].datum_type;
101 let fact = dt.fact(facts[0].shape.clone());
102 Ok(tvec!(fact))
103 })
104 .with_context(|| format!("Error while computing facts for {:?}", self.name()))
105 }
106
107 fn axes_mapping(
108 &self,
109 inputs: &[&TypedFact],
110 outputs: &[&TypedFact],
111 ) -> TractResult<AxesMapping> {
112 AxesMapping::natural(inputs, outputs)
113 }
114
115 fn change_axes(
116 &self,
117 model: &TypedModel,
118 node: &TypedNode,
119 _io: InOut,
120 change: &AxisOp,
121 ) -> TractResult<Option<AxisChangeConsequence>> {
122 let axes: Option<TVec<usize>> =
123 self.axes.iter().map(|it| change.transform_axis(*it)).collect();
124 if let Some(axes) = axes {
125 Ok(Some(AxisChangeConsequence::new(
126 model,
127 node,
128 Some(Box::new(GpuSoftmax {
129 axes,
130 backend_name: self.backend_name,
131 dispatch: self.dispatch,
132 })),
133 change,
134 )))
135 } else {
136 Ok(None)
137 }
138 }
139
140 as_op!();
141}