1use cubecl_core::{Runtime, client::ComputeClient, prelude::TensorHandleRef};
2
3use cubecl_std::tensor::TensorHandle;
4
5use crate::{
6 components::{
7 AttentionIdent, AttentionPartitionSize, AttentionPrecision, AttentionProblem,
8 AttentionSelection, AttentionSetupError, AttentionStageSize, AttentionTileSize,
9 AttentionTilingScheme, AvailableLineSizes, args::TensorInputsLaunch, attention_types::*,
10 batch::HypercubeSelection,
11 },
12 kernels::{Algorithm, dummy::DummyRegisterAlgorithm},
13};
14
15use crate::components::batch::BatchAttentionConfig;
16use crate::components::batch::BatchAttentionFamily;
17
18pub enum Strategy {
19 Tmp,
21}
22
23#[allow(clippy::result_large_err)]
24pub fn launch<R: Runtime, AP: AttentionPrecision>(
25 strategy: &Strategy,
26 client: &ComputeClient<R::Server>,
27 query: TensorHandle<R, QG<AP>>,
28 key: TensorHandle<R, KG<AP>>,
29 value: TensorHandle<R, VG<AP>>,
30 mask: Option<TensorHandle<R, MSK<AP>>>,
31 out: TensorHandle<R, OG<AP>>,
32) -> Result<(), AttentionSetupError> {
33 launch_ref::<R, AP>(
34 strategy,
35 client,
36 &query.as_ref(),
37 &key.as_ref(),
38 &value.as_ref(),
39 &mask.as_ref().map(|m| m.as_ref()),
40 &out.as_ref(),
41 )
42}
43
44#[allow(clippy::result_large_err)]
45pub fn launch_ref<R: Runtime, AP: AttentionPrecision>(
46 strategy: &Strategy,
47 client: &ComputeClient<R::Server>,
48 query: &TensorHandleRef<R>,
49 key: &TensorHandleRef<R>,
50 value: &TensorHandleRef<R>,
51 mask: &Option<TensorHandleRef<R>>,
52 out: &TensorHandleRef<R>,
53) -> Result<(), AttentionSetupError> {
54 match strategy {
55 Strategy::Tmp => launch_tmp::<R, AP>(client, query, key, value, mask, out),
56 }
57}
58
59pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
60 client: &ComputeClient<R::Server>,
61 query: &TensorHandleRef<R>,
62 key: &TensorHandleRef<R>,
63 value: &TensorHandleRef<R>,
64 mask: &Option<TensorHandleRef<R>>,
65 out: &TensorHandleRef<R>,
66) -> Result<(), AttentionSetupError> {
67 let line_sizes = AvailableLineSizes::from_elem_types::<R>(
68 query.elem_size,
69 size_of::<MSK<AP>>(),
70 out.elem_size,
71 );
72 let line_sizes = DummyRegisterAlgorithm::filter_line_sizes(line_sizes)
73 .filter_with_tensor(AttentionIdent::Query, query.strides, query.shape)
74 .filter_with_tensor(AttentionIdent::Key, key.strides, key.shape)
75 .filter_with_tensor(AttentionIdent::Value, value.strides, value.shape)
76 .filter_with_tensor(AttentionIdent::Out, out.strides, out.shape)
77 .pick_max()
78 .unwrap();
79
80 let problem = AttentionProblem {
81 batch: query.shape[0],
82 seq_q: query.shape[1],
83 seq_kv: key.shape[1],
84 num_heads: query.shape[2],
85 head_dim: query.shape[3],
86 val_dim: value.shape[3],
87 masked: mask.is_some(),
88 causal: false,
89 };
90
91 let tile_size = AttentionTileSize {
92 seq_q: 8,
93 head_dim: 8,
94 seq_kv: 8,
95 val_dim: 8,
96 };
97
98 let selection = AttentionSelection {
99 hypercube_selection: HypercubeSelection {},
100 tiling_scheme: AttentionTilingScheme {
101 tile_size,
102 partition_size: AttentionPartitionSize {
103 seq_q: 1,
104 head_dim: 1,
105 seq_kv: 1,
106 val_dim: 1,
107 },
108 stage_size: AttentionStageSize { seq_q: 1 },
109 },
110 plane_dim: 32,
111 reuse_key_value: false,
112 two_rows_in_array_tile: false,
113 };
114
115 let config = DummyRegisterAlgorithm::setup::<AP, R>(client, &problem, &selection, &line_sizes)?;
116
117 let cube_count_plan = config
118 .hypercube_config()
119 .cube_count_plan(&problem, &selection);
120
121 unsafe {
122 <DummyRegisterAlgorithm as Algorithm>::BatchAttention::launch_unchecked::<AP, R>(
123 client,
124 config.cube_dim(),
125 cube_count_plan.resolve(),
126 TensorInputsLaunch::new(
127 query.as_tensor_arg(line_sizes.query),
128 key.as_tensor_arg(line_sizes.key),
129 value.as_tensor_arg(line_sizes.value),
130 mask.as_ref()
131 .map(|it| it.as_tensor_arg(line_sizes.out))
132 .into(),
133 ),
134 out.as_tensor_arg(line_sizes.out),
135 cube_count_plan.as_args(),
136 config,
137 );
138 }
139
140 Ok(())
141}