use burn::module::{
AutodiffModule, Content, Devices, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
ModuleVisitor,
};
use burn::tensor::backend::{AutodiffBackend, Backend};
use crate::kernel::{BlockPattern1d, BlockPattern2d, BlockSparseConfig};
use crate::positional::RotaryEmbedding;
#[derive(Clone, Debug)]
pub struct FusedKernelConfig {
pub enabled: bool,
pub block_sparse: BlockSparseConfig,
pub rope_theta: f32,
pub relu_threshold: f32,
pub alibi_slopes: Option<Vec<f32>>,
pub rotary_embedding: RotaryEmbedding,
}
impl Default for FusedKernelConfig {
fn default() -> Self {
Self {
enabled: false,
block_sparse: BlockSparseConfig::dense(64, 64),
rope_theta: 65_536.0,
relu_threshold: 0.0,
alibi_slopes: None,
rotary_embedding: RotaryEmbedding::default(),
}
}
}
impl FusedKernelConfig {
pub fn with_block_sizes(mut self, latent: usize, time: usize) -> Self {
self.set_block_sizes(latent, time);
self
}
pub fn set_block_sizes(&mut self, latent: usize, time: usize) {
self.block_sparse = BlockSparseConfig {
latent: BlockPattern1d::dense(latent),
time: BlockPattern2d::dense(time),
};
}
pub fn set_alibi_slopes(&mut self, slopes: Vec<f32>) {
self.alibi_slopes = Some(slopes);
}
pub fn set_rotary_embedding(&mut self, rotary_embedding: RotaryEmbedding) {
self.rotary_embedding = rotary_embedding;
}
}
impl<B: Backend> Module<B> for FusedKernelConfig {
type Record = ();
fn collect_devices(&self, devices: Devices<B>) -> Devices<B> {
devices
}
fn fork(self, _device: &B::Device) -> Self {
self
}
fn to_device(self, _device: &B::Device) -> Self {
self
}
fn visit<Visitor: ModuleVisitor<B>>(&self, _visitor: &mut Visitor) {}
fn map<Mapper: ModuleMapper<B>>(self, _mapper: &mut Mapper) -> Self {
self
}
fn load_record(self, _record: Self::Record) -> Self {
self
}
fn into_record(self) -> Self::Record {}
}
impl<B: AutodiffBackend> AutodiffModule<B> for FusedKernelConfig {
type InnerModule = FusedKernelConfig;
fn valid(&self) -> Self::InnerModule {
self.clone()
}
}
impl ModuleDisplayDefault for FusedKernelConfig {
fn content(&self, content: Content) -> Option<Content> {
let summary = format!(
"enabled={}, rotary_embedding={}, relu_threshold={}, rope_theta={}, latent_block={}, time_block={}, custom_alibi={}",
self.enabled,
self.rotary_embedding,
self.relu_threshold,
self.rope_theta,
self.block_sparse.latent.block_size(),
self.block_sparse.time.block_size(),
self.alibi_slopes.as_ref().map(|s| s.len()).unwrap_or(0)
);
content
.set_top_level_type("FusedKernelConfig")
.add_formatted(&summary)
.optional()
}
}
impl ModuleDisplay for FusedKernelConfig {}
#[derive(Clone, Debug)]
pub struct BDHConfig {
pub n_layer: usize,
pub n_embd: usize,
pub dropout: f64,
pub n_head: usize,
pub mlp_internal_dim_multiplier: usize,
pub n_expert: usize,
pub vocab_size: usize,
pub fused_kernels: FusedKernelConfig,
}
impl Default for BDHConfig {
fn default() -> Self {
Self {
n_layer: 6,
n_embd: 256,
dropout: 0.1,
n_head: 4,
mlp_internal_dim_multiplier: 4,
n_expert: 1,
vocab_size: 256,
fused_kernels: FusedKernelConfig::default(),
}
}
}
impl BDHConfig {
pub fn latent_per_head(&self) -> usize {
let total = self.mlp_internal_dim_multiplier * self.n_embd;
assert!(
total.is_multiple_of(self.n_head),
"latent size must be divisible by the number of heads"
);
let latent_per_head = total / self.n_head;
assert!(
latent_per_head.is_multiple_of(self.n_expert),
"latent per head {} must be divisible by experts {}",
latent_per_head,
self.n_expert
);
latent_per_head
}
pub fn latent_total(&self) -> usize {
self.latent_per_head() * self.n_head
}
pub fn latent_per_expert(&self) -> usize {
self.latent_per_head() / self.n_expert
}
}