cubecl_attention/
base.rs

1use cubecl_core::{Runtime, client::ComputeClient, prelude::TensorHandleRef};
2
3use cubecl_std::tensor::TensorHandle;
4
5use crate::{
6    components::{
7        AttentionIdent, AttentionPartitionSize, AttentionPrecision, AttentionProblem,
8        AttentionSelection, AttentionSetupError, AttentionStageSize, AttentionTileSize,
9        AttentionTilingScheme, AvailableLineSizes, args::TensorInputsLaunch, attention_types::*,
10        batch::HypercubeSelection,
11    },
12    kernels::{Algorithm, dummy::DummyRegisterAlgorithm},
13};
14
15use crate::components::batch::BatchAttentionConfig;
16use crate::components::batch::BatchAttentionFamily;
17
18pub enum Strategy {
19    /// Temporary implementation
20    Tmp,
21}
22
23#[allow(clippy::result_large_err)]
24pub fn launch<R: Runtime, AP: AttentionPrecision>(
25    strategy: &Strategy,
26    client: &ComputeClient<R::Server>,
27    query: TensorHandle<R, QG<AP>>,
28    key: TensorHandle<R, KG<AP>>,
29    value: TensorHandle<R, VG<AP>>,
30    mask: Option<TensorHandle<R, MSK<AP>>>,
31    out: TensorHandle<R, OG<AP>>,
32) -> Result<(), AttentionSetupError> {
33    launch_ref::<R, AP>(
34        strategy,
35        client,
36        &query.as_ref(),
37        &key.as_ref(),
38        &value.as_ref(),
39        &mask.as_ref().map(|m| m.as_ref()),
40        &out.as_ref(),
41    )
42}
43
44#[allow(clippy::result_large_err)]
45pub fn launch_ref<R: Runtime, AP: AttentionPrecision>(
46    strategy: &Strategy,
47    client: &ComputeClient<R::Server>,
48    query: &TensorHandleRef<R>,
49    key: &TensorHandleRef<R>,
50    value: &TensorHandleRef<R>,
51    mask: &Option<TensorHandleRef<R>>,
52    out: &TensorHandleRef<R>,
53) -> Result<(), AttentionSetupError> {
54    match strategy {
55        Strategy::Tmp => launch_tmp::<R, AP>(client, query, key, value, mask, out),
56    }
57}
58
59pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
60    client: &ComputeClient<R::Server>,
61    query: &TensorHandleRef<R>,
62    key: &TensorHandleRef<R>,
63    value: &TensorHandleRef<R>,
64    mask: &Option<TensorHandleRef<R>>,
65    out: &TensorHandleRef<R>,
66) -> Result<(), AttentionSetupError> {
67    let line_sizes = AvailableLineSizes::from_elem_types::<R>(
68        query.elem_size,
69        size_of::<MSK<AP>>(),
70        out.elem_size,
71    );
72    let line_sizes = DummyRegisterAlgorithm::filter_line_sizes(line_sizes)
73        .filter_with_tensor(AttentionIdent::Query, query.strides, query.shape)
74        .filter_with_tensor(AttentionIdent::Key, key.strides, key.shape)
75        .filter_with_tensor(AttentionIdent::Value, value.strides, value.shape)
76        .filter_with_tensor(AttentionIdent::Out, out.strides, out.shape)
77        .pick_max()
78        .unwrap();
79
80    let problem = AttentionProblem {
81        batch: query.shape[0],
82        seq_q: query.shape[1],
83        seq_kv: key.shape[1],
84        num_heads: query.shape[2],
85        head_dim: query.shape[3],
86        val_dim: value.shape[3],
87        masked: mask.is_some(),
88        causal: false,
89    };
90
91    let tile_size = AttentionTileSize {
92        seq_q: 8,
93        head_dim: 8,
94        seq_kv: 8,
95        val_dim: 8,
96    };
97
98    let selection = AttentionSelection {
99        hypercube_selection: HypercubeSelection {},
100        tiling_scheme: AttentionTilingScheme {
101            tile_size,
102            partition_size: AttentionPartitionSize {
103                seq_q: 1,
104                head_dim: 1,
105                seq_kv: 1,
106                val_dim: 1,
107            },
108            stage_size: AttentionStageSize { seq_q: 1 },
109        },
110        plane_dim: 32,
111        reuse_key_value: false,
112        two_rows_in_array_tile: false,
113    };
114
115    let config = DummyRegisterAlgorithm::setup::<AP, R>(client, &problem, &selection, &line_sizes)?;
116
117    let cube_count_plan = config
118        .hypercube_config()
119        .cube_count_plan(&problem, &selection);
120
121    unsafe {
122        <DummyRegisterAlgorithm as Algorithm>::BatchAttention::launch_unchecked::<AP, R>(
123            client,
124            config.cube_dim(),
125            cube_count_plan.resolve(),
126            TensorInputsLaunch::new(
127                query.as_tensor_arg(line_sizes.query),
128                key.as_tensor_arg(line_sizes.key),
129                value.as_tensor_arg(line_sizes.value),
130                mask.as_ref()
131                    .map(|it| it.as_tensor_arg(line_sizes.out))
132                    .into(),
133            ),
134            out.as_tensor_arg(line_sizes.out),
135            cube_count_plan.as_args(),
136            config,
137        );
138    }
139
140    Ok(())
141}