use super::config::TransformerConfig;
#[derive(Debug, Clone)]
pub struct TransformerBlock {
pub layer_index: usize,
pub hidden_size: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub intermediate_size: usize,
pub use_swiglu: bool,
}
impl TransformerBlock {
pub fn new(config: &TransformerConfig, layer_index: usize) -> Self {
Self {
layer_index,
hidden_size: config.hidden_size,
num_heads: config.num_heads,
num_kv_heads: config.num_kv_heads,
head_dim: config.hidden_size / config.num_heads,
intermediate_size: config.intermediate_size,
use_swiglu: config.use_swiglu,
}
}
pub fn attention_params(&self) -> usize {
let q_proj = self.hidden_size * self.hidden_size;
let k_proj = self.hidden_size * (self.num_kv_heads * self.head_dim);
let v_proj = self.hidden_size * (self.num_kv_heads * self.head_dim);
let o_proj = self.hidden_size * self.hidden_size;
q_proj + k_proj + v_proj + o_proj
}
pub fn ffn_params(&self) -> usize {
if self.use_swiglu {
self.hidden_size * self.intermediate_size * 2
+ self.intermediate_size * self.hidden_size
} else {
self.hidden_size * self.intermediate_size + self.intermediate_size * self.hidden_size
}
}
pub fn total_params(&self) -> usize {
self.attention_params() + self.ffn_params() + self.hidden_size * 2 }
}
pub fn create_blocks(config: &TransformerConfig) -> Vec<TransformerBlock> {
(0..config.num_layers)
.map(|i| TransformerBlock::new(config, i))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::local::architectures::SmallLmConfig;
#[test]
fn test_create_blocks() {
let config = SmallLmConfig::tiny();
let blocks = create_blocks(&config);
assert_eq!(blocks.len(), 12);
assert_eq!(blocks[0].layer_index, 0);
assert_eq!(blocks[11].layer_index, 11);
}
#[test]
fn test_block_params() {
let config = SmallLmConfig::tiny();
let block = TransformerBlock::new(&config, 0);
assert!(block.total_params() > 0);
assert!(block.attention_params() > 0);
assert!(block.ffn_params() > 0);
}
}