baracuda_kernels/segment/
segment_mean.rs1use core::marker::PhantomData;
11
12use baracuda_cutlass::Result;
13use baracuda_driver::Stream;
14use baracuda_kernels_types::{
15 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
16 TensorRef, Workspace,
17};
18
19use super::segment_sum::{
20 build_sku, run_sorted_fw, validate_args, validate_desc, SegDescView, SegmentSumDescriptor,
21 SortedFwOp,
22};
23
24#[derive(Copy, Clone, Debug)]
26pub struct SegmentMeanDescriptor {
27 pub num_inputs: i32,
29 pub embedding_dim: i32,
31 pub num_segments: i32,
33 pub element: ElementKind,
35}
36
37impl SegDescView for SegmentMeanDescriptor {
38 #[inline]
39 fn view(&self) -> (i32, i32, i32, ElementKind) {
40 (
41 self.num_inputs,
42 self.embedding_dim,
43 self.num_segments,
44 self.element,
45 )
46 }
47}
48
49pub struct SegmentMeanArgs<'a, T: Element> {
51 pub input: TensorRef<'a, T, 2>,
53 pub segment_ids: TensorRef<'a, i32, 1>,
55 pub output: TensorMut<'a, T, 2>,
57}
58
59pub struct SegmentMeanPlan<T: Element> {
80 desc: SegmentMeanDescriptor,
81 sku: KernelSku,
82 _marker: PhantomData<T>,
83}
84
85impl<T: Element> SegmentMeanPlan<T> {
86 pub fn select(
88 _stream: &Stream,
89 desc: &SegmentMeanDescriptor,
90 _pref: PlanPreference,
91 ) -> Result<Self> {
92 validate_desc(*desc, T::KIND, "SegmentMeanPlan")?;
93 Ok(Self {
94 desc: *desc,
95 sku: build_sku::<T>(SegmentKind::SegmentMean),
96 _marker: PhantomData,
97 })
98 }
99
100 pub fn can_implement(&self, args: &SegmentMeanArgs<'_, T>) -> Result<()> {
102 let proxy = SegmentSumDescriptor {
103 num_inputs: self.desc.num_inputs,
104 embedding_dim: self.desc.embedding_dim,
105 num_segments: self.desc.num_segments,
106 element: self.desc.element,
107 };
108 validate_args(
109 &proxy,
110 args.input.shape,
111 args.segment_ids.shape,
112 args.output.shape,
113 "SegmentMeanPlan",
114 )
115 }
116
117 #[inline]
119 pub fn workspace_size(&self) -> usize {
120 0
121 }
122
123 #[inline]
125 pub fn sku(&self) -> KernelSku {
126 self.sku
127 }
128
129 #[inline]
131 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
132 self.sku.precision_guarantee
133 }
134
135 pub fn run(
137 &self,
138 stream: &Stream,
139 _workspace: Workspace<'_>,
140 args: SegmentMeanArgs<'_, T>,
141 ) -> Result<()> {
142 self.can_implement(&args)?;
143 let total_out = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
144 if total_out == 0 {
145 return Ok(());
146 }
147 run_sorted_fw::<T>(
148 stream,
149 self.desc.num_inputs,
150 self.desc.embedding_dim,
151 self.desc.num_segments,
152 &args.input,
153 &args.segment_ids,
154 &args.output,
155 SortedFwOp::Mean,
156 )
157 }
158}