cubek_convolution/kernels/
launch.rs

1use std::fmt::Display;
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Clone)]
6pub struct ConvolutionArgs<const N_SPATIAL: usize> {
7    pub stride: [usize; N_SPATIAL],
8    pub padding: [usize; N_SPATIAL],
9    pub dilation: [usize; N_SPATIAL],
10}
11
12pub enum Strategy {
13    Simple {
14        read_strategy: ReadingStrategy,
15        tile_kind: AcceleratedTileKind,
16    },
17}
18
19#[derive(Debug, Clone, Copy)]
20/// Which reader to use in simple algorithms
21pub enum ReadingStrategy {
22    Cyclic,
23    Strided,
24    Tilewise,
25    AsyncCyclic,
26    AsyncStrided,
27    Tma,
28}
29
30#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
31/// Which tile matmul to use for accelerated algorithms
32pub enum AcceleratedTileKind {
33    #[default]
34    Cmma,
35    Mma,
36}
37
38// Display implementations are used to combine and save names when autotuning.
39
40impl Display for AcceleratedTileKind {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            AcceleratedTileKind::Cmma => f.write_str("cmma"),
44            AcceleratedTileKind::Mma => f.write_str("mma"),
45        }
46    }
47}
48
49impl Display for ReadingStrategy {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        match self {
52            ReadingStrategy::Cyclic => f.write_str("cyclic"),
53            ReadingStrategy::Strided => f.write_str("strided"),
54            ReadingStrategy::Tilewise => f.write_str("tilewise"),
55            ReadingStrategy::AsyncCyclic => f.write_str("async_cyclic"),
56            ReadingStrategy::AsyncStrided => f.write_str("async_strided"),
57            ReadingStrategy::Tma => f.write_str("tma"),
58        }
59    }
60}