baracuda_kernels/segment/
unsorted_segment_min.rs1use core::ffi::c_void;
9use core::marker::PhantomData;
10
11use baracuda_cutlass::{Error, Result};
12use baracuda_driver::Stream;
13use baracuda_kernels_types::{
14 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
15 TensorRef, Workspace,
16};
17
18use super::map_status;
19use super::segment_sum::{validate_desc, SegDescView};
20use super::unsorted_segment_sum::{build_unsorted_sku, validate_unsorted_args};
21
22#[derive(Copy, Clone, Debug)]
24pub struct UnsortedSegmentMinDescriptor {
25 pub num_inputs: i32,
27 pub embedding_dim: i32,
29 pub num_segments: i32,
31 pub element: ElementKind,
33}
34
35impl SegDescView for UnsortedSegmentMinDescriptor {
36 #[inline]
37 fn view(&self) -> (i32, i32, i32, ElementKind) {
38 (
39 self.num_inputs,
40 self.embedding_dim,
41 self.num_segments,
42 self.element,
43 )
44 }
45}
46
47pub struct UnsortedSegmentMinArgs<'a, T: Element> {
49 pub input: TensorRef<'a, T, 2>,
51 pub segment_ids: TensorRef<'a, i32, 1>,
53 pub output: TensorMut<'a, T, 2>,
55}
56
57pub struct UnsortedSegmentMinPlan<T: Element> {
77 desc: UnsortedSegmentMinDescriptor,
78 sku: KernelSku,
79 _marker: PhantomData<T>,
80}
81
82impl<T: Element> UnsortedSegmentMinPlan<T> {
83 pub fn select(
85 _stream: &Stream,
86 desc: &UnsortedSegmentMinDescriptor,
87 _pref: PlanPreference,
88 ) -> Result<Self> {
89 validate_desc(*desc, T::KIND, "UnsortedSegmentMinPlan")?;
90 Ok(Self {
91 desc: *desc,
92 sku: build_unsorted_sku::<T>(SegmentKind::UnsortedSegmentMin),
93 _marker: PhantomData,
94 })
95 }
96
97 pub fn can_implement(&self, args: &UnsortedSegmentMinArgs<'_, T>) -> Result<()> {
99 validate_unsorted_args(
100 self.desc.num_inputs,
101 self.desc.embedding_dim,
102 self.desc.num_segments,
103 args.input.shape,
104 args.segment_ids.shape,
105 args.output.shape,
106 "UnsortedSegmentMinPlan",
107 )
108 }
109
110 #[inline]
112 pub fn workspace_size(&self) -> usize {
113 0
114 }
115
116 #[inline]
118 pub fn sku(&self) -> KernelSku {
119 self.sku
120 }
121
122 #[inline]
124 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
125 self.sku.precision_guarantee
126 }
127
128 pub fn run(
130 &self,
131 stream: &Stream,
132 _workspace: Workspace<'_>,
133 args: UnsortedSegmentMinArgs<'_, T>,
134 ) -> Result<()> {
135 self.can_implement(&args)?;
136 let total = (self.desc.num_segments as i64) * (self.desc.embedding_dim as i64);
137 if total == 0 {
138 return Ok(());
139 }
140 let in_ptr = args.input.data.as_raw().0 as *const c_void;
141 let id_ptr = args.segment_ids.data.as_raw().0 as *const c_void;
142 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
143 let stream_ptr = stream.as_raw() as *mut c_void;
144 let status = match T::KIND {
145 ElementKind::F32 => unsafe {
146 baracuda_kernels_sys::baracuda_kernels_unsorted_segment_min_f32_run(
147 self.desc.num_inputs,
148 self.desc.embedding_dim,
149 self.desc.num_segments,
150 in_ptr,
151 id_ptr,
152 out_ptr,
153 core::ptr::null_mut(),
154 0,
155 stream_ptr,
156 )
157 },
158 ElementKind::F64 => unsafe {
159 baracuda_kernels_sys::baracuda_kernels_unsorted_segment_min_f64_run(
160 self.desc.num_inputs,
161 self.desc.embedding_dim,
162 self.desc.num_segments,
163 in_ptr,
164 id_ptr,
165 out_ptr,
166 core::ptr::null_mut(),
167 0,
168 stream_ptr,
169 )
170 },
171 _ => {
172 return Err(Error::Unsupported(
173 "baracuda-kernels::UnsortedSegmentMinPlan::run reached an unimplemented dtype",
174 ));
175 }
176 };
177 map_status(status)
178 }
179}