baracuda_kernels/segment/
segment_min.rs1use core::marker::PhantomData;
10
11use baracuda_cutlass::Result;
12use baracuda_driver::Stream;
13use baracuda_kernels_types::{
14 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
15 TensorRef, Workspace,
16};
17
18use super::segment_sum::{
19 build_sku, run_sorted_fw, validate_args, validate_desc, SegDescView, SegmentSumDescriptor,
20 SortedFwOp,
21};
22
23#[derive(Copy, Clone, Debug)]
25pub struct SegmentMinDescriptor {
26 pub num_inputs: i32,
28 pub embedding_dim: i32,
30 pub num_segments: i32,
32 pub element: ElementKind,
34}
35
36impl SegDescView for SegmentMinDescriptor {
37 #[inline]
38 fn view(&self) -> (i32, i32, i32, ElementKind) {
39 (
40 self.num_inputs,
41 self.embedding_dim,
42 self.num_segments,
43 self.element,
44 )
45 }
46}
47
48pub struct SegmentMinArgs<'a, T: Element> {
50 pub input: TensorRef<'a, T, 2>,
52 pub segment_ids: TensorRef<'a, i32, 1>,
54 pub output: TensorMut<'a, T, 2>,
56}
57
58pub struct SegmentMinPlan<T: Element> {
78 desc: SegmentMinDescriptor,
79 sku: KernelSku,
80 _marker: PhantomData<T>,
81}
82
83impl<T: Element> SegmentMinPlan<T> {
84 pub fn select(
86 _stream: &Stream,
87 desc: &SegmentMinDescriptor,
88 _pref: PlanPreference,
89 ) -> Result<Self> {
90 validate_desc(*desc, T::KIND, "SegmentMinPlan")?;
91 Ok(Self {
92 desc: *desc,
93 sku: build_sku::<T>(SegmentKind::SegmentMin),
94 _marker: PhantomData,
95 })
96 }
97
98 pub fn can_implement(&self, args: &SegmentMinArgs<'_, T>) -> Result<()> {
100 let proxy = SegmentSumDescriptor {
101 num_inputs: self.desc.num_inputs,
102 embedding_dim: self.desc.embedding_dim,
103 num_segments: self.desc.num_segments,
104 element: self.desc.element,
105 };
106 validate_args(
107 &proxy,
108 args.input.shape,
109 args.segment_ids.shape,
110 args.output.shape,
111 "SegmentMinPlan",
112 )
113 }
114
115 #[inline]
117 pub fn workspace_size(&self) -> usize {
118 0
119 }
120
121 #[inline]
123 pub fn sku(&self) -> KernelSku {
124 self.sku
125 }
126
127 #[inline]
129 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
130 self.sku.precision_guarantee
131 }
132
133 pub fn run(
135 &self,
136 stream: &Stream,
137 _workspace: Workspace<'_>,
138 args: SegmentMinArgs<'_, T>,
139 ) -> Result<()> {
140 self.can_implement(&args)?;
141 let total_out = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
142 if total_out == 0 {
143 return Ok(());
144 }
145 run_sorted_fw::<T>(
146 stream,
147 self.desc.num_inputs,
148 self.desc.embedding_dim,
149 self.desc.num_segments,
150 &args.input,
151 &args.segment_ids,
152 &args.output,
153 SortedFwOp::Min,
154 )
155 }
156}