use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::prelude::*;
use crate::error::BrainHarmonyError;
use crate::model::linear_zeros;
#[derive(Module, Debug)]
pub struct BrainHarmonyPosEmbed<B: Backend> {
pub emb_h: Param<Tensor<B, 2>>,
pub grad_proj: Option<Linear<B>>,
pub geoh_proj: Option<Linear<B>>,
pub emb_w: Option<Param<Tensor<B, 2>>>,
pub emb_h_decoder: Option<Param<Tensor<B, 2>>>,
pub decoder_pos_embed_proj: Option<Linear<B>>,
pub emb_w_decoder: Option<Param<Tensor<B, 2>>>,
pub embed_dim: usize,
pub grid_h: usize,
pub grid_w: usize,
pub mode: String,
pub use_cls_token: bool,
pub use_decoder: bool,
}
impl<B: Backend> BrainHarmonyPosEmbed<B> {
pub fn new(
grad_dim: usize,
geoh_dim: usize,
embed_dim: usize,
pred_embed_dim: usize,
grid_size: (usize, usize),
mode: &str,
use_cls_token: bool,
use_decoder: bool,
device: &B::Device,
) -> crate::error::Result<Self> {
let (gh, gw) = grid_size;
let n = gh * gw;
let half_dim = embed_dim / 2;
let emb_h_data = sincos_1d_grid(half_dim, gh, gw);
let emb_h = Param::initialized(
ParamId::new(),
Tensor::<B, 2>::from_data(TensorData::new(emb_h_data, vec![n, half_dim]), device),
);
let (grad_proj, geoh_proj, emb_w) = match mode {
"gradient_geoh" => {
let gp = linear_zeros(grad_dim, half_dim, true, device);
let ghp = linear_zeros(geoh_dim, half_dim, true, device);
(Some(gp), Some(ghp), None)
}
"sincos" => {
let emb_w_data = sincos_1d_width(half_dim, gh, gw);
let t = Param::initialized(
ParamId::new(),
Tensor::<B, 2>::from_data(
TensorData::new(emb_w_data, vec![n, half_dim]),
device,
),
);
(None, None, Some(t))
}
_ => {
return Err(BrainHarmonyError::InvalidPosMode {
mode: mode.to_string(),
})
}
};
let (emb_h_decoder, decoder_pos_embed_proj, emb_w_decoder) = if use_decoder {
let pred_half = pred_embed_dim / 2;
let emb_h_dec_data = sincos_1d_grid(pred_half, gh, gw);
let emb_h_dec = Param::initialized(
ParamId::new(),
Tensor::<B, 2>::from_data(TensorData::new(emb_h_dec_data, vec![n, pred_half]), device),
);
match mode {
"gradient_geoh" => {
let proj = linear_zeros(half_dim, pred_half, true, device);
(Some(emb_h_dec), Some(proj), None)
}
"sincos" => {
let emb_w_dec_data = sincos_1d_width(pred_half, gh, gw);
let t = Param::initialized(
ParamId::new(),
Tensor::<B, 2>::from_data(
TensorData::new(emb_w_dec_data, vec![n, pred_half]),
device,
),
);
(Some(emb_h_dec), None, Some(t))
}
_ => (None, None, None),
}
} else {
(None, None, None)
};
Ok(Self {
emb_h,
grad_proj,
geoh_proj,
emb_w,
emb_h_decoder,
decoder_pos_embed_proj,
emb_w_decoder,
embed_dim,
grid_h: gh,
grid_w: gw,
mode: mode.to_string(),
use_cls_token,
use_decoder,
})
}
pub fn forward(
&self,
gradient: Option<&Tensor<B, 2>>,
geoh: Option<&Tensor<B, 2>>,
) -> (Tensor<B, 3>, Option<Tensor<B, 3>>) {
let emb_w = if self.mode == "gradient_geoh" {
let grad = gradient.expect("BUG: gradient tensor required for gradient_geoh mode");
let geoh_data = geoh.expect("BUG: geoh tensor required for gradient_geoh mode");
let grad_proj = self.grad_proj.as_ref().unwrap();
let geoh_proj = self.geoh_proj.as_ref().unwrap();
let grad_emb = grad_proj.forward(grad.clone()); let geoh_emb = geoh_proj.forward(geoh_data.clone());
let pos_embed = (grad_emb + geoh_emb).mul_scalar(0.5f32);
let repeated = repeat_interleave_dim0(pos_embed, self.grid_w);
let min_val: f32 = repeated.clone().min().into_scalar().elem();
let max_val: f32 = repeated.clone().max().into_scalar().elem();
let range = (max_val - min_val).max(1e-8);
repeated
.sub_scalar(min_val)
.div_scalar(range)
.mul_scalar(2.0f32)
.sub_scalar(1.0f32)
} else {
self.emb_w
.as_ref()
.expect("BUG: emb_w missing in sincos mode")
.val()
};
let emb_encoder = Tensor::cat(vec![self.emb_h.val(), emb_w.clone()], 1).unsqueeze_dim::<3>(0);
let pos_embed_encoder = if self.use_cls_token {
let [_, _n, d] = emb_encoder.dims();
let cls_zeros = Tensor::<B, 3>::zeros([1, 1, d], &emb_encoder.device());
Tensor::cat(vec![cls_zeros, emb_encoder], 1)
} else {
emb_encoder
};
let pos_embed_decoder = if self.use_decoder {
let emb_h_dec = self.emb_h_decoder.as_ref().unwrap().val();
let emb_w_dec = if self.mode == "gradient_geoh" {
let proj = self.decoder_pos_embed_proj.as_ref().unwrap();
proj.forward(emb_w)
} else {
self.emb_w_decoder.as_ref().unwrap().val()
};
let emb_decoder = Tensor::cat(vec![emb_h_dec, emb_w_dec], 1).unsqueeze_dim::<3>(0);
let dec = if self.use_cls_token {
let [_, _n, d] = emb_decoder.dims();
let cls_zeros = Tensor::<B, 3>::zeros([1, 1, d], &emb_decoder.device());
Tensor::cat(vec![cls_zeros, emb_decoder], 1)
} else {
emb_decoder
};
Some(dec)
} else {
None
};
(pos_embed_encoder, pos_embed_decoder)
}
}
fn repeat_interleave_dim0<B: Backend>(t: Tensor<B, 2>, repeats: usize) -> Tensor<B, 2> {
let [n, d] = t.dims();
t.unsqueeze_dim::<3>(1)
.expand([n, repeats, d])
.reshape([n * repeats, d])
}
fn sincos_1d_grid(half_dim: usize, grid_h: usize, grid_w: usize) -> Vec<f32> {
let n = grid_h * grid_w;
let quarter = half_dim / 2;
let mut data = vec![0.0f32; n * half_dim];
for h in 0..grid_h {
for w in 0..grid_w {
let pos = h as f64;
let idx = h * grid_w + w;
for k in 0..quarter {
let omega = 1.0 / 10000.0_f64.powf(k as f64 / quarter as f64);
let angle = pos * omega;
data[idx * half_dim + k] = angle.sin() as f32;
data[idx * half_dim + quarter + k] = angle.cos() as f32;
}
}
}
data
}
fn sincos_1d_width(half_dim: usize, grid_h: usize, grid_w: usize) -> Vec<f32> {
let n = grid_h * grid_w;
let quarter = half_dim / 2;
let mut data = vec![0.0f32; n * half_dim];
for h in 0..grid_h {
for w in 0..grid_w {
let pos = w as f64;
let idx = h * grid_w + w;
for k in 0..quarter {
let omega = 1.0 / 10000.0_f64.powf(k as f64 / quarter as f64);
let angle = pos * omega;
data[idx * half_dim + k] = angle.sin() as f32;
data[idx * half_dim + quarter + k] = angle.cos() as f32;
}
}
}
data
}
pub fn sincos_1d_flat(embed_dim: usize, n_positions: usize) -> Vec<f32> {
let half = embed_dim / 2;
let mut data = vec![0.0f32; n_positions * embed_dim];
for pos in 0..n_positions {
for k in 0..half {
let omega = 1.0 / 10000.0_f64.powf(k as f64 / half as f64);
let angle = pos as f64 * omega;
data[pos * embed_dim + k] = angle.sin() as f32;
data[pos * embed_dim + half + k] = angle.cos() as f32;
}
}
data
}