baracuda_kernels/segment/
segment_sum_backward.rs1use core::ffi::c_void;
12use core::marker::PhantomData;
13
14use baracuda_cutlass::{Error, Result};
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SegmentKind, TensorMut,
18 TensorRef, Workspace,
19};
20
21use super::map_status;
22use super::segment_sum::{build_sku, validate_desc, SegDescView};
23
24#[derive(Copy, Clone, Debug)]
26pub struct SegmentSumBackwardDescriptor {
27 pub num_inputs: i32,
29 pub embedding_dim: i32,
31 pub num_segments: i32,
33 pub element: ElementKind,
35}
36
37impl SegDescView for SegmentSumBackwardDescriptor {
38 #[inline]
39 fn view(&self) -> (i32, i32, i32, ElementKind) {
40 (
41 self.num_inputs,
42 self.embedding_dim,
43 self.num_segments,
44 self.element,
45 )
46 }
47}
48
49pub struct SegmentSumBackwardArgs<'a, T: Element> {
51 pub d_output: TensorRef<'a, T, 2>,
53 pub segment_ids: TensorRef<'a, i32, 1>,
55 pub d_input: TensorMut<'a, T, 2>,
57}
58
59pub struct SegmentSumBackwardPlan<T: Element> {
79 desc: SegmentSumBackwardDescriptor,
80 sku: KernelSku,
81 _marker: PhantomData<T>,
82}
83
84impl<T: Element> SegmentSumBackwardPlan<T> {
85 pub fn select(
87 _stream: &Stream,
88 desc: &SegmentSumBackwardDescriptor,
89 _pref: PlanPreference,
90 ) -> Result<Self> {
91 validate_desc(*desc, T::KIND, "SegmentSumBackwardPlan")?;
92 Ok(Self {
93 desc: *desc,
94 sku: build_sku::<T>(SegmentKind::SegmentSumBackward),
95 _marker: PhantomData,
96 })
97 }
98
99 pub fn can_implement(&self, args: &SegmentSumBackwardArgs<'_, T>) -> Result<()> {
101 if args.d_output.shape != [self.desc.num_segments, self.desc.embedding_dim] {
102 return Err(Error::InvalidProblem(
103 "baracuda-kernels::SegmentSumBackwardPlan: d_output shape != [num_segments, D]",
104 ));
105 }
106 if args.segment_ids.shape != [self.desc.num_inputs] {
107 return Err(Error::InvalidProblem(
108 "baracuda-kernels::SegmentSumBackwardPlan: segment_ids shape != [num_inputs]",
109 ));
110 }
111 if args.d_input.shape != [self.desc.num_inputs, self.desc.embedding_dim] {
112 return Err(Error::InvalidProblem(
113 "baracuda-kernels::SegmentSumBackwardPlan: d_input shape != [num_inputs, D]",
114 ));
115 }
116 Ok(())
117 }
118
119 #[inline]
121 pub fn workspace_size(&self) -> usize {
122 0
123 }
124
125 #[inline]
127 pub fn sku(&self) -> KernelSku {
128 self.sku
129 }
130
131 #[inline]
133 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
134 self.sku.precision_guarantee
135 }
136
137 pub fn run(
139 &self,
140 stream: &Stream,
141 _workspace: Workspace<'_>,
142 args: SegmentSumBackwardArgs<'_, T>,
143 ) -> Result<()> {
144 self.can_implement(&args)?;
145 let total = (self.desc.num_inputs as i64) * (self.desc.embedding_dim as i64);
146 if total == 0 {
147 return Ok(());
148 }
149 let do_ptr = args.d_output.data.as_raw().0 as *const c_void;
150 let id_ptr = args.segment_ids.data.as_raw().0 as *const c_void;
151 let di_ptr = args.d_input.data.as_raw().0 as *mut c_void;
152 let stream_ptr = stream.as_raw() as *mut c_void;
153 let status = match T::KIND {
154 ElementKind::F32 => unsafe {
155 baracuda_kernels_sys::baracuda_kernels_segment_sum_backward_f32_run(
156 self.desc.num_inputs,
157 self.desc.embedding_dim,
158 self.desc.num_segments,
159 do_ptr,
160 id_ptr,
161 di_ptr,
162 core::ptr::null_mut(),
163 0,
164 stream_ptr,
165 )
166 },
167 ElementKind::F64 => unsafe {
168 baracuda_kernels_sys::baracuda_kernels_segment_sum_backward_f64_run(
169 self.desc.num_inputs,
170 self.desc.embedding_dim,
171 self.desc.num_segments,
172 do_ptr,
173 id_ptr,
174 di_ptr,
175 core::ptr::null_mut(),
176 0,
177 stream_ptr,
178 )
179 },
180 _ => {
181 return Err(Error::Unsupported(
182 "baracuda-kernels::SegmentSumBackwardPlan::run reached an unimplemented dtype",
183 ))
184 }
185 };
186 map_status(status)
187 }
188}