1use core::ffi::c_void;
45use core::marker::PhantomData;
46
47use baracuda_cutlass::{Error, Result};
48use baracuda_driver::Stream;
49use baracuda_kernels_types::{
50 ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
51 PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
52};
53use baracuda_types::DeviceRepr;
54
55#[derive(Copy, Clone, Debug)]
64pub struct WriteSliceDescriptor<const N: usize> {
65 pub dest_shape: [i32; N],
67 pub source_shape: [i32; N],
70 pub ranges: [(i32, i32); N],
72 pub element: ElementKind,
75}
76
77pub struct WriteSliceArgs<'a, T: DeviceRepr + Copy + 'static, const N: usize> {
84 pub dest: TensorMut<'a, T, N>,
87 pub source: TensorRef<'a, T, N>,
89}
90
91pub struct WriteSlicePlan<T: DeviceRepr + Copy + 'static, const N: usize> {
116 desc: WriteSliceDescriptor<N>,
117 sku: KernelSku,
118 byte_width: i32,
119 is_nibble: bool,
120 fast_path: FastPath,
122 _marker: PhantomData<T>,
123}
124
125#[derive(Copy, Clone, Debug)]
126enum FastPath {
127 WholeDest,
129 ContiguousChunk { dest_offset_elems: i64, source_numel: i64 },
133 Generic,
135}
136
137impl<T: DeviceRepr + Copy + 'static, const N: usize> WriteSlicePlan<T, N> {
138 pub fn select(
142 _stream: &Stream,
143 desc: &WriteSliceDescriptor<N>,
144 _pref: PlanPreference,
145 ) -> Result<Self> {
146 if N == 0 {
147 return Err(Error::InvalidProblem(
148 "baracuda-kernels::WriteSlicePlan: rank-0 tensors not supported",
149 ));
150 }
151 if N > 8 {
152 return Err(Error::Unsupported(
153 "baracuda-kernels::WriteSlicePlan: tensor rank > 8 not supported",
154 ));
155 }
156 for d in 0..N {
158 let (s, e) = desc.ranges[d];
159 if s < 0 || e < s || e > desc.dest_shape[d] {
160 return Err(Error::InvalidProblem(
161 "baracuda-kernels::WriteSlicePlan: ranges[d] must satisfy \
162 0 <= start <= end <= dest_shape[d]",
163 ));
164 }
165 if desc.source_shape[d] != e - s {
166 return Err(Error::InvalidProblem(
167 "baracuda-kernels::WriteSlicePlan: source_shape[d] must equal \
168 ranges[d].1 - ranges[d].0",
169 ));
170 }
171 if desc.dest_shape[d] < 0 {
172 return Err(Error::InvalidProblem(
173 "baracuda-kernels::WriteSlicePlan: dest_shape dims must be non-negative",
174 ));
175 }
176 }
177
178 let (byte_width, is_nibble) = match dispatch_kind(desc.element) {
179 Some(b) => b,
180 None => {
181 return Err(Error::Unsupported(
182 "baracuda-kernels::WriteSlicePlan: dtype out of scope. Supported set: \
183 {f16, bf16, f32, F32Strict, f64, i32, i64, Bool, S8, U8, S4, U4, \
184 Fp8E4M3, Fp8E5M2, Complex32, Complex64}",
185 ));
186 }
187 };
188
189 if is_nibble {
193 let (s, e) = desc.ranges[N - 1];
194 if (s & 1) != 0 || (e & 1) != 0 {
195 return Err(Error::Unsupported(
196 "baracuda-kernels::WriteSlicePlan: WriteSlice on S4 / U4 requires \
197 even start/end on innermost axis (no read-modify-write at byte \
198 boundary in the trailblazer kernel)",
199 ));
200 }
201 if (desc.dest_shape[N - 1] & 1) != 0 {
204 return Err(Error::Unsupported(
205 "baracuda-kernels::WriteSlicePlan: WriteSlice on S4 / U4 requires \
206 even dest_shape on innermost axis",
207 ));
208 }
209 }
210
211 let fast_path = detect_fast_path::<N>(desc);
212
213 let precision_guarantee = PrecisionGuarantee {
214 math_precision: MathPrecision::F32,
215 accumulator: ElementKind::F32,
216 bit_stable_on_same_hardware: true,
218 deterministic: true,
219 };
220 let sku = KernelSku {
221 category: OpCategory::ShapeLayout,
222 op: ShapeLayoutKind::WriteSlice as u16,
223 element: desc.element,
224 aux_element: None,
225 layout: None,
226 epilogue: None,
227 arch: ArchSku::Sm80,
228 backend: BackendKind::Bespoke,
229 precision_guarantee,
230 };
231 Ok(Self {
232 desc: *desc,
233 sku,
234 byte_width,
235 is_nibble,
236 fast_path,
237 _marker: PhantomData,
238 })
239 }
240
241 pub fn can_implement(&self, args: &WriteSliceArgs<'_, T, N>) -> Result<()> {
244 if args.dest.shape != self.desc.dest_shape {
245 return Err(Error::InvalidProblem(
246 "baracuda-kernels::WriteSlicePlan: dest shape mismatch with descriptor",
247 ));
248 }
249 if args.source.shape != self.desc.source_shape {
250 return Err(Error::InvalidProblem(
251 "baracuda-kernels::WriteSlicePlan: source shape mismatch with descriptor",
252 ));
253 }
254 if !args.dest.is_contiguous() {
256 return Err(Error::Unsupported(
257 "baracuda-kernels::WriteSlicePlan: dest must be contiguous row-major",
258 ));
259 }
260 if !args.source.is_contiguous() {
261 return Err(Error::Unsupported(
262 "baracuda-kernels::WriteSlicePlan: source must be contiguous row-major",
263 ));
264 }
265 let dest_numel = product_i64(self.desc.dest_shape);
269 let source_numel = product_i64(self.desc.source_shape);
270 let dest_storage = if self.is_nibble { (dest_numel + 1) / 2 } else { dest_numel };
271 let source_storage = if self.is_nibble { (source_numel + 1) / 2 } else { source_numel };
272 if (args.dest.data.len() as i64) < dest_storage {
273 return Err(Error::BufferTooSmall {
274 needed: dest_storage as usize,
275 got: args.dest.data.len(),
276 });
277 }
278 if (args.source.data.len() as i64) < source_storage {
279 return Err(Error::BufferTooSmall {
280 needed: source_storage as usize,
281 got: args.source.data.len(),
282 });
283 }
284 Ok(())
285 }
286
287 #[inline]
289 pub fn workspace_size(&self) -> usize {
290 0
291 }
292
293 #[inline]
295 pub fn sku(&self) -> KernelSku {
296 self.sku
297 }
298
299 #[inline]
301 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
302 self.sku.precision_guarantee
303 }
304
305 pub fn run(
307 &self,
308 stream: &Stream,
309 _workspace: Workspace<'_>,
310 args: WriteSliceArgs<'_, T, N>,
311 ) -> Result<()> {
312 self.can_implement(&args)?;
313 let source_numel = product_i64(self.desc.source_shape);
314 if source_numel == 0 {
315 return Ok(());
316 }
317 let dest_ptr_u64 = args.dest.data.as_raw().0;
318 let source_ptr_u64 = args.source.data.as_raw().0;
319 let stream_ptr = stream.as_raw() as *mut c_void;
320
321 match self.fast_path {
323 FastPath::WholeDest | FastPath::ContiguousChunk { .. } => {
324 let (dest_off_elems, copy_elems) = match self.fast_path {
328 FastPath::WholeDest => (0i64, source_numel),
329 FastPath::ContiguousChunk { dest_offset_elems, source_numel: n } => {
330 (dest_offset_elems, n)
331 }
332 FastPath::Generic => unreachable!(),
333 };
334 let (dest_off_bytes, copy_bytes) = if self.is_nibble {
338 (dest_off_elems / 2, copy_elems / 2)
339 } else {
340 let bw = self.byte_width as i64;
341 (dest_off_elems * bw, copy_elems * bw)
342 };
343 return copy_d2d_async(
344 dest_ptr_u64.wrapping_add(dest_off_bytes as u64),
345 source_ptr_u64,
346 copy_bytes as usize,
347 stream_ptr,
348 );
349 }
350 FastPath::Generic => {}
351 }
352
353 let rank = N as i32;
355 let dest_shape = self.desc.dest_shape;
356 let source_shape = self.desc.source_shape;
357 let mut range_start = [0i32; N];
358 for d in 0..N {
359 range_start[d] = self.desc.ranges[d].0;
360 }
361
362 let status = if self.is_nibble {
363 let mut dest_byte_shape = dest_shape;
368 let mut source_byte_shape = source_shape;
369 let mut range_start_bytes = range_start;
370 dest_byte_shape[N - 1] /= 2;
371 source_byte_shape[N - 1] /= 2;
372 range_start_bytes[N - 1] /= 2;
373 let source_byte_numel = source_numel / 2;
374 unsafe {
375 baracuda_kernels_sys::baracuda_kernels_write_slice_nibble_run(
376 dest_ptr_u64 as *mut c_void,
377 source_ptr_u64 as *const c_void,
378 source_byte_numel,
379 rank,
380 dest_byte_shape.as_ptr(),
381 source_byte_shape.as_ptr(),
382 range_start_bytes.as_ptr(),
383 core::ptr::null_mut(),
384 0,
385 stream_ptr,
386 )
387 }
388 } else {
389 unsafe {
391 let dest = dest_ptr_u64 as *mut c_void;
392 let source = source_ptr_u64 as *const c_void;
393 let ds = dest_shape.as_ptr();
394 let ss = source_shape.as_ptr();
395 let rs = range_start.as_ptr();
396 match self.byte_width {
397 1 => baracuda_kernels_sys::baracuda_kernels_write_slice_b1_run(
398 dest, source, source_numel, rank, ds, ss, rs,
399 core::ptr::null_mut(), 0, stream_ptr,
400 ),
401 2 => baracuda_kernels_sys::baracuda_kernels_write_slice_b2_run(
402 dest, source, source_numel, rank, ds, ss, rs,
403 core::ptr::null_mut(), 0, stream_ptr,
404 ),
405 4 => baracuda_kernels_sys::baracuda_kernels_write_slice_b4_run(
406 dest, source, source_numel, rank, ds, ss, rs,
407 core::ptr::null_mut(), 0, stream_ptr,
408 ),
409 8 => baracuda_kernels_sys::baracuda_kernels_write_slice_b8_run(
410 dest, source, source_numel, rank, ds, ss, rs,
411 core::ptr::null_mut(), 0, stream_ptr,
412 ),
413 16 => baracuda_kernels_sys::baracuda_kernels_write_slice_b16_run(
414 dest, source, source_numel, rank, ds, ss, rs,
415 core::ptr::null_mut(), 0, stream_ptr,
416 ),
417 _ => return Err(Error::Unsupported(
418 "baracuda-kernels::WriteSlicePlan::run: unsupported byte width \
419 (select() should have caught this)",
420 )),
421 }
422 }
423 };
424 map_status(status)
425 }
426}
427
428fn dispatch_kind(k: ElementKind) -> Option<(i32, bool)> {
431 Some(match k {
432 ElementKind::Bool => (1, false),
433 ElementKind::S8 => (1, false),
434 ElementKind::U8 => (1, false),
435 ElementKind::Fp8E4M3 => (1, false),
436 ElementKind::Fp8E5M2 => (1, false),
437 ElementKind::F16 => (2, false),
438 ElementKind::Bf16 => (2, false),
439 ElementKind::F32 => (4, false),
440 ElementKind::F32Strict => (4, false),
441 ElementKind::I32 => (4, false),
442 ElementKind::F64 => (8, false),
443 ElementKind::I64 => (8, false),
444 ElementKind::Complex32 => (8, false),
445 ElementKind::Complex64 => (16, false),
446 ElementKind::S4 => (1, true),
447 ElementKind::U4 => (1, true),
448 ElementKind::Bin => return None,
450 })
451}
452
453fn detect_fast_path<const N: usize>(desc: &WriteSliceDescriptor<N>) -> FastPath {
454 let mut whole = true;
456 for d in 0..N {
457 let (s, e) = desc.ranges[d];
458 if s != 0 || e != desc.dest_shape[d] {
459 whole = false;
460 break;
461 }
462 }
463 if whole {
464 return FastPath::WholeDest;
465 }
466
467 if N == 1 {
471 let (s, _) = desc.ranges[0];
473 let source_numel = product_i64(desc.source_shape);
474 return FastPath::ContiguousChunk {
475 dest_offset_elems: s as i64,
476 source_numel,
477 };
478 }
479 let mut minors_full = true;
480 for d in 1..N {
481 let (s, e) = desc.ranges[d];
482 if s != 0 || e != desc.dest_shape[d] {
483 minors_full = false;
484 break;
485 }
486 }
487 if minors_full {
488 let mut minor_prod: i64 = 1;
489 for d in 1..N {
490 minor_prod = minor_prod.saturating_mul(desc.dest_shape[d] as i64);
491 }
492 let start_0 = desc.ranges[0].0 as i64;
493 let source_numel = product_i64(desc.source_shape);
494 return FastPath::ContiguousChunk {
495 dest_offset_elems: start_0 * minor_prod,
496 source_numel,
497 };
498 }
499 FastPath::Generic
500}
501
502#[inline]
503fn product_i64<const N: usize>(shape: [i32; N]) -> i64 {
504 let mut p: i64 = 1;
505 for d in 0..N {
506 p = p.saturating_mul(shape[d] as i64);
507 }
508 p
509}
510
511fn copy_d2d_async(
515 dst_dev: u64,
516 src_dev: u64,
517 bytes: usize,
518 stream: *mut c_void,
519) -> Result<()> {
520 if bytes == 0 {
521 return Ok(());
522 }
523 #[allow(non_camel_case_types)]
524 type CUresult = i32;
525 unsafe extern "system" {
526 fn cuMemcpyDtoDAsync_v2(
527 dst_device: u64,
528 src_device: u64,
529 byte_count: usize,
530 h_stream: *mut c_void,
531 ) -> CUresult;
532 }
533 let status = unsafe { cuMemcpyDtoDAsync_v2(dst_dev, src_dev, bytes, stream) };
534 if status != 0 {
535 return Err(Error::CutlassInternal(-status));
536 }
537 Ok(())
538}
539
540fn map_status(code: i32) -> Result<()> {
541 match code {
542 0 => Ok(()),
543 1 => Err(Error::MisalignedOperand),
544 2 => Err(Error::InvalidProblem(
545 "baracuda-kernels-sys reported invalid problem",
546 )),
547 3 => Err(Error::Unsupported(
548 "baracuda-kernels-sys reported unsupported configuration",
549 )),
550 4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
551 n => Err(Error::CutlassInternal(n)),
552 }
553}