use candle_core::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Linear, Module, VarBuilder};
use crate::layers::attention::GqaConfig;
use crate::layers::transformer::TransformerBlock;
use crate::tensor_utils::{precompute_rope_freqs, RmsNorm};
use super::config::CodePredictorConfig;
pub struct CodePredictor {
input_proj: Linear,
group_embeds: Vec<Embedding>,
layers: Vec<TransformerBlock>,
norm: RmsNorm,
group_heads: Vec<Linear>,
rope_cos: Tensor,
rope_sin: Tensor,
num_groups: usize,
config: CodePredictorConfig,
}
impl CodePredictor {
pub fn load(
config: &CodePredictorConfig,
talker_hidden_size: usize,
vb: VarBuilder,
device: &Device,
dtype: DType,
) -> Result<Self> {
let rope_theta = 1_000_000.0;
let max_position_embeddings = 128;
let gqa_config = GqaConfig::with_head_dim(
config.hidden_size,
config.num_attention_heads,
config.num_key_value_heads,
config.head_dim,
max_position_embeddings,
rope_theta,
1e-6,
);
let input_proj = candle_nn::linear(
talker_hidden_size,
config.hidden_size,
vb.pp("small_to_mtp_projection"),
)?;
let num_extra_groups = config.num_code_groups - 1;
let model_vb = vb.pp("model");
let mut group_embeds = Vec::with_capacity(num_extra_groups);
for g in 0..num_extra_groups {
let emb = candle_nn::embedding(
config.vocab_size,
talker_hidden_size,
model_vb.pp(format!("codec_embedding.{}", g)),
)?;
group_embeds.push(emb);
}
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for i in 0..config.num_hidden_layers {
let block = TransformerBlock::load(
&gqa_config,
config.intermediate_size,
model_vb.pp(format!("layers.{}", i)),
)?;
layers.push(block);
}
let norm = RmsNorm::load(config.hidden_size, 1e-6, model_vb.pp("norm"))?;
let mut group_heads = Vec::with_capacity(num_extra_groups);
for g in 0..num_extra_groups {
let head = candle_nn::linear_no_bias(
config.hidden_size,
config.vocab_size,
vb.pp(format!("lm_head.{}", g)),
)?;
group_heads.push(head);
}
let (rope_cos, rope_sin) = precompute_rope_freqs(
config.head_dim,
max_position_embeddings,
rope_theta,
device,
dtype,
)?;
Ok(Self {
input_proj,
group_embeds,
layers,
norm,
group_heads,
rope_cos,
rope_sin,
num_groups: config.num_code_groups,
config: config.clone(),
})
}
pub fn predict_step_and_sum(
&mut self,
past_hidden: &Tensor,
_g0_token: u32,
g0_embed: &Tensor,
device: &Device,
) -> Result<(Tensor, Vec<u32>)> {
let num_extra_groups = self.num_groups - 1;
let mut predicted_tokens: Vec<u32> = Vec::with_capacity(num_extra_groups);
let mut summed = g0_embed.clone();
for layer in &mut self.layers {
layer.clear_cache();
}
let past_hidden_3d = past_hidden.unsqueeze(1)?; let prefill = Tensor::cat(&[&past_hidden_3d, g0_embed], 1)?;
let prefill = self.input_proj.forward(&prefill)?;
let mut hidden = prefill;
let prefill_mask = {
let mask_data = vec![0.0f32, f32::NEG_INFINITY, 0.0f32, 0.0f32];
let mask = Tensor::from_vec(mask_data, (2, 2), device)?.to_dtype(hidden.dtype())?;
mask.unsqueeze(0)?.unsqueeze(0)?
};
for layer in &mut self.layers {
hidden = layer.forward(
&hidden,
&self.rope_cos,
&self.rope_sin,
0,
Some(&prefill_mask),
)?;
}
let hidden = self.norm.forward(&hidden)?;
let mut pos = 2usize;
let last_h = hidden.i((.., hidden.dims()[1] - 1.., ..))?; let logits = self.group_heads[0].forward(&last_h)?;
let mut prev_predicted_token = sample_top_k_top_p(&logits, 50, 0.8, device)?;
predicted_tokens.push(prev_predicted_token);
let tok_tensor = Tensor::new(&[prev_predicted_token], device)?.unsqueeze(0)?;
let tok_embed = self.group_embeds[0].forward(&tok_tensor)?; summed = summed.add(&tok_embed)?;
for gen_step in 1..num_extra_groups {
let prev_tok_tensor = Tensor::new(&[prev_predicted_token], device)?.unsqueeze(0)?;
let embed_idx = (gen_step - 1).min(self.group_embeds.len() - 1);
let prev_embed = self.group_embeds[embed_idx].forward(&prev_tok_tensor)?;
let step_input = self.input_proj.forward(&prev_embed)?;
let mut h = step_input;
for layer in &mut self.layers {
h = layer.forward(&h, &self.rope_cos, &self.rope_sin, pos, None)?;
}
let h = self.norm.forward(&h)?;
pos += 1;
let logits = self.group_heads[gen_step].forward(&h)?;
let predicted_token = sample_top_k_top_p(&logits, 50, 0.8, device)?;
predicted_tokens.push(predicted_token);
let pred_tensor = Tensor::new(&[predicted_token], device)?.unsqueeze(0)?;
let pred_embed_idx = gen_step.min(self.group_embeds.len() - 1);
let pred_embed = self.group_embeds[pred_embed_idx].forward(&pred_tensor)?;
summed = summed.add(&pred_embed)?;
prev_predicted_token = predicted_token;
}
for layer in &mut self.layers {
layer.clear_cache();
}
Ok((summed, predicted_tokens))
}
}
impl std::fmt::Debug for CodePredictor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CodePredictor")
.field("num_groups", &self.num_groups)
.field("num_layers", &self.layers.len())
.field("hidden_size", &self.config.hidden_size)
.finish()
}
}
fn sample_top_k_top_p(logits: &Tensor, top_k: usize, top_p: f32, _device: &Device) -> Result<u32> {
let flat: Vec<f32> = logits
.to_dtype(candle_core::DType::F32)?
.flatten_all()?
.to_vec1()?;
let vocab_size = flat.len();
let mut indexed: Vec<(usize, f32)> = flat.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let k = top_k.min(vocab_size);
let truncated = &indexed[..k];
let max_logit = truncated[0].1;
let exp_vals: Vec<f32> = truncated
.iter()
.map(|(_, v)| (v - max_logit).exp())
.collect();
let sum_exp: f32 = exp_vals.iter().sum();
let mut probs: Vec<(usize, f32)> = truncated
.iter()
.zip(exp_vals.iter())
.map(|((idx, _), &e)| (*idx, e / sum_exp))
.collect();
let mut cumsum = 0.0f32;
let mut cutoff = probs.len();
for (i, &(_, p)) in probs.iter().enumerate() {
cumsum += p;
if cumsum >= top_p {
cutoff = i + 1;
break;
}
}
probs.truncate(cutoff);
let total: f32 = probs.iter().map(|(_, p)| p).sum();
for entry in &mut probs {
entry.1 /= total;
}
let r: f32 = cp_rand_uniform();
let mut cumsum = 0.0;
for &(idx, p) in &probs {
cumsum += p;
if cumsum >= r {
return Ok(idx as u32);
}
}
Ok(probs.last().map(|&(idx, _)| idx as u32).unwrap_or(0))
}
fn cp_rand_uniform() -> f32 {
use std::sync::Mutex;
use std::time::SystemTime;
static STATE: Mutex<Option<[u64; 4]>> = Mutex::new(None);
let mut guard = STATE.lock().unwrap();
let s = guard.get_or_insert_with(|| {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0xcafebabe);
let mut seed = [
now ^ 0xabcdef0123456789,
now.wrapping_mul(6364136223846793005),
!now ^ 0x9876543210fedcba,
now.rotate_left(17) ^ 0x1111111111111111,
];
for _ in 0..8 {
let t = seed[1] << 17;
seed[2] ^= seed[0];
seed[3] ^= seed[1];
seed[1] ^= seed[2];
seed[0] ^= seed[3];
seed[2] ^= t;
seed[3] = seed[3].rotate_left(45);
}
seed
});
let result = (s[0].wrapping_add(s[3])).rotate_left(23).wrapping_add(s[0]);
let t = s[1] << 17;
s[2] ^= s[0];
s[3] ^= s[1];
s[1] ^= s[2];
s[0] ^= s[3];
s[2] ^= t;
s[3] = s[3].rotate_left(45);
(result >> 40) as f32 / (1u64 << 24) as f32
}