Skip to main content

baracuda_kernels/shape_layout/
write_slice.rs

1//! `write_slice` plan — Phase 13.1 trailblazer.
2//!
3//! `write_slice(dest, source, ranges) -> dest`:
4//!
5//!   `dest[start_0..end_0, ..., start_{N-1}..end_{N-1}] = source`
6//!
7//! Assign semantics (not accumulate — that distinguishes
8//! [`WriteSlicePlan`] from `ScatterAddPlan`). Drives Fuel team's
9//! persistent KV-cache append during autoregressive decoding —
10//! step 9c E.3.3 of their Phase 7.6 integration.
11//!
12//! Dtype coverage spans the entire baracuda element bank via
13//! byte-width dispatch (`sizeof(T) ∈ {1, 2, 4, 8, 16}`), with a
14//! separate nibble-packed kernel for [`S4`] / [`U4`]. Bound is
15//! `T: DeviceRepr + Copy + 'static` (same as [`TensorRef`]) so the
16//! same plan covers `Element`-family, `IntElement`-family, and
17//! `FpElement`-family dtypes uniformly.
18//!
19//! No backward — `write_slice` is non-differentiable in Fuel's
20//! autograd model.
21//!
22//! ## Fast paths
23//!
24//! 1. **Full-width minor axes** — when `ranges[i] == (0, dest_shape[i])`
25//!    for all `i > 0`, the source maps to one contiguous chunk of
26//!    `dest` starting at offset `start_0 * stride[0] * sizeof(T)`. A
27//!    single `cuMemcpyDtoDAsync` does the copy. This is the KV-cache
28//!    append shape and the most performance-critical case.
29//! 2. **Whole dest covered** — when source-shape == dest-shape and
30//!    ranges fully cover dest, a single `cuMemcpyDtoDAsync` of the
31//!    whole buffer (degenerate of case 1).
32//! 3. **Otherwise** — generic per-slab-element kernel. One thread per
33//!    source element computes the dest linear offset from the slab
34//!    coord shifted by `range_start`.
35//!
36//! ## S4 / U4 constraint
37//!
38//! Nibble-packed dtypes pack two elements per `u8`. To avoid
39//! read-modify-write across the byte boundary, the trailblazer
40//! requires that `start_{N-1}` and `end_{N-1}` on the innermost axis
41//! be **even**. A non-even innermost range returns
42//! [`Error::Unsupported`] at `select` time.
43
44use core::ffi::c_void;
45use core::marker::PhantomData;
46
47use baracuda_cutlass::{Error, Result};
48use baracuda_driver::Stream;
49use baracuda_kernels_types::{
50    ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
51    PrecisionGuarantee, ShapeLayoutKind, TensorMut, TensorRef, Workspace,
52};
53use baracuda_types::DeviceRepr;
54
55/// Descriptor for a `write_slice` op.
56///
57/// `dest_shape[d]` is the per-axis extent of the destination tensor.
58/// `source_shape[d]` must equal `ranges[d].1 - ranges[d].0` for every
59/// axis (the slab extent). `ranges[d] = (start, end)` selects the
60/// inclusive-start / exclusive-end window on axis `d`.
61/// `element` is the logical element kind of both tensors (they share
62/// dtype). Used to drive byte-width / nibble dispatch.
63#[derive(Copy, Clone, Debug)]
64pub struct WriteSliceDescriptor<const N: usize> {
65    /// Shape of the destination tensor.
66    pub dest_shape: [i32; N],
67    /// Shape of the source tensor (== `ranges[i].1 - ranges[i].0`
68    /// per axis).
69    pub source_shape: [i32; N],
70    /// Per-axis `(start, end)` window. `0 ≤ start ≤ end ≤ dest_shape[d]`.
71    pub ranges: [(i32, i32); N],
72    /// Element kind of both tensors. Used to compute the byte width
73    /// (and to detect S4 / U4 for the nibble path).
74    pub element: ElementKind,
75}
76
77/// Args bundle for a `write_slice` launch.
78///
79/// `dest` is mutated in place. `source` is read once. Both must be
80/// contiguous row-major with zero offset relative to their backing
81/// device buffer (Fuel's plan layer materializes strided / offset
82/// inputs upstream via `Contiguize`).
83pub struct WriteSliceArgs<'a, T: DeviceRepr + Copy + 'static, const N: usize> {
84    /// Destination tensor — written in the per-axis range window.
85    /// Bytes outside the window are untouched.
86    pub dest: TensorMut<'a, T, N>,
87    /// Source tensor — same dtype as `dest`, shape == slab extent.
88    pub source: TensorRef<'a, T, N>,
89}
90
91/// `write_slice` plan.
92///
93/// `dest[start_0..end_0, ..., start_{N-1}..end_{N-1}] = source` —
94/// assign (not accumulate). Drives Fuel team's persistent KV-cache
95/// append.
96///
97/// **When to use**: in-place per-axis range write. Distinct from
98/// [`ScatterAddPlan`](crate::ScatterAddPlan) (which accumulates
99/// per-index) and from [`PadPlan`](crate::PadPlan) (which produces a
100/// larger output tensor). No backward — non-differentiable.
101///
102/// **Dtypes**: every byte-aligned element kind in baracuda's element
103/// bank — `f16, bf16, f32, F32Strict, f64, i32, i64, Bool, S8, U8,
104/// Fp8E4M3, Fp8E5M2, Complex32, Complex64`. Plus nibble-packed
105/// `S4 / U4` with the even-alignment constraint on the innermost axis.
106/// `Bin` (1-bit packed) is out of scope.
107///
108/// **Shape limits**: rank in `[1, 8]`; per-axis
109/// `0 ≤ start ≤ end ≤ dest_shape[d]`; `source_shape[d] = end - start`.
110///
111/// **Workspace**: none.
112///
113/// **Precision guarantee**: deterministic, bit-stable, bit-exact (no
114/// arithmetic — pure memcpy / index + copy).
115pub struct WriteSlicePlan<T: DeviceRepr + Copy + 'static, const N: usize> {
116    desc: WriteSliceDescriptor<N>,
117    sku: KernelSku,
118    byte_width: i32,
119    is_nibble: bool,
120    /// Fast-path discriminant computed once at `select` time.
121    fast_path: FastPath,
122    _marker: PhantomData<T>,
123}
124
125#[derive(Copy, Clone, Debug)]
126enum FastPath {
127    /// Source covers exactly the dest (whole-buffer copy).
128    WholeDest,
129    /// `ranges[i] == (0, dest_shape[i])` for all `i > 0` — the slab is
130    /// one contiguous chunk in dest. Offset (in elements) of the
131    /// chunk's start is stored.
132    ContiguousChunk { dest_offset_elems: i64, source_numel: i64 },
133    /// Neither fast path applies — fall through to the generic kernel.
134    Generic,
135}
136
137impl<T: DeviceRepr + Copy + 'static, const N: usize> WriteSlicePlan<T, N> {
138    /// Pick a kernel for `desc`. Validates rank, range bounds, source
139    /// shape consistency, dtype coverage, and the nibble-axis-alignment
140    /// constraint for S4 / U4. Detects the available fast path.
141    pub fn select(
142        _stream: &Stream,
143        desc: &WriteSliceDescriptor<N>,
144        _pref: PlanPreference,
145    ) -> Result<Self> {
146        if N == 0 {
147            return Err(Error::InvalidProblem(
148                "baracuda-kernels::WriteSlicePlan: rank-0 tensors not supported",
149            ));
150        }
151        if N > 8 {
152            return Err(Error::Unsupported(
153                "baracuda-kernels::WriteSlicePlan: tensor rank > 8 not supported",
154            ));
155        }
156        // Validate ranges + source shape.
157        for d in 0..N {
158            let (s, e) = desc.ranges[d];
159            if s < 0 || e < s || e > desc.dest_shape[d] {
160                return Err(Error::InvalidProblem(
161                    "baracuda-kernels::WriteSlicePlan: ranges[d] must satisfy \
162                     0 <= start <= end <= dest_shape[d]",
163                ));
164            }
165            if desc.source_shape[d] != e - s {
166                return Err(Error::InvalidProblem(
167                    "baracuda-kernels::WriteSlicePlan: source_shape[d] must equal \
168                     ranges[d].1 - ranges[d].0",
169                ));
170            }
171            if desc.dest_shape[d] < 0 {
172                return Err(Error::InvalidProblem(
173                    "baracuda-kernels::WriteSlicePlan: dest_shape dims must be non-negative",
174                ));
175            }
176        }
177
178        let (byte_width, is_nibble) = match dispatch_kind(desc.element) {
179            Some(b) => b,
180            None => {
181                return Err(Error::Unsupported(
182                    "baracuda-kernels::WriteSlicePlan: dtype out of scope. Supported set: \
183                     {f16, bf16, f32, F32Strict, f64, i32, i64, Bool, S8, U8, S4, U4, \
184                      Fp8E4M3, Fp8E5M2, Complex32, Complex64}",
185                ));
186            }
187        };
188
189        // Nibble-axis-alignment guard. Both start and end on the
190        // innermost axis must be even so no byte straddles two halves
191        // of the kernel write set.
192        if is_nibble {
193            let (s, e) = desc.ranges[N - 1];
194            if (s & 1) != 0 || (e & 1) != 0 {
195                return Err(Error::Unsupported(
196                    "baracuda-kernels::WriteSlicePlan: WriteSlice on S4 / U4 requires \
197                     even start/end on innermost axis (no read-modify-write at byte \
198                     boundary in the trailblazer kernel)",
199                ));
200            }
201            // Also require the innermost dest extent to be even — the
202            // nibble byte-shape on the innermost axis is dest_shape/2.
203            if (desc.dest_shape[N - 1] & 1) != 0 {
204                return Err(Error::Unsupported(
205                    "baracuda-kernels::WriteSlicePlan: WriteSlice on S4 / U4 requires \
206                     even dest_shape on innermost axis",
207                ));
208            }
209        }
210
211        let fast_path = detect_fast_path::<N>(desc);
212
213        let precision_guarantee = PrecisionGuarantee {
214            math_precision: MathPrecision::F32,
215            accumulator: ElementKind::F32,
216            // No arithmetic — pure memcpy + linear write.
217            bit_stable_on_same_hardware: true,
218            deterministic: true,
219        };
220        let sku = KernelSku {
221            category: OpCategory::ShapeLayout,
222            op: ShapeLayoutKind::WriteSlice as u16,
223            element: desc.element,
224            aux_element: None,
225            layout: None,
226            epilogue: None,
227            arch: ArchSku::Sm80,
228            backend: BackendKind::Bespoke,
229            precision_guarantee,
230        };
231        Ok(Self {
232            desc: *desc,
233            sku,
234            byte_width,
235            is_nibble,
236            fast_path,
237            _marker: PhantomData,
238        })
239    }
240
241    /// Validate `args` against the descriptor: shapes match, device
242    /// buffers are large enough.
243    pub fn can_implement(&self, args: &WriteSliceArgs<'_, T, N>) -> Result<()> {
244        if args.dest.shape != self.desc.dest_shape {
245            return Err(Error::InvalidProblem(
246                "baracuda-kernels::WriteSlicePlan: dest shape mismatch with descriptor",
247            ));
248        }
249        if args.source.shape != self.desc.source_shape {
250            return Err(Error::InvalidProblem(
251                "baracuda-kernels::WriteSlicePlan: source shape mismatch with descriptor",
252            ));
253        }
254        // The kernel assumes both tensors are contiguous row-major.
255        if !args.dest.is_contiguous() {
256            return Err(Error::Unsupported(
257                "baracuda-kernels::WriteSlicePlan: dest must be contiguous row-major",
258            ));
259        }
260        if !args.source.is_contiguous() {
261            return Err(Error::Unsupported(
262                "baracuda-kernels::WriteSlicePlan: source must be contiguous row-major",
263            ));
264        }
265        // Buffer-size checks. Nibble case: storage element count is
266        // numel/2 (rounded up — innermost extent is even by select-time
267        // guard, so numel is even too on the nibble path).
268        let dest_numel = product_i64(self.desc.dest_shape);
269        let source_numel = product_i64(self.desc.source_shape);
270        let dest_storage = if self.is_nibble { (dest_numel + 1) / 2 } else { dest_numel };
271        let source_storage = if self.is_nibble { (source_numel + 1) / 2 } else { source_numel };
272        if (args.dest.data.len() as i64) < dest_storage {
273            return Err(Error::BufferTooSmall {
274                needed: dest_storage as usize,
275                got: args.dest.data.len(),
276            });
277        }
278        if (args.source.data.len() as i64) < source_storage {
279            return Err(Error::BufferTooSmall {
280                needed: source_storage as usize,
281                got: args.source.data.len(),
282            });
283        }
284        Ok(())
285    }
286
287    /// Workspace size in bytes. Always `0`.
288    #[inline]
289    pub fn workspace_size(&self) -> usize {
290        0
291    }
292
293    /// Identity of the kernel this plan picked.
294    #[inline]
295    pub fn sku(&self) -> KernelSku {
296        self.sku
297    }
298
299    /// Numerical guarantees for this plan's kernel.
300    #[inline]
301    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
302        self.sku.precision_guarantee
303    }
304
305    /// Launch on `stream`. `workspace` is ignored (always zero).
306    pub fn run(
307        &self,
308        stream: &Stream,
309        _workspace: Workspace<'_>,
310        args: WriteSliceArgs<'_, T, N>,
311    ) -> Result<()> {
312        self.can_implement(&args)?;
313        let source_numel = product_i64(self.desc.source_shape);
314        if source_numel == 0 {
315            return Ok(());
316        }
317        let dest_ptr_u64 = args.dest.data.as_raw().0;
318        let source_ptr_u64 = args.source.data.as_raw().0;
319        let stream_ptr = stream.as_raw() as *mut c_void;
320
321        // -------------------- Fast paths --------------------
322        match self.fast_path {
323            FastPath::WholeDest | FastPath::ContiguousChunk { .. } => {
324                // Bytes to copy and per-side offsets:
325                //   - source: always starts at offset 0 with source_numel elems
326                //   - dest: starts at `dest_offset_elems` (0 for WholeDest)
327                let (dest_off_elems, copy_elems) = match self.fast_path {
328                    FastPath::WholeDest => (0i64, source_numel),
329                    FastPath::ContiguousChunk { dest_offset_elems, source_numel: n } => {
330                        (dest_offset_elems, n)
331                    }
332                    FastPath::Generic => unreachable!(),
333                };
334                // Byte counts. Nibble: 2 elements per byte (innermost
335                // axis alignment is guaranteed even by select-time
336                // guard, so both offset and count are integer bytes).
337                let (dest_off_bytes, copy_bytes) = if self.is_nibble {
338                    (dest_off_elems / 2, copy_elems / 2)
339                } else {
340                    let bw = self.byte_width as i64;
341                    (dest_off_elems * bw, copy_elems * bw)
342                };
343                return copy_d2d_async(
344                    dest_ptr_u64.wrapping_add(dest_off_bytes as u64),
345                    source_ptr_u64,
346                    copy_bytes as usize,
347                    stream_ptr,
348                );
349            }
350            FastPath::Generic => {}
351        }
352
353        // -------------------- Generic kernel path --------------------
354        let rank = N as i32;
355        let dest_shape = self.desc.dest_shape;
356        let source_shape = self.desc.source_shape;
357        let mut range_start = [0i32; N];
358        for d in 0..N {
359            range_start[d] = self.desc.ranges[d].0;
360        }
361
362        let status = if self.is_nibble {
363            // Nibble kernel: shape arrays on the innermost axis are
364            // byte-counted (= elements / 2). select() guarantees both
365            // innermost dest extent and innermost start are even, so
366            // the divisions are exact.
367            let mut dest_byte_shape = dest_shape;
368            let mut source_byte_shape = source_shape;
369            let mut range_start_bytes = range_start;
370            dest_byte_shape[N - 1] /= 2;
371            source_byte_shape[N - 1] /= 2;
372            range_start_bytes[N - 1] /= 2;
373            let source_byte_numel = source_numel / 2;
374            unsafe {
375                baracuda_kernels_sys::baracuda_kernels_write_slice_nibble_run(
376                    dest_ptr_u64 as *mut c_void,
377                    source_ptr_u64 as *const c_void,
378                    source_byte_numel,
379                    rank,
380                    dest_byte_shape.as_ptr(),
381                    source_byte_shape.as_ptr(),
382                    range_start_bytes.as_ptr(),
383                    core::ptr::null_mut(),
384                    0,
385                    stream_ptr,
386                )
387            }
388        } else {
389            // Byte-aligned: dispatch on byte width.
390            unsafe {
391                let dest = dest_ptr_u64 as *mut c_void;
392                let source = source_ptr_u64 as *const c_void;
393                let ds = dest_shape.as_ptr();
394                let ss = source_shape.as_ptr();
395                let rs = range_start.as_ptr();
396                match self.byte_width {
397                    1 => baracuda_kernels_sys::baracuda_kernels_write_slice_b1_run(
398                        dest, source, source_numel, rank, ds, ss, rs,
399                        core::ptr::null_mut(), 0, stream_ptr,
400                    ),
401                    2 => baracuda_kernels_sys::baracuda_kernels_write_slice_b2_run(
402                        dest, source, source_numel, rank, ds, ss, rs,
403                        core::ptr::null_mut(), 0, stream_ptr,
404                    ),
405                    4 => baracuda_kernels_sys::baracuda_kernels_write_slice_b4_run(
406                        dest, source, source_numel, rank, ds, ss, rs,
407                        core::ptr::null_mut(), 0, stream_ptr,
408                    ),
409                    8 => baracuda_kernels_sys::baracuda_kernels_write_slice_b8_run(
410                        dest, source, source_numel, rank, ds, ss, rs,
411                        core::ptr::null_mut(), 0, stream_ptr,
412                    ),
413                    16 => baracuda_kernels_sys::baracuda_kernels_write_slice_b16_run(
414                        dest, source, source_numel, rank, ds, ss, rs,
415                        core::ptr::null_mut(), 0, stream_ptr,
416                    ),
417                    _ => return Err(Error::Unsupported(
418                        "baracuda-kernels::WriteSlicePlan::run: unsupported byte width \
419                         (select() should have caught this)",
420                    )),
421                }
422            }
423        };
424        map_status(status)
425    }
426}
427
428/// Per-`ElementKind` byte width + nibble-flag mapping. Returns `None`
429/// for unsupported kinds (today: `Bin`).
430fn dispatch_kind(k: ElementKind) -> Option<(i32, bool)> {
431    Some(match k {
432        ElementKind::Bool => (1, false),
433        ElementKind::S8 => (1, false),
434        ElementKind::U8 => (1, false),
435        ElementKind::Fp8E4M3 => (1, false),
436        ElementKind::Fp8E5M2 => (1, false),
437        ElementKind::F16 => (2, false),
438        ElementKind::Bf16 => (2, false),
439        ElementKind::F32 => (4, false),
440        ElementKind::F32Strict => (4, false),
441        ElementKind::I32 => (4, false),
442        ElementKind::F64 => (8, false),
443        ElementKind::I64 => (8, false),
444        ElementKind::Complex32 => (8, false),
445        ElementKind::Complex64 => (16, false),
446        ElementKind::S4 => (1, true),
447        ElementKind::U4 => (1, true),
448        // Bin (1-bit packed) is out of scope — distinct packing model.
449        ElementKind::Bin => return None,
450    })
451}
452
453fn detect_fast_path<const N: usize>(desc: &WriteSliceDescriptor<N>) -> FastPath {
454    // WholeDest: ranges cover every axis fully and source_shape == dest_shape.
455    let mut whole = true;
456    for d in 0..N {
457        let (s, e) = desc.ranges[d];
458        if s != 0 || e != desc.dest_shape[d] {
459            whole = false;
460            break;
461        }
462    }
463    if whole {
464        return FastPath::WholeDest;
465    }
466
467    // ContiguousChunk: ranges[i] == (0, dest_shape[i]) for all i > 0.
468    // The slab is one contiguous block in dest's row-major layout
469    // starting at `start_0 * (product of dest_shape[1..])` elements.
470    if N == 1 {
471        // Rank-1 partial — contiguous chunk by definition (just one axis).
472        let (s, _) = desc.ranges[0];
473        let source_numel = product_i64(desc.source_shape);
474        return FastPath::ContiguousChunk {
475            dest_offset_elems: s as i64,
476            source_numel,
477        };
478    }
479    let mut minors_full = true;
480    for d in 1..N {
481        let (s, e) = desc.ranges[d];
482        if s != 0 || e != desc.dest_shape[d] {
483            minors_full = false;
484            break;
485        }
486    }
487    if minors_full {
488        let mut minor_prod: i64 = 1;
489        for d in 1..N {
490            minor_prod = minor_prod.saturating_mul(desc.dest_shape[d] as i64);
491        }
492        let start_0 = desc.ranges[0].0 as i64;
493        let source_numel = product_i64(desc.source_shape);
494        return FastPath::ContiguousChunk {
495            dest_offset_elems: start_0 * minor_prod,
496            source_numel,
497        };
498    }
499    FastPath::Generic
500}
501
502#[inline]
503fn product_i64<const N: usize>(shape: [i32; N]) -> i64 {
504    let mut p: i64 = 1;
505    for d in 0..N {
506        p = p.saturating_mul(shape[d] as i64);
507    }
508    p
509}
510
511/// Device-to-device async copy on `stream`. Thin wrapper around
512/// `cuMemcpyDtoDAsync_v2` — matches the same pattern used by the
513/// `kthvalue` plan's H2D / D2H helpers.
514fn copy_d2d_async(
515    dst_dev: u64,
516    src_dev: u64,
517    bytes: usize,
518    stream: *mut c_void,
519) -> Result<()> {
520    if bytes == 0 {
521        return Ok(());
522    }
523    #[allow(non_camel_case_types)]
524    type CUresult = i32;
525    unsafe extern "system" {
526        fn cuMemcpyDtoDAsync_v2(
527            dst_device: u64,
528            src_device: u64,
529            byte_count: usize,
530            h_stream: *mut c_void,
531        ) -> CUresult;
532    }
533    let status = unsafe { cuMemcpyDtoDAsync_v2(dst_dev, src_dev, bytes, stream) };
534    if status != 0 {
535        return Err(Error::CutlassInternal(-status));
536    }
537    Ok(())
538}
539
540fn map_status(code: i32) -> Result<()> {
541    match code {
542        0 => Ok(()),
543        1 => Err(Error::MisalignedOperand),
544        2 => Err(Error::InvalidProblem(
545            "baracuda-kernels-sys reported invalid problem",
546        )),
547        3 => Err(Error::Unsupported(
548            "baracuda-kernels-sys reported unsupported configuration",
549        )),
550        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
551        n => Err(Error::CutlassInternal(n)),
552    }
553}