use cubek_std::GlobalPartitionSize;
use crate::{
components::global::memory::GlobalLayoutConfig,
components::{batch::BatchConfig, global::GlobalConfig},
};
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct PartitionedBatchConfig<G: GlobalConfig> {
pub global_config: G,
pub global_partition_size: GlobalPartitionSize,
}
impl<G: GlobalConfig> BatchConfig for PartitionedBatchConfig<G> {
fn lhs_global_layout_config(&self) -> GlobalLayoutConfig {
self.global_config.lhs_reader_config().gmem_config.into()
}
fn rhs_global_layout_config(&self) -> GlobalLayoutConfig {
self.global_config.rhs_reader_config().gmem_config.into()
}
fn out_global_layout_config(&self) -> GlobalLayoutConfig {
self.global_config.writer_config().gmem_config.into()
}
}
impl<G: GlobalConfig> PartitionedBatchConfig<G> {
pub fn new(global_config: G, global_partition_size: GlobalPartitionSize) -> Self {
Self {
global_config,
global_partition_size,
}
}
}