baracuda_kernels/segment/
segment_prod.rs1use core::marker::PhantomData;
15
16use baracuda_cutlass::Result;
17use baracuda_driver::Stream;
18use baracuda_kernels_types::{
19 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
20 TensorRef, Workspace,
21};
22
23use super::segment_sum::{
24 build_sku, run_sorted_fw, validate_args, validate_desc, SegDescView, SegmentSumDescriptor,
25 SortedFwOp,
26};
27
28#[derive(Copy, Clone, Debug)]
30pub struct SegmentProdDescriptor {
31 pub num_inputs: i32,
33 pub embedding_dim: i32,
35 pub num_segments: i32,
37 pub element: ElementKind,
39}
40
41impl SegDescView for SegmentProdDescriptor {
42 #[inline]
43 fn view(&self) -> (i32, i32, i32, ElementKind) {
44 (
45 self.num_inputs,
46 self.embedding_dim,
47 self.num_segments,
48 self.element,
49 )
50 }
51}
52
53pub struct SegmentProdArgs<'a, T: Element> {
55 pub input: TensorRef<'a, T, 2>,
57 pub segment_ids: TensorRef<'a, i32, 1>,
59 pub output: TensorMut<'a, T, 2>,
61}
62
63pub struct SegmentProdPlan<T: Element> {
84 desc: SegmentProdDescriptor,
85 sku: KernelSku,
86 _marker: PhantomData<T>,
87}
88
89impl<T: Element> SegmentProdPlan<T> {
90 pub fn select(
92 _stream: &Stream,
93 desc: &SegmentProdDescriptor,
94 _pref: PlanPreference,
95 ) -> Result<Self> {
96 validate_desc(*desc, T::KIND, "SegmentProdPlan")?;
97 Ok(Self {
98 desc: *desc,
99 sku: build_sku::<T>(SegmentKind::SegmentProd),
100 _marker: PhantomData,
101 })
102 }
103
104 pub fn can_implement(&self, args: &SegmentProdArgs<'_, T>) -> Result<()> {
106 let proxy = SegmentSumDescriptor {
107 num_inputs: self.desc.num_inputs,
108 embedding_dim: self.desc.embedding_dim,
109 num_segments: self.desc.num_segments,
110 element: self.desc.element,
111 };
112 validate_args(
113 &proxy,
114 args.input.shape,
115 args.segment_ids.shape,
116 args.output.shape,
117 "SegmentProdPlan",
118 )
119 }
120
121 #[inline]
123 pub fn workspace_size(&self) -> usize {
124 0
125 }
126
127 #[inline]
129 pub fn sku(&self) -> KernelSku {
130 self.sku
131 }
132
133 #[inline]
135 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
136 self.sku.precision_guarantee
137 }
138
139 pub fn run(
141 &self,
142 stream: &Stream,
143 _workspace: Workspace<'_>,
144 args: SegmentProdArgs<'_, T>,
145 ) -> Result<()> {
146 self.can_implement(&args)?;
147 let total_out = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
148 if total_out == 0 {
149 return Ok(());
150 }
151 run_sorted_fw::<T>(
152 stream,
153 self.desc.num_inputs,
154 self.desc.embedding_dim,
155 self.desc.num_segments,
156 &args.input,
157 &args.segment_ids,
158 &args.output,
159 SortedFwOp::Prod,
160 )
161 }
162}