baracuda_kernels/sort/
sort_backward.rs1use core::ffi::c_void;
7use core::marker::PhantomData;
8
9use baracuda_cutlass::{Error, Result};
10use baracuda_driver::Stream;
11use baracuda_kernels_types::{
12 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, SortKind, TensorMut,
13 TensorRef, Workspace,
14};
15
16use super::map_status;
17use super::sort::{build_sku, validate_sort_desc};
18
19#[derive(Copy, Clone, Debug)]
21pub struct SortBackwardDescriptor {
22 pub batch: i32,
24 pub row_len: i32,
26 pub element: ElementKind,
28}
29
30pub struct SortBackwardArgs<'a, T: Element> {
32 pub dy: TensorRef<'a, T, 2>,
34 pub indices: TensorRef<'a, i32, 2>,
36 pub dx: TensorMut<'a, T, 2>,
38}
39
40pub struct SortBackwardPlan<T: Element> {
58 desc: SortBackwardDescriptor,
59 sku: KernelSku,
60 _marker: PhantomData<T>,
61}
62
63impl<T: Element> SortBackwardPlan<T> {
64 pub fn select(
66 _stream: &Stream,
67 desc: &SortBackwardDescriptor,
68 _pref: PlanPreference,
69 ) -> Result<Self> {
70 validate_sort_desc(
71 desc.batch,
72 desc.row_len,
73 desc.element,
74 T::KIND,
75 "SortBackwardPlan",
76 )?;
77 if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
78 return Err(Error::Unsupported(
79 "baracuda-kernels::SortBackwardPlan: today only f32 / f64 grads supported",
80 ));
81 }
82 let sku = build_sku::<T>(SortKind::SortBackward);
83 Ok(Self {
84 desc: *desc,
85 sku,
86 _marker: PhantomData,
87 })
88 }
89
90 pub fn can_implement(&self, args: &SortBackwardArgs<'_, T>) -> Result<()> {
92 let expected = [self.desc.batch, self.desc.row_len];
93 if args.dy.shape != expected
94 || args.indices.shape != expected
95 || args.dx.shape != expected
96 {
97 return Err(Error::InvalidProblem(
98 "baracuda-kernels::SortBackwardPlan: tensor shapes != [batch, row_len]",
99 ));
100 }
101 Ok(())
102 }
103
104 #[inline]
106 pub fn workspace_size(&self) -> usize {
107 0
108 }
109
110 #[inline]
112 pub fn sku(&self) -> KernelSku {
113 self.sku
114 }
115
116 #[inline]
118 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
119 self.sku.precision_guarantee
120 }
121
122 pub fn run(
124 &self,
125 stream: &Stream,
126 _workspace: Workspace<'_>,
127 args: SortBackwardArgs<'_, T>,
128 ) -> Result<()> {
129 self.can_implement(&args)?;
130 if self.desc.batch == 0 || self.desc.row_len == 0 {
131 return Ok(());
132 }
133 let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
134 let idx_ptr = args.indices.data.as_raw().0 as *const c_void;
135 let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
136 let stream_ptr = stream.as_raw() as *mut c_void;
137
138 let status = match T::KIND {
139 ElementKind::F32 => unsafe {
140 baracuda_kernels_sys::baracuda_kernels_sort_backward_f32_run(
141 self.desc.batch,
142 self.desc.row_len,
143 dy_ptr,
144 idx_ptr,
145 dx_ptr,
146 core::ptr::null_mut(),
147 0,
148 stream_ptr,
149 )
150 },
151 ElementKind::F64 => unsafe {
152 baracuda_kernels_sys::baracuda_kernels_sort_backward_f64_run(
153 self.desc.batch,
154 self.desc.row_len,
155 dy_ptr,
156 idx_ptr,
157 dx_ptr,
158 core::ptr::null_mut(),
159 0,
160 stream_ptr,
161 )
162 },
163 _ => {
164 return Err(Error::Unsupported(
165 "baracuda-kernels::SortBackwardPlan::run reached an unimplemented dtype \
166 — select() should have caught this",
167 ));
168 }
169 };
170 map_status(status)
171 }
172}