use std::f32::consts::PI;
use burn::prelude::*;
pub fn fourier_embed_4d<B: Backend>(
positions: Tensor<B, 3>,
dimension: usize,
freqs: usize,
increment_time: f32,
margin: f32,
device: &B::Device,
) -> Tensor<B, 3> {
let pos = positions;
let time_col = pos.clone().narrow(2, 3, 1).mul_scalar(increment_time);
let xyz = pos.narrow(2, 0, 3);
let pos = Tensor::cat(vec![xyz, time_col], 2);
let pos = pos + margin;
let width = 1.0 + 2.0 * margin;
let half_dim = dimension / 2;
let n_freq4 = freqs.pow(4);
let mut freq_coeffs = Vec::with_capacity(n_freq4 * 4);
for fx in 0..freqs {
for fy in 0..freqs {
for fz in 0..freqs {
for fw in 0..freqs {
freq_coeffs.push(2.0 * PI * fx as f32 / width);
freq_coeffs.push(2.0 * PI * fy as f32 / width);
freq_coeffs.push(2.0 * PI * fz as f32 / width);
freq_coeffs.push(2.0 * PI * fw as f32 / width);
}
}
}
}
let freq_t = Tensor::<B, 2>::from_data(
TensorData::new(freq_coeffs, vec![n_freq4, 4]),
device,
)
.transpose()
.unsqueeze_dim::<3>(0);
let loc = pos.matmul(freq_t);
let loc = if n_freq4 > half_dim {
loc.narrow(2, 0, half_dim)
} else if n_freq4 < half_dim {
panic!(
"freqs^4 = {} < half_dim = {}. Increase freqs parameter.",
n_freq4, half_dim
);
} else {
loc
};
let cos_part = loc.clone().cos();
let sin_part = loc.sin();
Tensor::cat(vec![cos_part, sin_part], 2)
}
pub fn add_time_patch<B: Backend>(
pos: Tensor<B, 3>,
num_patches: usize,
device: &B::Device,
) -> Tensor<B, 3> {
let [batch, n_chans, _] = pos.dims();
let pos_4d = pos.unsqueeze_dim::<4>(2); let pos_repeated = pos_4d.repeat_dim(2, num_patches);
let time_data: Vec<f32> = (0..num_patches).map(|t| t as f32).collect();
let time_values = Tensor::<B, 1>::from_data(
TensorData::new(time_data, vec![num_patches]),
device,
)
.reshape([1, 1, num_patches, 1])
.repeat_dim(0, batch)
.repeat_dim(1, n_chans);
let pos_with_time = Tensor::cat(vec![pos_repeated, time_values], 3);
pos_with_time.reshape([batch, n_chans * num_patches, 4])
}