use std::marker::PhantomData;
use cubecl::ir::DeviceProperties;
use cubek_matmul::{
components::{
global::{
GlobalReaderConfig, GlobalWriterConfig, InputLoadFlow, PartitionedStageFamily,
PlaneFlowConfig, PlaneFlowPartitionRule,
memory::{GlobalMemoryConfig, ViewDirection},
multi_stage::EventLoadingMode,
read::ReaderMode,
},
stage::StridedStageFamily,
},
definition::LoadingPrecomputeStrategy,
};
use cubek_std::{MatrixLayout, StageIdent};
use crate::{
components::{
global::{
GlobalAttentionFamily,
simple::{SimpleGlobalAttention, config::SimpleGlobalAttentionConfig},
},
stage::{StageAttentionConfig as _, StageAttentionFamily},
},
definition::{AttentionBlueprint, AttentionElems, AttentionPrecision, AttentionSetupError},
};
pub struct SimpleGlobalAttentionFamily<SA: StageAttentionFamily> {
_phantom: PhantomData<SA>,
}
impl<
SA: StageAttentionFamily<
KeyStage = StridedStageFamily,
ValueStage = StridedStageFamily,
OutStage = PartitionedStageFamily,
>,
> GlobalAttentionFamily for SimpleGlobalAttentionFamily<SA>
{
type Attention<AP: AttentionPrecision> = SimpleGlobalAttention<AP, SA::Attention<AP>>;
type Config = SimpleGlobalAttentionConfig<SA::Config>;
fn expand_config(
device_props: &DeviceProperties,
blueprint: &AttentionBlueprint,
dtypes: &AttentionElems,
) -> Result<Self::Config, AttentionSetupError> {
let stage_config = SA::expand_config(device_props, blueprint, dtypes)?;
let precompute_job = LoadingPrecomputeStrategy::Never.into();
let plane_dim = stage_config.plane_dim();
let reader_mode = ReaderMode::Relaxed;
let event_loading_mode = EventLoadingMode::Relaxed;
let specialization_tensor_config = InputLoadFlow::MainOnly;
let plane_flow_config = PlaneFlowConfig::new_unspecialized(stage_config.num_planes());
let query_gmem_config = GlobalMemoryConfig {
vector_size: blueprint.vector_sizes.query,
check_row_bounds: blueprint.check_bounds.seq_q,
check_col_bounds: blueprint.check_bounds.head_dim,
matrix_layout: MatrixLayout::RowMajor,
view_direction: ViewDirection::None,
dtype: dtypes.query_global,
};
let mask_gmem_config = GlobalMemoryConfig {
vector_size: blueprint.vector_sizes.mask,
check_row_bounds: blueprint.check_bounds.seq_q,
check_col_bounds: blueprint.check_bounds.seq_kv,
matrix_layout: MatrixLayout::RowMajor,
view_direction: ViewDirection::Col,
dtype: dtypes.mask,
};
let key_gmem_config = GlobalMemoryConfig {
vector_size: blueprint.vector_sizes.key,
check_row_bounds: blueprint.check_bounds.seq_kv,
check_col_bounds: blueprint.check_bounds.head_dim,
matrix_layout: MatrixLayout::RowMajor,
view_direction: ViewDirection::Row,
dtype: dtypes.key_global,
};
let value_gmem_config = GlobalMemoryConfig {
vector_size: blueprint.vector_sizes.value,
check_row_bounds: blueprint.check_bounds.seq_kv,
check_col_bounds: blueprint.check_bounds.val_dim,
matrix_layout: MatrixLayout::RowMajor,
view_direction: ViewDirection::Row,
dtype: dtypes.value_global,
};
let out_gmem_config = GlobalMemoryConfig {
vector_size: blueprint.vector_sizes.out,
check_row_bounds: blueprint.check_bounds.seq_q,
check_col_bounds: blueprint.check_bounds.val_dim,
matrix_layout: MatrixLayout::RowMajor,
view_direction: ViewDirection::None,
dtype: dtypes.out_global,
};
let key_reader_config = GlobalReaderConfig {
gmem_config: key_gmem_config,
smem_config: stage_config.key_smem_config(),
precompute_job,
plane_dim,
reader_mode,
event_loading_mode,
input_load_flow: specialization_tensor_config,
plane_flow_config,
stage_ident: StageIdent::Rhs,
};
let value_reader_config = GlobalReaderConfig {
gmem_config: value_gmem_config,
smem_config: stage_config.value_smem_config(),
precompute_job,
plane_dim,
reader_mode,
event_loading_mode,
input_load_flow: specialization_tensor_config,
plane_flow_config,
stage_ident: StageIdent::Rhs,
};
let writer_config = GlobalWriterConfig {
gmem_config: out_gmem_config,
smem_config: stage_config.out_smem_config(),
plane_flow_partition_rule: PlaneFlowPartitionRule::MainFlowOnly,
plane_dim,
};
Ok(SimpleGlobalAttentionConfig {
stage_config,
key_reader_config,
value_reader_config,
query_gmem_config,
mask_gmem_config,
writer_config,
})
}
}