cubek_convolution/launch/
strategy.rs1use std::fmt::Display;
2
3use serde::{Deserialize, Serialize};
4
5use crate::definition::ConvBlueprint;
6
7#[derive(Clone, Debug)]
14pub enum Strategy {
15 Inferred {
18 algorithm: ConvAlgorithm,
19 tile_kind: AcceleratedTileKind,
20 },
21 Forced {
25 algorithm: ConvAlgorithm,
26 blueprint: ConvBlueprint,
27 },
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub enum ConvAlgorithm {
35 SimpleSyncCyclic,
36 SimpleSyncStrided,
37 SimpleSyncTilewise,
38 SimpleAsyncCyclic,
39 SimpleAsyncStrided,
40 SimpleAsyncTma,
41 SpecializedAsyncCyclic,
42 SpecializedAsyncStrided,
43 SpecializedTma,
44}
45
46impl Display for ConvAlgorithm {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 let s = match self {
49 ConvAlgorithm::SimpleSyncCyclic => "simple_sync_cyclic",
50 ConvAlgorithm::SimpleSyncStrided => "simple_sync_strided",
51 ConvAlgorithm::SimpleSyncTilewise => "simple_sync_tilewise",
52 ConvAlgorithm::SimpleAsyncCyclic => "simple_async_cyclic",
53 ConvAlgorithm::SimpleAsyncStrided => "simple_async_strided",
54 ConvAlgorithm::SimpleAsyncTma => "simple_async_tma",
55 ConvAlgorithm::SpecializedAsyncCyclic => "specialized_async_cyclic",
56 ConvAlgorithm::SpecializedAsyncStrided => "specialized_async_strided",
57 ConvAlgorithm::SpecializedTma => "specialized_tma",
58 };
59 f.write_str(s)
60 }
61}
62
63#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
64pub enum AcceleratedTileKind {
66 #[default]
67 Cmma,
68 Mma,
69}
70
71impl Display for AcceleratedTileKind {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 AcceleratedTileKind::Cmma => f.write_str("cmma"),
75 AcceleratedTileKind::Mma => f.write_str("mma"),
76 }
77 }
78}
79
80impl Display for Strategy {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 Strategy::Inferred {
84 algorithm,
85 tile_kind,
86 } => write!(f, "{algorithm}_{tile_kind}"),
87 Strategy::Forced {
88 algorithm,
89 blueprint: _,
90 } => write!(f, "{algorithm}_forced"),
91 }
92 }
93}