use super::config::DinoV2Config;
use anyhow::{Result, ensure};
use rlx_core::weight_map::WeightMap;
pub struct DinoV2PreprocessWeights {
pub proj_w: Vec<f32>,
pub proj_b: Vec<f32>,
pub cls_token: Vec<f32>,
pub register_tokens: Vec<f32>,
pub pos_embed: Vec<f32>,
pub embed_dim: usize,
pub patch_dim: usize,
pub num_patches: usize,
pub num_register_tokens: usize,
pub seq: usize,
}
pub(super) fn extract_preprocess_weights(
weights: &mut WeightMap,
cfg: &DinoV2Config,
) -> Result<DinoV2PreprocessWeights> {
let embed_dim = cfg.hidden_size;
let patch_dim = cfg.patch_dim();
let num_patches = cfg.num_patches();
let seq = cfg.seq_len();
let (proj_raw, proj_shape) = weights.take("patch_embed.proj.weight")?;
ensure!(
proj_shape.len() == 4
&& proj_shape[0] == embed_dim
&& proj_shape[1] * proj_shape[2] * proj_shape[3] == patch_dim,
"patch_embed.proj.weight expected [E={embed_dim}, 3, ps, ps] (patch_dim={patch_dim}), got {proj_shape:?}"
);
let mut proj_w = vec![0f32; embed_dim * patch_dim];
for e in 0..embed_dim {
for d in 0..patch_dim {
proj_w[d * embed_dim + e] = proj_raw[e * patch_dim + d];
}
}
let (proj_b, _) = weights.take("patch_embed.proj.bias")?;
let (cls_token, _) = weights.take("cls_token")?;
let (pos_embed, pos_shape) = weights.take("pos_embed")?;
ensure!(
pos_embed.len() == seq * embed_dim,
"pos_embed length {} does not match seq*E ({}*{}) — interpolation of \
pretrained pos_embed not yet supported; use cfg.img_size matching the checkpoint. \
shape={pos_shape:?}",
pos_embed.len(),
seq,
embed_dim
);
let register_tokens = if cfg.num_register_tokens > 0 {
let (data, shape) = weights.take("register_tokens")?;
ensure!(
shape.len() == 3 && shape[1] == cfg.num_register_tokens && shape[2] == embed_dim,
"register_tokens expected [1, {n}, {embed_dim}], got {shape:?}",
n = cfg.num_register_tokens
);
data
} else {
Vec::new()
};
Ok(DinoV2PreprocessWeights {
proj_w,
proj_b,
cls_token,
register_tokens,
pos_embed,
embed_dim,
patch_dim,
num_patches,
num_register_tokens: cfg.num_register_tokens,
seq,
})
}
pub fn assemble_hidden(
pre: &DinoV2PreprocessWeights,
image: &[f32],
batch: usize,
patch_size: usize,
img_size: usize,
) -> Result<Vec<f32>> {
let e = pre.embed_dim;
let np = pre.num_patches;
let seq = pre.seq;
let pd = pre.patch_dim;
let n_side = img_size / patch_size;
ensure!(
image.len() == batch * 3 * img_size * img_size,
"image length {} != B·3·H·W ({}·3·{}·{})",
image.len(),
batch,
img_size,
img_size
);
ensure!(
np == n_side * n_side,
"num_patches mismatch — img_size/patch_size inconsistent"
);
let mut hidden = vec![0f32; batch * seq * e];
for b in 0..batch {
let img_off = b * 3 * img_size * img_size;
let out_off = b * seq * e;
hidden[out_off..out_off + e].copy_from_slice(&pre.cls_token);
if pre.num_register_tokens > 0 {
let dst = &mut hidden[out_off + e..out_off + e * (1 + pre.num_register_tokens)];
dst.copy_from_slice(&pre.register_tokens);
}
let patch_row_base = 1 + pre.num_register_tokens;
let mut patch_buf = vec![0f32; pd];
for py in 0..n_side {
for px in 0..n_side {
for c in 0..3 {
for ry in 0..patch_size {
let src_y = py * patch_size + ry;
for rx in 0..patch_size {
let src_x = px * patch_size + rx;
let src_idx =
img_off + c * img_size * img_size + src_y * img_size + src_x;
let dst_idx = c * patch_size * patch_size + ry * patch_size + rx;
patch_buf[dst_idx] = image[src_idx];
}
}
}
let row = patch_row_base + py * n_side + px;
let out_slice = &mut hidden[out_off + row * e..out_off + (row + 1) * e];
out_slice.copy_from_slice(&pre.proj_b);
for d in 0..pd {
let v = patch_buf[d];
if v == 0.0 {
continue;
}
let w_row = &pre.proj_w[d * e..(d + 1) * e];
for k in 0..e {
out_slice[k] += v * w_row[k];
}
}
}
}
for i in 0..seq * e {
hidden[out_off + i] += pre.pos_embed[i];
}
}
Ok(hidden)
}
pub fn rgb_u8_to_imagenet_nchw(rgb: &[u8], h_in: usize, w_in: usize, img_size: usize) -> Vec<f32> {
use super::config::{IMAGENET_MEAN, IMAGENET_STD};
let mut out = vec![0f32; 3 * img_size * img_size];
let sx = (w_in as f32 - 1.0) / (img_size.max(1) as f32 - 1.0).max(1.0);
let sy = (h_in as f32 - 1.0) / (img_size.max(1) as f32 - 1.0).max(1.0);
for y in 0..img_size {
let fy = y as f32 * sy;
let y0 = fy.floor() as usize;
let y1 = (y0 + 1).min(h_in - 1);
let dy = fy - y0 as f32;
for x in 0..img_size {
let fx = x as f32 * sx;
let x0 = fx.floor() as usize;
let x1 = (x0 + 1).min(w_in - 1);
let dx = fx - x0 as f32;
for c in 0..3 {
let p00 = rgb[(y0 * w_in + x0) * 3 + c] as f32;
let p01 = rgb[(y0 * w_in + x1) * 3 + c] as f32;
let p10 = rgb[(y1 * w_in + x0) * 3 + c] as f32;
let p11 = rgb[(y1 * w_in + x1) * 3 + c] as f32;
let top = p00 * (1.0 - dx) + p01 * dx;
let bot = p10 * (1.0 - dx) + p11 * dx;
let v = (top * (1.0 - dy) + bot * dy) / 255.0;
let norm = (v - IMAGENET_MEAN[c]) / IMAGENET_STD[c];
out[c * img_size * img_size + y * img_size + x] = norm;
}
}
}
out
}