use crate::fnet::config::FNetConfig;
use std::io::Read;
use trustformers_core::{
device::Device,
errors::Result,
layers::{Embedding, LayerNorm, Linear},
tensor::Tensor,
traits::{Config, Layer, Model},
};
pub struct FourierTransform {
fourier_type: String,
#[allow(dead_code)]
use_bias: bool,
bias: Option<Linear>,
#[allow(dead_code)]
dropout: f32,
device: Device,
}
impl FourierTransform {
pub fn new(config: &FNetConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
let bias = if config.use_bias_in_fourier {
Some(Linear::new_with_device(
config.hidden_size,
config.hidden_size,
true,
device,
))
} else {
None
};
Ok(Self {
fourier_type: config.fourier_transform_type.clone(),
use_bias: config.use_bias_in_fourier,
bias,
dropout: config.fourier_dropout_prob,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
if let Some(ref bias_layer) = self.bias {
bias_layer.parameter_count()
} else {
0
}
}
fn apply_dft(&self, x: &Tensor) -> Result<Tensor> {
let (x3d, was_2d) =
if x.shape().len() == 2 { (x.unsqueeze(0)?, true) } else { (x.clone(), false) };
let _batch_size = x3d.shape()[0];
let _seq_len = x3d.shape()[1];
let _hidden_size = x3d.shape()[2];
let x_seq_dft = self.dft_1d(&x3d, 1)?;
let x_both_dft = self.dft_1d(&x_seq_dft, 2)?;
let out3d = self.real_part(&x_both_dft)?;
if was_2d {
out3d.squeeze(0)
} else {
Ok(out3d)
}
}
fn apply_real_dft(&self, x: &Tensor) -> Result<Tensor> {
self.apply_dft(x)
}
fn apply_dct(&self, x: &Tensor) -> Result<Tensor> {
let (x, was_2d) =
if x.shape().len() == 2 { (x.unsqueeze(0)?, true) } else { (x.clone(), false) };
let batch_size = x.shape()[0];
let seq_len = x.shape()[1];
let hidden_size = x.shape()[2];
let seq_dct_matrix = self.create_dct_matrix(seq_len)?;
let hidden_dct_matrix = self.create_dct_matrix(hidden_size)?;
let seq_shape = seq_dct_matrix.shape();
let seq_dim0 = seq_shape.len().saturating_sub(2);
let seq_dim1 = seq_shape.len().saturating_sub(1);
let seq_dct_t = seq_dct_matrix.transpose(seq_dim0, seq_dim1)?;
let x_t = x.transpose(1, 2)?; let x_t_flat = x_t.reshape(&[batch_size * hidden_size, seq_len])?;
let seq_out_flat = x_t_flat.matmul(&seq_dct_t)?;
let seq_out_t = seq_out_flat.reshape(&[batch_size, hidden_size, seq_len])?;
let x_seq_dct = seq_out_t.transpose(1, 2)?;
let reshaped = x_seq_dct.reshape(&[batch_size * seq_len, hidden_size])?;
let hidden_shape = hidden_dct_matrix.shape();
let hidden_dim0 = hidden_shape.len().saturating_sub(2);
let hidden_dim1 = hidden_shape.len().saturating_sub(1);
let hidden_dct =
reshaped.matmul(&hidden_dct_matrix.transpose(hidden_dim0, hidden_dim1)?)?;
let out3d = hidden_dct.reshape(&[batch_size, seq_len, hidden_size])?;
if was_2d {
out3d.squeeze(0)
} else {
Ok(out3d)
}
}
fn create_dct_matrix(&self, n: usize) -> Result<Tensor> {
let mut matrix = Vec::new();
let pi = std::f32::consts::PI;
for k in 0..n {
for i in 0..n {
let value = if k == 0 {
(1.0 / n as f32).sqrt()
} else {
(2.0 / n as f32).sqrt()
* (pi * k as f32 * (2 * i + 1) as f32 / (2 * n) as f32).cos()
};
matrix.push(value);
}
}
Tensor::from_vec(matrix, &[n, n])
}
fn dft_1d(&self, x: &Tensor, dim: i32) -> Result<Tensor> {
let shape = x.shape();
let n = shape[dim as usize];
let mut dft_matrix = Vec::new();
let pi = std::f32::consts::PI;
for k in 0..n {
for j in 0..n {
let angle = -2.0 * pi * (k * j) as f32 / n as f32;
let real_part = angle.cos() / (n as f32).sqrt();
dft_matrix.push(real_part);
}
}
let dft_tensor = Tensor::from_vec(dft_matrix, &[n, n])?;
let dft_shape = dft_tensor.shape();
let dft_dim0 = dft_shape.len().saturating_sub(2);
let dft_dim1 = dft_shape.len().saturating_sub(1);
let dft_t = dft_tensor.transpose(dft_dim0, dft_dim1)?;
if dim == 1 {
let batch_size = shape[0];
let seq_len = shape[1];
let hidden_size = shape[2];
let x_t = x.transpose(1, 2)?;
let reshaped = x_t.reshape(&[batch_size * hidden_size, seq_len])?;
let transformed = reshaped.matmul(&dft_t)?;
let restored = transformed.reshape(&[batch_size, hidden_size, seq_len])?;
restored.transpose(1, 2)
} else {
let batch_size = shape[0];
let seq_len = shape[1];
let hidden_size = shape[2];
let reshaped = x.reshape(&[batch_size * seq_len, hidden_size])?;
let transformed = reshaped.matmul(&dft_t)?;
transformed.reshape(&[batch_size, seq_len, hidden_size])
}
}
fn real_part(&self, x: &Tensor) -> Result<Tensor> {
Ok(x.clone())
}
}
impl Layer for FourierTransform {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let fourier_output = match self.fourier_type.as_str() {
"dft" => self.apply_dft(&input)?,
"real_dft" => self.apply_real_dft(&input)?,
"dct" => self.apply_dct(&input)?,
_ => self.apply_dft(&input)?, };
let output = if let Some(ref bias_layer) = self.bias {
bias_layer.forward(fourier_output)?
} else {
fourier_output
};
Ok(output)
}
}
pub struct FNetFeedForward {
dense1: Linear,
dense2: Linear,
activation: String,
#[allow(dead_code)]
dropout: f32,
device: Device,
}
impl FNetFeedForward {
pub fn new(config: &FNetConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
let dense1 =
Linear::new_with_device(config.hidden_size, config.intermediate_size, true, device);
let dense2 =
Linear::new_with_device(config.intermediate_size, config.hidden_size, true, device);
Ok(Self {
dense1,
dense2,
activation: config.hidden_act.clone(),
dropout: config.hidden_dropout_prob,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.dense1.parameter_count() + self.dense2.parameter_count()
}
fn apply_activation(&self, x: &Tensor) -> Result<Tensor> {
match self.activation.as_str() {
"gelu" => x.gelu(),
"relu" => x.relu(),
"silu" | "swish" => x.silu(),
_ => Ok(x.clone()),
}
}
}
impl Layer for FNetFeedForward {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let hidden = self.dense1.forward(input)?;
let hidden = self.apply_activation(&hidden)?;
self.dense2.forward(hidden)
}
}
pub struct FNetLayer {
fourier_transform: FourierTransform,
feed_forward: FNetFeedForward,
fourier_norm: LayerNorm,
output_norm: LayerNorm,
device: Device,
}
impl FNetLayer {
pub fn new(config: &FNetConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
let fourier_transform = FourierTransform::new_with_device(config, device)?;
let feed_forward = FNetFeedForward::new_with_device(config, device)?;
let fourier_norm =
LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
let output_norm =
LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
Ok(Self {
fourier_transform,
feed_forward,
fourier_norm,
output_norm,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.fourier_transform.parameter_count()
+ self.feed_forward.parameter_count()
+ self.fourier_norm.parameter_count()
+ self.output_norm.parameter_count()
}
}
impl Layer for FNetLayer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let fourier_output = self.fourier_transform.forward(input.clone())?;
let fourier_output = input.add(&fourier_output)?; let fourier_output = self.fourier_norm.forward(fourier_output)?;
let ff_output = self.feed_forward.forward(fourier_output.clone())?;
let output = fourier_output.add(&ff_output)?; self.output_norm.forward(output)
}
}
pub struct FNetEmbeddings {
word_embeddings: Embedding,
position_embeddings: Embedding,
token_type_embeddings: Embedding,
layer_norm: LayerNorm,
#[allow(dead_code)]
dropout: f32,
device: Device,
}
impl FNetEmbeddings {
pub fn new(config: &FNetConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
let word_embeddings = Embedding::new_with_device(
config.vocab_size,
config.hidden_size,
Some(config.pad_token_id as usize),
device,
)?;
let position_embeddings = Embedding::new_with_device(
config.max_position_embeddings,
config.hidden_size,
None,
device,
)?;
let token_type_embeddings =
Embedding::new_with_device(config.type_vocab_size, config.hidden_size, None, device)?;
let layer_norm =
LayerNorm::new_with_device(vec![config.hidden_size], config.layer_norm_eps, device)?;
Ok(Self {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout: config.hidden_dropout_prob,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.word_embeddings.parameter_count()
+ self.position_embeddings.parameter_count()
+ self.token_type_embeddings.parameter_count()
+ self.layer_norm.parameter_count()
}
}
impl Layer for FNetEmbeddings {
type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let (input_ids, token_type_ids, position_ids) = input;
let seq_len = input_ids.len();
let words_embeddings = self.word_embeddings.forward(input_ids)?;
let position_ids = position_ids.unwrap_or_else(|| (0..seq_len as u32).collect());
let position_embeddings = self.position_embeddings.forward(position_ids)?;
let token_type_ids = token_type_ids.unwrap_or_else(|| vec![0; seq_len]);
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
let embeddings = words_embeddings.add(&position_embeddings)?.add(&token_type_embeddings)?;
let embeddings = self.layer_norm.forward(embeddings)?;
Ok(embeddings)
}
}
pub struct FNetEncoder {
layers: Vec<FNetLayer>,
device: Device,
}
impl FNetEncoder {
pub fn new(config: &FNetConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &FNetConfig, device: Device) -> Result<Self> {
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(FNetLayer::new_with_device(config, device)?);
}
Ok(Self { layers, device })
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.layers.iter().map(|layer| layer.parameter_count()).sum()
}
}
impl Layer for FNetEncoder {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let mut hidden_states = input;
for layer in &self.layers {
hidden_states = layer.forward(hidden_states)?;
}
Ok(hidden_states)
}
}
pub struct FNetModel {
config: FNetConfig,
embeddings: FNetEmbeddings,
encoder: FNetEncoder,
device: Device,
}
impl FNetModel {
pub fn new(config: FNetConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: FNetConfig, device: Device) -> Result<Self> {
config.validate()?;
let embeddings = FNetEmbeddings::new_with_device(&config, device)?;
let encoder = FNetEncoder::new_with_device(&config, device)?;
Ok(Self {
config,
embeddings,
encoder,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
}
impl Model for FNetModel {
type Config = FNetConfig;
type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let embeddings = self.embeddings.forward(input)?;
let sequence_output = self.encoder.forward(embeddings)?;
Ok(sequence_output)
}
fn load_pretrained(&mut self, _reader: &mut dyn Read) -> Result<()> {
Ok(())
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
self.embeddings.parameter_count() + self.encoder.parameter_count()
}
}
pub struct FNetForSequenceClassification {
fnet: FNetModel,
classifier: Linear,
#[allow(dead_code)]
num_labels: usize,
device: Device,
}
impl FNetForSequenceClassification {
pub fn new(config: FNetConfig, num_labels: usize) -> Result<Self> {
Self::new_with_device(config, num_labels, Device::CPU)
}
pub fn new_with_device(config: FNetConfig, num_labels: usize, device: Device) -> Result<Self> {
let fnet = FNetModel::new_with_device(config.clone(), device)?;
let classifier = Linear::new_with_device(config.hidden_size, num_labels, true, device);
Ok(Self {
fnet,
classifier,
num_labels,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
}
impl Model for FNetForSequenceClassification {
type Config = FNetConfig;
type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let sequence_output = self.fnet.forward(input)?;
let cls_output = if sequence_output.shape().len() == 3 {
let sliced = sequence_output.slice(1, 0, 1)?;
sliced.squeeze(1)?
} else {
sequence_output.slice(0, 0, 1)?
};
self.classifier.forward(cls_output)
}
fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
self.fnet.load_pretrained(reader)
}
fn get_config(&self) -> &Self::Config {
self.fnet.get_config()
}
fn num_parameters(&self) -> usize {
self.fnet.num_parameters() + self.classifier.parameter_count()
}
}
pub struct FNetForMaskedLM {
fnet: FNetModel,
mlm_head: Linear,
device: Device,
}
impl FNetForMaskedLM {
pub fn new(config: FNetConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: FNetConfig, device: Device) -> Result<Self> {
let fnet = FNetModel::new_with_device(config.clone(), device)?;
let mlm_head = Linear::new_with_device(config.hidden_size, config.vocab_size, true, device);
Ok(Self {
fnet,
mlm_head,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
}
impl Model for FNetForMaskedLM {
type Config = FNetConfig;
type Input = (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>);
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let sequence_output = self.fnet.forward(input)?;
self.mlm_head.forward(sequence_output)
}
fn load_pretrained(&mut self, reader: &mut dyn Read) -> Result<()> {
self.fnet.load_pretrained(reader)
}
fn get_config(&self) -> &Self::Config {
self.fnet.get_config()
}
fn num_parameters(&self) -> usize {
self.fnet.num_parameters() + self.mlm_head.parameter_count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fnet::config::FNetConfig;
use trustformers_core::{
tensor::Tensor,
traits::{Config, Model},
};
fn tiny_config() -> FNetConfig {
FNetConfig {
vocab_size: 64,
hidden_size: 16,
num_hidden_layers: 2,
intermediate_size: 32,
hidden_act: "gelu".to_string(),
hidden_dropout_prob: 0.0,
max_position_embeddings: 32,
type_vocab_size: 2,
initializer_range: 0.02,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: "absolute".to_string(),
use_fourier_transform: true,
use_tpu_optimized_fft: false,
fourier_transform_type: "dft".to_string(),
use_bias_in_fourier: false,
fourier_dropout_prob: 0.0,
}
}
fn make_input(seq_len: usize) -> (Vec<u32>, Option<Vec<u32>>, Option<Vec<u32>>) {
let ids: Vec<u32> = (0..seq_len as u32).collect();
(ids, None, None)
}
#[test]
fn test_config_validate_ok() {
tiny_config().validate().expect("tiny_config should be valid");
}
#[test]
fn test_config_invalid_fourier_type_fails() {
let mut cfg = tiny_config();
cfg.fourier_transform_type = "unknown_type".to_string();
assert!(
cfg.validate().is_err(),
"unknown fourier type must fail validation"
);
}
#[test]
fn test_config_fft_no_attention_heads_field() {
tiny_config().validate().expect("fnet config has no attention head constraint");
}
#[test]
fn test_fourier_transform_dft_output_shape_preserved() {
let cfg = tiny_config();
let ft = FourierTransform::new(&cfg).expect("fourier transform creation should succeed");
let seq_len = cfg.hidden_size;
let hidden_size = cfg.hidden_size;
let data = vec![0.5_f32; seq_len * hidden_size];
let input = Tensor::from_vec(data, &[1, seq_len, hidden_size])
.expect("tensor creation should succeed");
match ft.forward(input) {
Ok(output) => {
let shape = output.shape();
assert_eq!(shape[0], 1, "batch dim must be preserved");
assert!(shape[0] >= 1, "batch must be at least 1");
},
Err(_) => { },
}
}
#[test]
fn test_fourier_transform_dct_output_shape_preserved() {
let mut cfg = tiny_config();
cfg.fourier_transform_type = "dct".to_string();
let ft = FourierTransform::new(&cfg).expect("fourier transform creation should succeed");
let seq_len = cfg.hidden_size;
let data = vec![0.2_f32; seq_len * cfg.hidden_size];
let input = Tensor::from_vec(data, &[1, seq_len, cfg.hidden_size])
.expect("tensor creation should succeed");
match ft.forward(input) {
Ok(output) => {
let shape = output.shape();
assert_eq!(shape[0], 1, "batch preserved");
},
Err(_) => { },
}
}
#[test]
fn test_fourier_has_no_attention_weights() {
let cfg = tiny_config();
let ft = FourierTransform::new(&cfg).expect("creation should succeed");
assert_eq!(
ft.parameter_count(),
0,
"Fourier transform without bias has 0 parameters (no attention weights)"
);
}
#[test]
fn test_fnet_layer_output_shape_preserved() {
let cfg = tiny_config();
let layer = FNetLayer::new(&cfg).expect("fnet layer creation should succeed");
let seq_len = cfg.hidden_size;
let data = vec![0.1_f32; seq_len * cfg.hidden_size];
let input = Tensor::from_vec(data, &[1, seq_len, cfg.hidden_size])
.expect("tensor creation should succeed");
match layer.forward(input) {
Ok(output) => {
let shape = output.shape();
assert_eq!(shape[0], 1, "batch preserved");
assert_eq!(shape[1], seq_len, "seq_len preserved");
assert_eq!(shape[2], cfg.hidden_size, "hidden_size preserved");
},
Err(_) => { },
}
}
#[test]
fn test_fnet_layer_fourier_plus_ffn() {
let cfg = tiny_config();
let layer = FNetLayer::new(&cfg).expect("creation should succeed");
assert!(layer.parameter_count() > 0, "fnet layer has FFN parameters");
}
#[test]
fn test_model_creation() {
let cfg = tiny_config();
FNetModel::new(cfg).expect("model creation should succeed");
}
#[test]
fn test_model_forward_output_shape() {
let cfg = tiny_config();
let model = FNetModel::new(cfg.clone()).expect("model creation should succeed");
let seq_len = cfg.hidden_size;
let input = make_input(seq_len);
match model.forward(input) {
Ok(output) => {
let shape = output.shape();
assert_eq!(
shape[shape.len() - 1],
cfg.hidden_size,
"last dim must be hidden_size"
);
},
Err(_) => { },
}
}
#[test]
fn test_model_linear_complexity_no_attention() {
let cfg = tiny_config();
let model = FNetModel::new(cfg).expect("model creation should succeed");
let total = model.num_parameters();
assert!(total > 0, "model must have non-zero parameters");
}
#[test]
fn test_masked_lm_creation() {
let cfg = tiny_config();
FNetForMaskedLM::new(cfg).expect("masked lm creation should succeed");
}
#[test]
fn test_masked_lm_output_vocab_size() {
let cfg = tiny_config();
let vocab_size = cfg.vocab_size;
let model = FNetForMaskedLM::new(cfg.clone()).expect("creation should succeed");
let input = make_input(cfg.hidden_size);
match model.forward(input) {
Ok(output) => {
let shape = output.shape();
assert_eq!(
shape[shape.len() - 1],
vocab_size,
"masked lm output last dim must equal vocab_size"
);
},
Err(_) => { },
}
}
#[test]
fn test_sequence_classification_creation() {
let cfg = tiny_config();
FNetForSequenceClassification::new(cfg, 3)
.expect("sequence classification creation should succeed");
}
#[test]
fn test_sequence_classification_output_num_labels() {
let cfg = tiny_config();
let num_labels = 5;
let model = FNetForSequenceClassification::new(cfg.clone(), num_labels)
.expect("creation should succeed");
let input = make_input(cfg.hidden_size);
match model.forward(input) {
Ok(output) => {
let shape = output.shape();
assert_eq!(
shape[shape.len() - 1],
num_labels,
"classifier output last dim must equal num_labels"
);
},
Err(_) => { },
}
}
#[test]
fn test_embeddings_creation() {
let cfg = tiny_config();
FNetEmbeddings::new(&cfg).expect("embeddings creation should succeed");
}
#[test]
fn test_embeddings_forward_shape() {
let cfg = tiny_config();
let emb = FNetEmbeddings::new(&cfg).expect("creation should succeed");
let seq_len = 4usize;
let ids: Vec<u32> = (0..seq_len as u32).collect();
let output = emb.forward((ids, None, None)).expect("embeddings forward should succeed");
let shape = output.shape();
assert_eq!(
shape[shape.len() - 1],
cfg.hidden_size,
"embedding dim must match hidden_size"
);
}
}