use crate::riir::variants::{GROUP_SIZE, VARIANT};
use crate::riir::io::weight_file::WeightFile;
#[derive(Debug, thiserror::Error)]
pub enum EmbeddingError {
#[error("token id {token_id} out of vocabulary range (0..{vocab_size})")]
TokenOutOfRange { token_id: i32, vocab_size: usize },
#[error("embedding tensor '{name}' missing from manifest")]
MissingTensor { name: &'static str },
#[error(
"embedding tensor '{name}' has unexpected shape {shape:?} (expected {expected:?})"
)]
ShapeMismatch {
name: &'static str,
shape: Vec<usize>,
expected: Vec<usize>,
},
#[error("output buffer length {got} does not match HIDDEN_DIM {expected}")]
OutputLen { got: usize, expected: usize },
}
const WEIGHT_NAME: &str = "model.embed_tokens.weight";
const SCALES_NAME: &str = "model.embed_tokens.scales";
const BIASES_NAME: &str = "model.embed_tokens.biases";
pub fn embed_lookup(
wf: &WeightFile,
token_id: i32,
out: &mut [f32],
) -> Result<(), EmbeddingError> {
let hidden_dim = VARIANT.hidden_dim;
let vocab_size = VARIANT.vocab_size;
if out.len() != hidden_dim {
return Err(EmbeddingError::OutputLen {
got: out.len(),
expected: hidden_dim,
});
}
if token_id < 0 || (token_id as usize) >= vocab_size {
return Err(EmbeddingError::TokenOutOfRange {
token_id,
vocab_size,
});
}
let num_groups = hidden_dim / GROUP_SIZE;
let packed_cols = hidden_dim / 8;
let w_bytes = tensor_or_missing(wf, WEIGHT_NAME)?;
let s_bytes = tensor_or_missing(wf, SCALES_NAME)?;
let b_bytes = tensor_or_missing(wf, BIASES_NAME)?;
expect_shape(wf, WEIGHT_NAME, &[vocab_size, packed_cols])?;
expect_shape(wf, SCALES_NAME, &[vocab_size, num_groups])?;
expect_shape(wf, BIASES_NAME, &[vocab_size, num_groups])?;
embed_lookup_at(w_bytes, s_bytes, b_bytes, token_id, out);
Ok(())
}
pub fn embed_lookup_at(
weight: &[u8],
scales: &[u8],
biases: &[u8],
token_id: i32,
out: &mut [f32],
) {
let hidden_dim = out.len();
let num_groups = hidden_dim / GROUP_SIZE;
let packed_cols = hidden_dim / 8;
let group_size = GROUP_SIZE;
let packed_per_group = group_size / 8;
let token = token_id as usize;
let w_row_off = token * packed_cols * 4;
let s_row_off = token * num_groups * 2;
let b_row_off = token * num_groups * 2;
let w_row = &weight[w_row_off..w_row_off + packed_cols * 4];
let s_row = &scales[s_row_off..s_row_off + num_groups * 2];
let b_row = &biases[b_row_off..b_row_off + num_groups * 2];
for g in 0..num_groups {
let scale = bf16_to_f32(read_u16_le(s_row, g));
let bias = bf16_to_f32(read_u16_le(b_row, g));
for p in 0..packed_per_group {
let packed = read_u32_le(w_row, g * packed_per_group + p);
let base = g * group_size + p * 8;
for n in 0..8 {
let nibble = (packed >> (n * 4)) & 0xF;
out[base + n] = (nibble as f32) * scale + bias;
}
}
}
}
fn tensor_or_missing<'a>(
wf: &'a WeightFile,
name: &'static str,
) -> Result<&'a [u8], EmbeddingError> {
wf.tensor_bytes(name)
.ok_or(EmbeddingError::MissingTensor { name })
}
fn expect_shape(
wf: &WeightFile,
name: &'static str,
expected: &[usize],
) -> Result<(), EmbeddingError> {
let info = wf
.tensor_info(name)
.ok_or(EmbeddingError::MissingTensor { name })?;
if info.shape != expected {
return Err(EmbeddingError::ShapeMismatch {
name,
shape: info.shape.clone(),
expected: expected.to_vec(),
});
}
Ok(())
}
fn read_u32_le(buf: &[u8], idx: usize) -> u32 {
let off = idx * 4;
u32::from_le_bytes([buf[off], buf[off + 1], buf[off + 2], buf[off + 3]])
}
fn read_u16_le(buf: &[u8], idx: usize) -> u16 {
let off = idx * 2;
u16::from_le_bytes([buf[off], buf[off + 1]])
}
#[inline]
pub fn bf16_to_f32(bits: u16) -> f32 {
f32::from_bits((bits as u32) << 16)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bf16_round_trip_known_values() {
assert_eq!(bf16_to_f32(0x3F80), 1.0);
assert_eq!(bf16_to_f32(0xC000), -2.0);
assert_eq!(bf16_to_f32(0x0000), 0.0);
assert!(bf16_to_f32(0x7F80).is_infinite());
}
}