use crate::config::MoonVitConfig;
const ROPE_THETA: f64 = 10_000.0;
pub fn freqs_cis_for_grid(
cfg: &MoonVitConfig,
grid_h: usize,
grid_w: usize,
device_theta: f64,
) -> Vec<f32> {
let head_dim = cfg.head_dim();
assert!(head_dim.is_multiple_of(4));
let dim_quarter = head_dim / 4;
let max_h = cfg.init_pos_emb_height;
let max_w = cfg.init_pos_emb_width;
let mut table = vec![0f32; max_h * max_w * (head_dim / 2) * 2];
for y in 0..max_h {
for x in 0..max_w {
let flat = y * max_w + x;
for i in 0..dim_quarter {
let freq = 1.0 / device_theta.powf((4 * i) as f64 / head_dim as f64);
let x_angle = x as f64 * freq;
let y_angle = y as f64 * freq;
let base = flat * head_dim;
table[base + 2 * i] = x_angle.cos() as f32;
table[base + 2 * i + 1] = x_angle.sin() as f32;
table[base + head_dim / 2 + 2 * i] = y_angle.cos() as f32;
table[base + head_dim / 2 + 2 * i + 1] = y_angle.sin() as f32;
}
}
}
let mut out = Vec::with_capacity(grid_h * grid_w * head_dim);
for y in 0..grid_h {
for x in 0..grid_w {
let src = (y * max_w + x) * head_dim;
out.extend_from_slice(&table[src..src + head_dim]);
}
}
out
}
pub fn rope_cos_sin_halves_for_grid(
cfg: &MoonVitConfig,
grid_h: usize,
grid_w: usize,
) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
let freqs = freqs_cis_for_grid(cfg, grid_h, grid_w, ROPE_THETA);
let seq = grid_h * grid_w;
let dh = cfg.head_dim();
let quarter = dh / 4;
let mut cos_x = vec![0f32; seq * quarter];
let mut sin_x = vec![0f32; seq * quarter];
let mut cos_y = vec![0f32; seq * quarter];
let mut sin_y = vec![0f32; seq * quarter];
for t in 0..seq {
let base = t * dh;
let ox = t * quarter;
let oy = t * quarter;
for i in 0..quarter {
cos_x[ox + i] = freqs[base + 2 * i];
sin_x[ox + i] = freqs[base + 2 * i + 1];
cos_y[oy + i] = freqs[base + dh / 2 + 2 * i];
sin_y[oy + i] = freqs[base + dh / 2 + 2 * i + 1];
}
}
(cos_x, sin_x, cos_y, sin_y)
}
pub fn apply_rope_2d(
q: &mut [f32],
k: &mut [f32],
freqs: &[f32],
seq: usize,
heads: usize,
head_dim: usize,
) {
let half = head_dim / 2;
for t in 0..seq {
let f_base = t * head_dim;
for h in 0..heads {
let q_base = (t * heads + h) * head_dim;
let k_base = q_base;
for i in 0..half / 2 {
let q0 = q[q_base + 2 * i];
let q1 = q[q_base + 2 * i + 1];
let c = freqs[f_base + 2 * i];
let s = freqs[f_base + 2 * i + 1];
q[q_base + 2 * i] = q0 * c - q1 * s;
q[q_base + 2 * i + 1] = q0 * s + q1 * c;
let k0 = k[k_base + 2 * i];
let k1 = k[k_base + 2 * i + 1];
k[k_base + 2 * i] = k0 * c - k1 * s;
k[k_base + 2 * i + 1] = k0 * s + k1 * c;
}
for i in 0..half / 2 {
let idx = half + 2 * i;
let q0 = q[q_base + idx];
let q1 = q[q_base + idx + 1];
let c = freqs[f_base + idx];
let s = freqs[f_base + idx + 1];
q[q_base + idx] = q0 * c - q1 * s;
q[q_base + idx + 1] = q0 * s + q1 * c;
let k0 = k[k_base + idx];
let k1 = k[k_base + idx + 1];
k[k_base + idx] = k0 * c - k1 * s;
k[k_base + idx + 1] = k0 * s + k1 * c;
}
}
}
}