impl QuantizedAprTransformer {
#[must_use]
pub fn new(config: AprTransformerConfig, quant_type: AprQuantizationType) -> Self {
let hidden_dim = config.hidden_dim;
let vocab_size = config.vocab_size;
let _intermediate_dim = config.intermediate_dim;
let embed_size = vocab_size * hidden_dim; let layer_weight_size = Self::calculate_layer_bytes(&config, quant_type);
let lm_head_size = Self::calculate_quantized_bytes(hidden_dim * vocab_size, quant_type);
let layer_weights = (0..config.num_layers)
.map(|_| vec![0u8; layer_weight_size])
.collect();
Self {
config,
quant_type,
token_embedding: vec![0.0; embed_size],
layer_weights,
output_norm_weight: vec![1.0; hidden_dim],
lm_head_weight: vec![0u8; lm_head_size],
}
}
#[must_use]
pub fn from_f32_transformer(
f32_model: &AprTransformer,
quant_type: AprQuantizationType,
) -> Self {
let config = f32_model.config.clone();
Self::new(config, quant_type)
}
#[must_use]
pub fn quantization_type(&self) -> AprQuantizationType {
self.quant_type
}
#[must_use]
pub fn bits_per_weight(&self) -> f64 {
self.quant_type.bits_per_weight()
}
#[must_use]
pub fn config(&self) -> &AprTransformerConfig {
&self.config
}
#[must_use]
pub fn weight_bytes(&self) -> usize {
let embed_bytes = self.token_embedding.len() * 4; let layer_bytes: usize = self.layer_weights.iter().map(std::vec::Vec::len).sum();
let norm_bytes = self.output_norm_weight.len() * 4; let lm_head_bytes = self.lm_head_weight.len();
embed_bytes + layer_bytes + norm_bytes + lm_head_bytes
}
#[must_use]
pub fn f32_equivalent_bytes(&self) -> usize {
let num_params = self.num_parameters();
num_params * 4 }
#[must_use]
pub fn num_parameters(&self) -> usize {
let hidden = self.config.hidden_dim;
let vocab = self.config.vocab_size;
let layers = self.config.num_layers;
let intermediate = self.config.intermediate_dim;
let embed_params = vocab * hidden * 2;
let layer_params = hidden
+ (hidden * 3 * hidden)
+ (hidden * hidden)
+ (hidden * intermediate)
+ (intermediate * hidden);
let norm_params = hidden;
embed_params + (layers * layer_params) + norm_params
}
fn calculate_layer_bytes(
config: &AprTransformerConfig,
quant_type: AprQuantizationType,
) -> usize {
let hidden = config.hidden_dim;
let intermediate = config.intermediate_dim;
let weight_elements = (hidden * 3 * hidden)
+ (hidden * hidden)
+ (hidden * intermediate)
+ (intermediate * hidden);
Self::calculate_quantized_bytes(weight_elements, quant_type)
}
pub(crate) fn calculate_quantized_bytes(
num_elements: usize,
quant_type: AprQuantizationType,
) -> usize {
let values_per_block = quant_type.values_per_block();
let bytes_per_block = quant_type.bytes_per_block();
let num_blocks = num_elements.div_ceil(values_per_block);
num_blocks * bytes_per_block
}
pub fn forward(&self, token_ids: &[u32]) -> Result<Vec<f32>> {
if token_ids.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Token sequence cannot be empty".to_string(),
});
}
let hidden_dim = self.config.hidden_dim;
let _vocab_size = self.config.vocab_size;
let mut hidden = Vec::with_capacity(token_ids.len() * hidden_dim);
for &token_id in token_ids {
let offset = (token_id as usize) * hidden_dim;
if offset + hidden_dim <= self.token_embedding.len() {
hidden.extend_from_slice(&self.token_embedding[offset..offset + hidden_dim]);
} else {
hidden.extend(std::iter::repeat_n(0.0, hidden_dim));
}
}
for _layer_weights in &self.layer_weights {
}
let seq_len = token_ids.len();
let eps = self.config.eps;
let mut normed = Vec::with_capacity(hidden.len());
for s in 0..seq_len {
let start = s * hidden_dim;
let slice = &hidden[start..start + hidden_dim];
let mean: f32 = slice.iter().sum::<f32>() / hidden_dim as f32;
let variance: f32 =
slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / hidden_dim as f32;
let std_dev = (variance + eps).sqrt();
for (i, &x) in slice.iter().enumerate() {
let normalized = (x - mean) / std_dev;
normed.push(normalized * self.output_norm_weight[i]);
}
}
let last_hidden_start = (seq_len - 1) * hidden_dim;
let last_hidden = &normed[last_hidden_start..last_hidden_start + hidden_dim];
let logits = self.compute_lm_head_logits(last_hidden)?;
Ok(logits)
}
fn compute_lm_head_logits(&self, _hidden: &[f32]) -> Result<Vec<f32>> {
let vocab_size = self.config.vocab_size;
let _hidden_dim = self.config.hidden_dim;
let logits = vec![0.0f32; vocab_size];
match self.quant_type {
AprQuantizationType::F32 => {
},
AprQuantizationType::Q4_K => {
},
AprQuantizationType::Q8_0 => {
},
}
Ok(logits)
}
pub fn to_bytes(&self) -> Result<Vec<u8>> {
let mut bytes = Vec::new();
bytes.extend_from_slice(&MAGIC);
bytes.extend_from_slice(&1u32.to_le_bytes());
bytes.extend_from_slice(&(self.config.hidden_dim as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.num_layers as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.num_heads as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.num_kv_heads as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.vocab_size as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.intermediate_dim as u32).to_le_bytes());
bytes.extend_from_slice(&(self.config.context_length as u32).to_le_bytes());
bytes.extend_from_slice(&self.config.rope_theta.to_le_bytes());
bytes.extend_from_slice(&self.config.eps.to_le_bytes());
let tensor_offset = APR_TRANSFORMER_HEADER_SIZE as u32;
bytes.extend_from_slice(&tensor_offset.to_le_bytes());
bytes.push(self.quant_type.to_byte());
while bytes.len() < APR_TRANSFORMER_HEADER_SIZE {
bytes.push(0);
}
for &v in &self.token_embedding {
bytes.extend_from_slice(&v.to_le_bytes());
}
for layer in &self.layer_weights {
bytes.extend_from_slice(layer);
}
for &v in &self.output_norm_weight {
bytes.extend_from_slice(&v.to_le_bytes());
}
bytes.extend_from_slice(&self.lm_head_weight);
Ok(bytes)
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
if data.len() < APR_TRANSFORMER_HEADER_SIZE {
return Err(RealizarError::FormatError {
reason: format!("Data too small: {} bytes", data.len()),
});
}
if data[0..4] != MAGIC {
return Err(RealizarError::FormatError {
reason: "Invalid APR magic".to_string(),
});
}
let hidden_dim = u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize;
let num_layers = u32::from_le_bytes([data[12], data[13], data[14], data[15]]) as usize;
let num_heads = u32::from_le_bytes([data[16], data[17], data[18], data[19]]) as usize;
let num_kv_heads = u32::from_le_bytes([data[20], data[21], data[22], data[23]]) as usize;
let vocab_size = u32::from_le_bytes([data[24], data[25], data[26], data[27]]) as usize;
let intermediate_dim =
u32::from_le_bytes([data[28], data[29], data[30], data[31]]) as usize;
let context_length = u32::from_le_bytes([data[32], data[33], data[34], data[35]]) as usize;
let rope_theta = f32::from_le_bytes([data[36], data[37], data[38], data[39]]);
let eps = f32::from_le_bytes([data[40], data[41], data[42], data[43]]);
let quant_type =
AprQuantizationType::from_byte(data[48]).ok_or_else(|| RealizarError::FormatError {
reason: format!("Invalid quantization type: {}", data[48]),
})?;
let config = AprTransformerConfig {
architecture: "apr".to_string(),
hidden_dim,
num_layers,
num_heads,
num_kv_heads,
vocab_size,
intermediate_dim,
context_length,
rope_theta,
eps,
eos_token_id: None, ..Default::default()
};
Ok(Self::new(config, quant_type))
}
pub fn forward_with_cache(
&self,
token_id: u32,
cache: &mut AprKVCache,
_position: usize,
) -> Result<Vec<f32>> {
let hidden_dim = self.config.hidden_dim;
let num_heads = self.config.num_heads;
let num_kv_heads = self.config.num_kv_heads;
let head_dim = hidden_dim / num_heads;
let mut hidden = Vec::with_capacity(hidden_dim);
let offset = (token_id as usize) * hidden_dim;
if offset + hidden_dim <= self.token_embedding.len() {
hidden.extend_from_slice(&self.token_embedding[offset..offset + hidden_dim]);
} else {
hidden.extend(std::iter::repeat_n(0.0, hidden_dim));
}
for layer_idx in 0..self.config.num_layers {
let kv_size = num_kv_heads * head_dim;
let k = vec![0.0f32; kv_size];
let v = vec![0.0f32; kv_size];
cache.append(layer_idx, &k, &v);
}
let eps = self.config.eps;
let mean: f32 = hidden.iter().sum::<f32>() / hidden_dim as f32;
let variance: f32 =
hidden.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / hidden_dim as f32;
let std_dev = (variance + eps).sqrt();
let mut normed = Vec::with_capacity(hidden_dim);
for (i, &x) in hidden.iter().enumerate() {
let normalized = (x - mean) / std_dev;
normed.push(normalized * self.output_norm_weight[i]);
}
let logits = self.compute_lm_head_logits(&normed)?;
Ok(logits)
}
}