use serde::{Deserialize, Serialize};
use trustformers_core::{errors::invalid_config, traits::Config};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FNetConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub intermediate_size: usize,
pub hidden_act: String,
pub hidden_dropout_prob: f32,
pub max_position_embeddings: usize,
pub type_vocab_size: usize,
pub initializer_range: f32,
pub layer_norm_eps: f32,
pub pad_token_id: u32,
pub position_embedding_type: String,
pub use_fourier_transform: bool, pub use_tpu_optimized_fft: bool, pub fourier_transform_type: String, pub use_bias_in_fourier: bool, pub fourier_dropout_prob: f32, }
impl Default for FNetConfig {
fn default() -> Self {
Self {
vocab_size: 32000, hidden_size: 768,
num_hidden_layers: 12,
intermediate_size: 3072,
hidden_act: "gelu".to_string(),
hidden_dropout_prob: 0.1,
max_position_embeddings: 512,
type_vocab_size: 4, initializer_range: 0.02,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: "absolute".to_string(),
use_fourier_transform: true,
use_tpu_optimized_fft: false,
fourier_transform_type: "dft".to_string(),
use_bias_in_fourier: true,
fourier_dropout_prob: 0.0, }
}
}
impl Config for FNetConfig {
fn validate(&self) -> trustformers_core::errors::Result<()> {
if !["dft", "real_dft", "dct"].contains(&self.fourier_transform_type.as_str()) {
return Err(trustformers_core::errors::invalid_config(
"fourier_transform_type",
"fourier_transform_type must be one of: dft, real_dft, dct",
));
}
if self.max_position_embeddings > 8192 {
return Err(invalid_config(
"config_field",
"max_position_embeddings > 8192 may be inefficient for FFT. Consider chunking."
.to_string(),
));
}
Ok(())
}
fn architecture(&self) -> &'static str {
"FNet"
}
}
impl FNetConfig {
pub fn fnet_base() -> Self {
Self::default()
}
pub fn fnet_large() -> Self {
Self {
hidden_size: 1024,
num_hidden_layers: 24,
intermediate_size: 4096,
..Self::default()
}
}
pub fn fnet_tpu() -> Self {
Self {
use_tpu_optimized_fft: true,
fourier_transform_type: "real_dft".to_string(), max_position_embeddings: 1024, ..Self::default()
}
}
pub fn fnet_dct() -> Self {
Self {
fourier_transform_type: "dct".to_string(),
max_position_embeddings: 1024,
..Self::default()
}
}
pub fn fnet_long() -> Self {
Self {
max_position_embeddings: 4096,
fourier_transform_type: "real_dft".to_string(), ..Self::default()
}
}
pub fn complexity_advantage(&self) -> f32 {
let n = self.max_position_embeddings as f32;
let attention_complexity = n * n; let fourier_complexity = n * n.log2(); attention_complexity / fourier_complexity
}
pub fn is_efficient_config(&self) -> bool {
let n = self.max_position_embeddings;
n > 0 && (n & (n - 1)) == 0
}
pub fn recommended_batch_size(&self) -> usize {
match self.hidden_size {
768 => 64, 1024 => 32, _ => 16, }
}
}