use burn::tensor::{Tensor, backend::Backend};
pub fn sinusoids<B: Backend>(length: usize, channels: usize, device: &B::Device) -> Tensor<B, 2> {
assert!(channels % 2 == 0, "channels must be even, got {channels}");
assert_ne!(channels, 2);
let half = channels / 2;
let log_timescale_increment = 10000.0_f64.ln() / (half as f64 - 1.0);
let mut data = vec![0.0f32; length * channels];
for pos in 0..length {
for i in 0..half {
let inv_timescale = (-log_timescale_increment * i as f64).exp();
let angle = (pos as f64 * inv_timescale) as f32;
data[pos * channels + i] = angle.sin();
data[pos * channels + half + i] = angle.cos();
}
}
Tensor::from_data(
burn::tensor::TensorData::new(data, [length, channels]),
device,
)
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type TestBackend = NdArray;
#[test]
fn test_sinusoids_shape() {
let device = Default::default();
let emb = sinusoids::<TestBackend>(1500, 512, &device);
assert_eq!(emb.shape().dims, [1500, 512]);
}
#[test]
fn test_sinusoids_first_position_is_zero_sin() {
let device = Default::default();
let emb = sinusoids::<TestBackend>(10, 4, &device);
let data = emb.to_data();
let vals = data.to_vec::<f32>().unwrap();
assert!((vals[0] - 0.0).abs() < 1e-6, "sin at pos=0 should be 0");
assert!((vals[1] - 0.0).abs() < 1e-6, "sin at pos=0 should be 0");
assert!((vals[2] - 1.0).abs() < 1e-6, "cos at pos=0 should be 1");
assert!((vals[3] - 1.0).abs() < 1e-6, "cos at pos=0 should be 1");
}
}