use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransformerConfig {
pub num_layers: usize,
pub hidden_dim: usize,
pub num_heads: usize,
pub feedforward_dim: usize,
#[serde(default = "default_dropout")]
pub dropout: f64,
pub max_seq_len: usize,
#[serde(default)]
pub attention_type: AttentionType,
#[serde(default)]
pub position_encoding: PositionEncodingType,
#[serde(default = "default_layer_norm_eps")]
pub layer_norm_eps: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AttentionType {
#[default]
Self_,
Cross,
Both,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PositionEncodingType {
#[default]
Learned,
Sinusoidal,
Rotary,
Relative,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncoderDecoderConfig {
pub encoder: TransformerConfig,
pub decoder: TransformerConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WhisperConfig {
pub audio_encoder: AudioEncoderConfig,
pub text_decoder: TransformerConfig,
pub vocab_size: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioEncoderConfig {
pub n_mels: usize,
pub transformer: TransformerConfig,
#[serde(default)]
pub conv_layers: Vec<ConvConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConvConfig {
pub in_channels: usize,
pub out_channels: usize,
pub kernel_size: usize,
#[serde(default = "default_stride")]
pub stride: usize,
#[serde(default = "default_padding")]
pub padding: usize,
}
fn default_dropout() -> f64 {
0.1
}
fn default_layer_norm_eps() -> f64 {
1e-5
}
fn default_stride() -> usize {
1
}
fn default_padding() -> usize {
0
}
impl TransformerConfig {
pub fn bert_base() -> Self {
Self {
num_layers: 12,
hidden_dim: 768,
num_heads: 12,
feedforward_dim: 3072,
dropout: 0.1,
max_seq_len: 512,
attention_type: AttentionType::Self_,
position_encoding: PositionEncodingType::Learned,
layer_norm_eps: 1e-12,
}
}
pub fn gpt2() -> Self {
Self {
num_layers: 12,
hidden_dim: 768,
num_heads: 12,
feedforward_dim: 3072,
dropout: 0.1,
max_seq_len: 1024,
attention_type: AttentionType::Self_,
position_encoding: PositionEncodingType::Learned,
layer_norm_eps: 1e-5,
}
}
pub fn whisper_tiny_encoder() -> Self {
Self {
num_layers: 4,
hidden_dim: 384,
num_heads: 6,
feedforward_dim: 1536,
dropout: 0.0,
max_seq_len: 1500,
attention_type: AttentionType::Self_,
position_encoding: PositionEncodingType::Learned,
layer_norm_eps: 1e-5,
}
}
pub fn whisper_tiny_decoder() -> Self {
Self {
num_layers: 4,
hidden_dim: 384,
num_heads: 6,
feedforward_dim: 1536,
dropout: 0.0,
max_seq_len: 448,
attention_type: AttentionType::Both,
position_encoding: PositionEncodingType::Learned,
layer_norm_eps: 1e-5,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_serialization() {
let config = TransformerConfig::bert_base();
let json = serde_json::to_string_pretty(&config).unwrap();
println!("BERT config:\n{}", json);
let deserialized: TransformerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.num_layers, deserialized.num_layers);
}
#[test]
fn test_whisper_config() {
let encoder = TransformerConfig::whisper_tiny_encoder();
let decoder = TransformerConfig::whisper_tiny_decoder();
assert_eq!(encoder.num_layers, 4);
assert_eq!(encoder.hidden_dim, 384);
assert_eq!(decoder.num_layers, 4);
assert_eq!(decoder.hidden_dim, 384);
}
}