baracuda_kernels/segment/
segment_max.rs1use core::marker::PhantomData;
13
14use baracuda_cutlass::Result;
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
18 TensorRef, Workspace,
19};
20
21use super::segment_sum::{
22 build_sku, run_sorted_fw, validate_args, validate_desc, SegDescView, SegmentSumDescriptor,
23 SortedFwOp,
24};
25
26#[derive(Copy, Clone, Debug)]
28pub struct SegmentMaxDescriptor {
29 pub num_inputs: i32,
31 pub embedding_dim: i32,
33 pub num_segments: i32,
35 pub element: ElementKind,
37}
38
39impl SegDescView for SegmentMaxDescriptor {
40 #[inline]
41 fn view(&self) -> (i32, i32, i32, ElementKind) {
42 (
43 self.num_inputs,
44 self.embedding_dim,
45 self.num_segments,
46 self.element,
47 )
48 }
49}
50
51pub struct SegmentMaxArgs<'a, T: Element> {
53 pub input: TensorRef<'a, T, 2>,
55 pub segment_ids: TensorRef<'a, i32, 1>,
57 pub output: TensorMut<'a, T, 2>,
59}
60
61pub struct SegmentMaxPlan<T: Element> {
81 desc: SegmentMaxDescriptor,
82 sku: KernelSku,
83 _marker: PhantomData<T>,
84}
85
86impl<T: Element> SegmentMaxPlan<T> {
87 pub fn select(
89 _stream: &Stream,
90 desc: &SegmentMaxDescriptor,
91 _pref: PlanPreference,
92 ) -> Result<Self> {
93 validate_desc(*desc, T::KIND, "SegmentMaxPlan")?;
94 Ok(Self {
95 desc: *desc,
96 sku: build_sku::<T>(SegmentKind::SegmentMax),
97 _marker: PhantomData,
98 })
99 }
100
101 pub fn can_implement(&self, args: &SegmentMaxArgs<'_, T>) -> Result<()> {
103 let proxy = SegmentSumDescriptor {
104 num_inputs: self.desc.num_inputs,
105 embedding_dim: self.desc.embedding_dim,
106 num_segments: self.desc.num_segments,
107 element: self.desc.element,
108 };
109 validate_args(
110 &proxy,
111 args.input.shape,
112 args.segment_ids.shape,
113 args.output.shape,
114 "SegmentMaxPlan",
115 )
116 }
117
118 #[inline]
120 pub fn workspace_size(&self) -> usize {
121 0
122 }
123
124 #[inline]
126 pub fn sku(&self) -> KernelSku {
127 self.sku
128 }
129
130 #[inline]
132 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
133 self.sku.precision_guarantee
134 }
135
136 pub fn run(
138 &self,
139 stream: &Stream,
140 _workspace: Workspace<'_>,
141 args: SegmentMaxArgs<'_, T>,
142 ) -> Result<()> {
143 self.can_implement(&args)?;
144 let total_out = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
145 if total_out == 0 {
146 return Ok(());
147 }
148 run_sorted_fw::<T>(
149 stream,
150 self.desc.num_inputs,
151 self.desc.embedding_dim,
152 self.desc.num_segments,
153 &args.input,
154 &args.segment_ids,
155 &args.output,
156 SortedFwOp::Max,
157 )
158 }
159}