use alloc::vec::Vec;
use crate as burn;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use crate::tensor::TensorData;
#[cfg(not(feature = "std"))]
use num_traits::Float;
#[derive(Config)]
pub struct PositionalEncodingConfig {
#[config(default = "5_000")]
pub max_sequence_size: usize,
pub d_model: usize,
#[config(default = "10_000")]
pub max_timescale: usize,
}
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PositionalEncoding<B: Backend> {
pub sinusoids: Tensor<B, 3>,
pub max_sequence_size: usize,
pub max_timescale: usize,
}
impl<B: Backend> ModuleDisplay for PositionalEncoding<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [_, _, d_model] = self.sinusoids.shape().dims;
content
.add("d_model", &d_model)
.add("max_sequence_size", &self.max_sequence_size)
.add("max_timescale", &self.max_timescale)
.optional()
}
}
impl PositionalEncodingConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> PositionalEncoding<B> {
let sinusoids = generate_sinusoids::<B>(
self.max_sequence_size,
self.d_model,
self.max_timescale,
device,
)
.unsqueeze::<3>();
PositionalEncoding {
sinusoids,
max_sequence_size: self.max_sequence_size,
max_timescale: self.max_timescale,
}
}
}
impl<B: Backend> PositionalEncoding<B> {
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let [_, seq_length, d_model_input] = input.dims();
let [batch_size, max_sequence_size, d_model] = self.sinusoids.dims();
assert!(
max_sequence_size >= seq_length,
"max_sequence_size({}) must be greater or equal than length({seq_length})",
max_sequence_size,
);
assert!(
d_model_input == d_model,
"d_model({}) of the input must be equal to d_model of encoding({})",
d_model_input,
d_model,
);
let slices = [0..batch_size, 0..seq_length, 0..d_model];
input.add(self.sinusoids.clone().slice(slices))
}
}
pub fn generate_sinusoids<B: Backend>(
length: usize,
d_model: usize,
max_timescale: usize,
device: &B::Device,
) -> Tensor<B, 2> {
assert!(d_model % 2 == 0, "d_model must be even");
assert!(
max_timescale >= length,
"max_timescale must be greater than length"
);
let log_timescale_increment = -(max_timescale as f32).ln() / d_model as f32;
let mut scaled_time_sin_cos = Vec::with_capacity(length);
for i in 0..length {
let mut row = Vec::with_capacity(d_model / 2);
for k in (0..d_model).step_by(2) {
let div_term = (k as f32 * log_timescale_increment).exp();
row.push((div_term * i as f32).sin());
row.push((div_term * i as f32).cos());
}
scaled_time_sin_cos.push(row);
}
let data = TensorData::new(
scaled_time_sin_cos.into_iter().flatten().collect(),
[length, d_model],
);
Tensor::<B, 2>::from_data(data, device)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_module() {
let d_model = 6;
let length = 3;
let batch_size = 2;
let device = Default::default();
let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
let tensor = Tensor::zeros([batch_size, length, d_model], &device);
let output = pe.forward(tensor);
assert_eq!(output.shape().dims, [batch_size, length, d_model]);
let expected = Tensor::<TestBackend, 3>::from_floats(
[
[
[0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
[0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
[0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
],
[
[0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
[0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
[0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
],
],
&device,
);
output.to_data().assert_approx_eq(&expected.to_data(), 5);
}
#[test]
fn test_generate_sinusoids() {
let device = Default::default();
let sinusoids = generate_sinusoids::<TestBackend>(12, 6, 10_000, &device);
let expected = Tensor::<TestBackend, 2>::from_floats(
[
[0.00000, 1.00000, 0.00000, 1.00000, 0.00000, 1.00000],
[0.84147, 0.54030, 0.04640, 0.99892, 0.00215, 1.00000],
[0.90930, -0.41615, 0.09270, 0.99569, 0.00431, 0.99999],
[0.14112, -0.98999, 0.13880, 0.99032, 0.00646, 0.99998],
[-0.75680, -0.65364, 0.18460, 0.98281, 0.00862, 0.99996],
[-0.95892, 0.28366, 0.23000, 0.97319, 0.01077, 0.99994],
[-0.27942, 0.96017, 0.27491, 0.96147, 0.01293, 0.99992],
[0.65699, 0.75390, 0.31922, 0.94768, 0.01508, 0.99989],
[0.98936, -0.14550, 0.36285, 0.93185, 0.01723, 0.99985],
[0.41212, -0.91113, 0.40570, 0.91401, 0.01939, 0.99981],
[-0.54402, -0.83907, 0.44767, 0.89420, 0.02154, 0.99977],
[-0.99999, 0.00443, 0.48868, 0.87246, 0.02370, 0.99972],
],
&device,
);
sinusoids.to_data().assert_approx_eq(&expected.to_data(), 5);
}
#[test]
#[should_panic]
fn d_model_input_should_match() {
let d_model = 8;
let device = Default::default();
let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
let input = Tensor::zeros([1, 5, 10], &device);
let _output = pe.forward(input);
}
#[test]
#[should_panic]
fn input_length_should_be_less_than_max_len() {
let d_model = 8;
let device = Default::default();
let pe = PositionalEncodingConfig::new(d_model).init::<TestBackend>(&device);
let input = Tensor::zeros([1, 6_000, d_model], &device);
let _output = pe.forward(input);
}
#[test]
fn display() {
let config = PositionalEncodingConfig::new(4);
let pe = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", pe),
"PositionalEncoding {d_model: 4, max_sequence_size: 5000, max_timescale: 10000}"
);
}
}