Skip to main content

cubek_convolution/launch/
strategy.rs

1use std::fmt::Display;
2
3use serde::{Deserialize, Serialize};
4
5use crate::definition::ConvBlueprint;
6
7/// Top-level user-facing strategy for `launch_ref`.
8///
9/// `Specific` selects an algorithm and tile-matmul kind, letting the routine
10/// infer the rest. `Forced` bypasses inference and uses the supplied blueprint
11/// directly (the algorithm tag must still be provided so the kernel-side
12/// generic dispatch can pick the right reading-strategy implementation).
13#[derive(Clone, Debug)]
14pub enum Strategy {
15    /// User picks the algorithm and tile-matmul kind. Tiling/swizzle/etc. are
16    /// inferred from the problem.
17    Inferred {
18        algorithm: ConvAlgorithm,
19        tile_kind: AcceleratedTileKind,
20    },
21    /// User supplies a pre-built blueprint. The algorithm tag tells the launcher
22    /// which kernel generic to instantiate; the tiling/swizzle/etc. come from
23    /// the blueprint. The tile-matmul kind comes from the blueprint as well.
24    Forced {
25        algorithm: ConvAlgorithm,
26        blueprint: ConvBlueprint,
27    },
28}
29
30/// The convolution-side algorithm enum. Subsumes the previous
31/// `ReadingStrategy` axis and the Simple/Specialized split. A single value
32/// here picks one concrete `Routine` impl (see `crate::routines`).
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub enum ConvAlgorithm {
35    SimpleSyncCyclic,
36    SimpleSyncStrided,
37    SimpleSyncTilewise,
38    SimpleAsyncCyclic,
39    SimpleAsyncStrided,
40    SimpleAsyncTma,
41    SpecializedAsyncCyclic,
42    SpecializedAsyncStrided,
43    SpecializedTma,
44}
45
46impl Display for ConvAlgorithm {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        let s = match self {
49            ConvAlgorithm::SimpleSyncCyclic => "simple_sync_cyclic",
50            ConvAlgorithm::SimpleSyncStrided => "simple_sync_strided",
51            ConvAlgorithm::SimpleSyncTilewise => "simple_sync_tilewise",
52            ConvAlgorithm::SimpleAsyncCyclic => "simple_async_cyclic",
53            ConvAlgorithm::SimpleAsyncStrided => "simple_async_strided",
54            ConvAlgorithm::SimpleAsyncTma => "simple_async_tma",
55            ConvAlgorithm::SpecializedAsyncCyclic => "specialized_async_cyclic",
56            ConvAlgorithm::SpecializedAsyncStrided => "specialized_async_strided",
57            ConvAlgorithm::SpecializedTma => "specialized_tma",
58        };
59        f.write_str(s)
60    }
61}
62
63#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
64/// Which tile matmul to use for accelerated algorithms.
65pub enum AcceleratedTileKind {
66    #[default]
67    Cmma,
68    Mma,
69}
70
71impl Display for AcceleratedTileKind {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            AcceleratedTileKind::Cmma => f.write_str("cmma"),
75            AcceleratedTileKind::Mma => f.write_str("mma"),
76        }
77    }
78}
79
80impl Display for Strategy {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        match self {
83            Strategy::Inferred {
84                algorithm,
85                tile_kind,
86            } => write!(f, "{algorithm}_{tile_kind}"),
87            Strategy::Forced {
88                algorithm,
89                blueprint: _,
90            } => write!(f, "{algorithm}_forced"),
91        }
92    }
93}