cubek_convolution/kernels/
launch.rs1use std::fmt::Display;
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Clone)]
6pub struct ConvolutionArgs<const N_SPATIAL: usize> {
7 pub stride: [usize; N_SPATIAL],
8 pub padding: [usize; N_SPATIAL],
9 pub dilation: [usize; N_SPATIAL],
10}
11
12pub enum Strategy {
13 Simple {
14 read_strategy: ReadingStrategy,
15 tile_kind: AcceleratedTileKind,
16 },
17}
18
19#[derive(Debug, Clone, Copy)]
20pub enum ReadingStrategy {
22 Cyclic,
23 Strided,
24 Tilewise,
25 AsyncCyclic,
26 AsyncStrided,
27 Tma,
28}
29
30#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
31pub enum AcceleratedTileKind {
33 #[default]
34 Cmma,
35 Mma,
36}
37
38impl Display for AcceleratedTileKind {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 AcceleratedTileKind::Cmma => f.write_str("cmma"),
44 AcceleratedTileKind::Mma => f.write_str("mma"),
45 }
46 }
47}
48
49impl Display for ReadingStrategy {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 ReadingStrategy::Cyclic => f.write_str("cyclic"),
53 ReadingStrategy::Strided => f.write_str("strided"),
54 ReadingStrategy::Tilewise => f.write_str("tilewise"),
55 ReadingStrategy::AsyncCyclic => f.write_str("async_cyclic"),
56 ReadingStrategy::AsyncStrided => f.write_str("async_strided"),
57 ReadingStrategy::Tma => f.write_str("tma"),
58 }
59 }
60}