use super::voxtral_llama::{VoxtralLlama, VoxtralLlamaCache, VoxtralLlamaConfig};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{
layer_norm, linear, linear_no_bias, Conv1d, Dropout, LayerNorm, Linear, VarBuilder,
};
use rand::Rng;
#[derive(Debug, Clone)]
pub struct VoxtralEncoderConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub head_dim: usize,
pub scale_embedding: bool,
pub activation_function: String,
pub num_mel_bins: usize,
pub max_source_positions: usize,
pub initializer_range: f64,
pub attention_dropout: f64,
pub dropout: f64,
pub layerdrop: f64,
pub activation_dropout: f64,
}
#[derive(Debug, Clone)]
pub struct VoxtralConfig {
pub audio_config: VoxtralEncoderConfig,
pub text_config: VoxtralLlamaConfig,
pub audio_token_id: usize,
pub projector_hidden_act: String,
}
impl Default for VoxtralConfig {
fn default() -> Self {
Self {
audio_config: VoxtralEncoderConfig::default(),
text_config: VoxtralLlamaConfig::voxtral_3b(),
audio_token_id: 24,
projector_hidden_act: "gelu".to_string(),
}
}
}
impl Default for VoxtralEncoderConfig {
fn default() -> Self {
Self {
vocab_size: 51866,
hidden_size: 1280,
intermediate_size: 5120,
num_hidden_layers: 32,
num_attention_heads: 20,
num_key_value_heads: 20,
head_dim: 64,
scale_embedding: false,
activation_function: "gelu".to_string(),
num_mel_bins: 128,
max_source_positions: 1500,
initializer_range: 0.02,
attention_dropout: 0.0,
dropout: 0.0,
layerdrop: 0.0,
activation_dropout: 0.0,
}
}
}
impl VoxtralEncoderConfig {
pub fn with_whisper_compatibility(mut self) -> Self {
self.dropout = 0.0;
self.layerdrop = 0.0;
self.activation_dropout = 0.0;
self
}
}
#[derive(Debug, Clone)]
pub struct VoxtralCache {
cache: VoxtralLlamaCache,
audio_processed: bool,
cached_audio_embeds: Option<Tensor>,
cached_audio_positions: Option<Vec<(usize, usize)>>,
}
#[derive(Debug, Clone)]
pub struct VoxtralGenerationConfig {
pub max_new_tokens: usize,
pub temperature: f64,
pub top_p: Option<f64>,
pub device: Device,
pub cache: Option<VoxtralCache>,
}
impl VoxtralGenerationConfig {
pub fn new(device: Device) -> Self {
Self {
max_new_tokens: 500,
temperature: 0.0,
top_p: None,
device,
cache: None,
}
}
}
impl VoxtralCache {
pub fn new(
use_kv_cache: bool,
dtype: DType,
config: &VoxtralLlamaConfig,
device: &Device,
) -> Result<Self> {
Ok(Self {
cache: VoxtralLlamaCache::new(use_kv_cache, dtype, config, device)?,
audio_processed: false,
cached_audio_embeds: None,
cached_audio_positions: None,
})
}
pub fn reset(&mut self) {
self.audio_processed = false;
self.cached_audio_embeds = None;
self.cached_audio_positions = None;
}
}
fn safe_clamp(x: &Tensor) -> Result<Tensor> {
match x.dtype() {
DType::F16 => {
let max_val = 64504.0;
x.clamp(-max_val, max_val)
}
DType::BF16 => {
Ok(x.clone())
}
_ => Ok(x.clone()),
}
}
pub fn replace_audio_tokens(
inputs_embeds: &Tensor,
audio_embeds: &Tensor,
audio_positions: &[(usize, usize)],
device: &Device,
) -> Result<Tensor> {
if audio_positions.is_empty() {
return Ok(inputs_embeds.clone());
}
let (batch_size, seq_len, hidden_size) = inputs_embeds.dims3()?;
let num_audio_tokens = audio_positions.len();
let audio_embeds_dims = audio_embeds.dims2()?;
let total_audio_embeds = audio_embeds_dims.0;
let audio_embeds = if total_audio_embeds >= num_audio_tokens {
if num_audio_tokens == total_audio_embeds {
audio_embeds.clone()
} else {
audio_embeds.i(0..num_audio_tokens)?
}
} else {
candle::bail!(
"Not enough audio embeddings: need {}, got {}. Input sequence should have {} audio tokens.",
num_audio_tokens,
total_audio_embeds,
total_audio_embeds
);
};
let mut result = inputs_embeds.clone();
for (idx, &(batch_idx, seq_idx)) in audio_positions.iter().enumerate() {
if batch_idx >= batch_size || seq_idx >= seq_len {
candle::bail!(
"Invalid audio position: ({}, {}) for tensor shape ({}, {}, {})",
batch_idx,
seq_idx,
batch_size,
seq_len,
hidden_size
);
}
let audio_embed = audio_embeds.i(idx)?;
let mut position_mask = vec![0f32; batch_size * seq_len];
position_mask[batch_idx * seq_len + seq_idx] = 1.0;
let position_mask = Tensor::new(position_mask.as_slice(), device)?
.reshape((batch_size, seq_len, 1))?
.to_dtype(inputs_embeds.dtype())?;
let audio_embed_broadcast = audio_embed.unsqueeze(0)?.unsqueeze(0)?.broadcast_as((
batch_size,
seq_len,
hidden_size,
))?;
let inverse_mask = (1.0 - &position_mask)?;
result = (result.broadcast_mul(&inverse_mask)?
+ audio_embed_broadcast.broadcast_mul(&position_mask)?)?;
}
Ok(result)
}
pub fn find_audio_token_positions(
input_ids: &Tensor,
audio_token_id: usize,
) -> Result<Vec<(usize, usize)>> {
let input_ids = if input_ids.dtype() == candle::DType::U32 {
input_ids.to_dtype(candle::DType::I64)?
} else {
input_ids.clone()
};
let input_ids = input_ids.to_vec2::<i64>()?;
let mut positions = Vec::new();
for (batch_idx, sequence) in input_ids.iter().enumerate() {
for (seq_idx, &token_id) in sequence.iter().enumerate() {
if token_id as usize == audio_token_id {
positions.push((batch_idx, seq_idx));
}
}
}
Ok(positions)
}
#[derive(Debug, Clone)]
struct VoxtralAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
num_heads: usize,
head_dim: usize,
scaling: f64,
attention_dropout: Dropout,
}
impl VoxtralAttention {
fn new(cfg: &VoxtralEncoderConfig, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let head_dim = embed_dim / num_heads;
if head_dim * num_heads != embed_dim {
candle::bail!(
"embed_dim must be divisible by num_heads ({} % {} != 0)",
embed_dim,
num_heads
);
}
let scaling = (head_dim as f64).powf(-0.5);
let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
let k_proj = linear_no_bias(embed_dim, embed_dim, vb.pp("k_proj"))?;
let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
let attention_dropout = Dropout::new(cfg.attention_dropout as f32);
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
num_heads,
head_dim,
scaling,
attention_dropout,
})
}
fn reshape_for_scores(&self, x: &Tensor, seq_len: usize, bsz: usize) -> Result<Tensor> {
x.reshape((bsz, seq_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()
}
}
impl Module for VoxtralAttention {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (bsz, seq_len, _) = x.dims3()?;
let q = (self.q_proj.forward(x)? * self.scaling)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = self.reshape_for_scores(&q, seq_len, bsz)?;
let k = self.reshape_for_scores(&k, seq_len, bsz)?;
let v = self.reshape_for_scores(&v, seq_len, bsz)?;
let scores = q.matmul(&k.transpose(D::Minus2, D::Minus1)?)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&scores)?;
let attn_weights = self.attention_dropout.forward(&attn_weights, false)?;
let attn_output = attn_weights.matmul(&v)?;
let attn_output = attn_output.transpose(1, 2)?.contiguous()?.reshape((
bsz,
seq_len,
self.num_heads * self.head_dim,
))?;
self.out_proj.forward(&attn_output)
}
}
#[derive(Debug, Clone)]
struct VoxtralEncoderLayer {
self_attn: VoxtralAttention,
self_attn_layer_norm: LayerNorm,
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
activation: candle_nn::Activation,
dropout: Dropout,
activation_dropout: Dropout,
}
impl VoxtralEncoderLayer {
fn new(cfg: &VoxtralEncoderConfig, vb: VarBuilder) -> Result<Self> {
let embed_dim = cfg.hidden_size;
let self_attn = VoxtralAttention::new(cfg, vb.pp("self_attn"))?;
let self_attn_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("self_attn_layer_norm"))?;
let fc1 = linear(embed_dim, cfg.intermediate_size, vb.pp("fc1"))?;
let fc2 = linear(cfg.intermediate_size, embed_dim, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
let activation = match cfg.activation_function.as_str() {
"gelu" => candle_nn::Activation::Gelu,
"relu" => candle_nn::Activation::Relu,
_ => candle::bail!(
"Unsupported activation function: {}",
cfg.activation_function
),
};
let dropout = Dropout::new(cfg.dropout as f32);
let activation_dropout = Dropout::new(cfg.activation_dropout as f32);
Ok(Self {
self_attn,
self_attn_layer_norm,
fc1,
fc2,
final_layer_norm,
activation,
dropout,
activation_dropout,
})
}
pub fn get_fc1_out_dim(&self) -> usize {
self.fc1.weight().dims()[0]
}
fn forward(&self, x: &Tensor, training: bool) -> Result<Tensor> {
let residual = x;
let x = self.self_attn_layer_norm.forward(x)?;
let x = self.self_attn.forward(&x)?;
let x = self.dropout.forward(&x, training)?;
let x = (x + residual)?;
let residual = &x;
let x = self.final_layer_norm.forward(&x)?;
let x = self.fc1.forward(&x)?;
let x = x.apply(&self.activation)?;
let x = self.activation_dropout.forward(&x, training)?;
let x = self.fc2.forward(&x)?;
let x = self.dropout.forward(&x, training)?;
let x = (x + residual)?;
safe_clamp(&x)
}
}
#[derive(Debug, Clone)]
pub struct VoxtralEncoder {
conv1: Conv1d,
conv2: Conv1d,
embed_positions: Tensor,
layers: Vec<VoxtralEncoderLayer>,
layer_norm: LayerNorm,
dropout: Dropout,
layerdrop: f64,
}
impl VoxtralEncoder {
pub fn new(cfg: &VoxtralEncoderConfig, vb: VarBuilder) -> Result<Self> {
let cfg = cfg.clone().with_whisper_compatibility();
let embed_dim = cfg.hidden_size;
let conv1 = candle_nn::conv1d(
cfg.num_mel_bins,
embed_dim,
3,
candle_nn::Conv1dConfig {
padding: 1,
..Default::default()
},
vb.pp("conv1"),
)?;
let conv2 = candle_nn::conv1d(
embed_dim,
embed_dim,
3,
candle_nn::Conv1dConfig {
stride: 2,
padding: 1,
..Default::default()
},
vb.pp("conv2"),
)?;
let embed_positions = vb.get(
(cfg.max_source_positions, embed_dim),
"embed_positions.weight",
)?;
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
layers.push(VoxtralEncoderLayer::new(
&cfg,
vb.pp(format!("layers.{i}")),
)?);
}
let layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("layer_norm"))?;
let dropout = Dropout::new(cfg.dropout as f32);
Ok(Self {
conv1,
conv2,
embed_positions,
layers,
layer_norm,
dropout,
layerdrop: cfg.layerdrop,
})
}
pub fn forward(&self, input_features: &Tensor) -> Result<Tensor> {
self.forward_with_training(input_features, false)
}
pub fn forward_with_training(&self, input_features: &Tensor, training: bool) -> Result<Tensor> {
let expected_dtype = self.conv1.weight().dtype();
let input_features = if input_features.dtype() != expected_dtype {
input_features.to_dtype(expected_dtype)?
} else {
input_features.clone()
};
let x = if false {
let conv1_weight_f32 = self.conv1.weight().to_dtype(DType::F32)?;
let conv1_bias_f32 = if let Some(bias) = self.conv1.bias() {
Some(bias.to_dtype(DType::F32)?)
} else {
None
};
let mut conv_result = input_features.conv1d(&conv1_weight_f32, 1, 1, 1, 1)?;
if let Some(bias) = conv1_bias_f32 {
conv_result = conv_result.broadcast_add(&bias.unsqueeze(0)?.unsqueeze(2)?)?;
}
conv_result
} else {
self.conv1.forward(&input_features)?
};
let x = x.gelu()?;
let x = if false {
let conv2_weight_f32 = self.conv2.weight().to_dtype(DType::F32)?;
let conv2_bias_f32 = if let Some(bias) = self.conv2.bias() {
Some(bias.to_dtype(DType::F32)?)
} else {
None
};
let mut conv_result = x.conv1d(&conv2_weight_f32, 2, 1, 1, 1)?;
if let Some(bias) = conv2_bias_f32 {
conv_result = conv_result.broadcast_add(&bias.unsqueeze(0)?.unsqueeze(2)?)?;
}
conv_result
} else {
self.conv2.forward(&x)?
};
let x = x.gelu()?;
let x = x.transpose(1, 2)?;
let seq_len = x.dim(1)?;
let positions = self.embed_positions.i(..seq_len)?;
let x = if false {
let x_f32 = x.to_dtype(candle::DType::F32)?;
let positions_f32 = positions.to_dtype(candle::DType::F32)?;
x_f32.broadcast_add(&positions_f32)? } else if x.dtype() != positions.dtype() {
let x_f32 = x.to_dtype(candle::DType::F32)?;
let result_f32 = x_f32.broadcast_add(&positions)?;
result_f32.to_dtype(x.dtype())?
} else {
x.broadcast_add(&positions)?
};
let mut x = self.dropout.forward(&x, training)?;
for (idx, layer) in self.layers.iter().enumerate() {
x = self.forward_layer_with_dropout(&x, layer, idx, training)?;
}
let x = self.layer_norm.forward(&x)?;
Ok(x)
}
fn forward_layer_with_dropout(
&self,
x: &Tensor,
layer: &VoxtralEncoderLayer,
_layer_idx: usize,
training: bool,
) -> Result<Tensor> {
if training && self.layerdrop > 0.0 {
let mut rng = rand::rng();
let keep_prob = 1.0 - self.layerdrop;
let keep: bool = rng.random::<f64>() < keep_prob;
if !keep {
return Ok(x.clone());
}
}
layer.forward(x, training)
}
pub fn get_intermediate_size(&self) -> usize {
if !self.layers.is_empty() {
self.layers[0].get_fc1_out_dim()
} else {
5120 }
}
pub fn process_long_audio(
&self,
input_features: &Tensor,
chunk_size: usize,
overlap: usize,
) -> Result<Tensor> {
let (_batch_size, _num_mel, seq_len) = input_features.dims3()?;
if seq_len <= chunk_size {
return self.forward(input_features);
}
let mut outputs = Vec::new();
let step = chunk_size - overlap;
for start in (0..seq_len).step_by(step) {
let end = (start + chunk_size).min(seq_len);
let chunk = input_features.i((.., .., start..end))?;
let output = self.forward(&chunk)?;
if !outputs.is_empty() && overlap > 0 {
let overlap_frames = overlap / 2; let last_output: &mut Tensor = outputs.last_mut().unwrap();
let last_len = last_output.dim(1)?;
let overlap_start = last_len.saturating_sub(overlap_frames);
let overlap_new = output.i((.., ..overlap_frames, ..))?;
let overlap_old = last_output.i((.., overlap_start.., ..))?;
let averaged = ((overlap_old + overlap_new)? * 0.5)?;
*last_output =
Tensor::cat(&[&last_output.i((.., ..overlap_start, ..))?, &averaged], 1)?;
outputs.push(output.i((.., overlap_frames.., ..))?);
} else {
outputs.push(output);
}
}
let outputs_ref: Vec<&Tensor> = outputs.iter().collect();
Tensor::cat(&outputs_ref, 1)
}
}
#[derive(Debug, Clone)]
pub struct VoxtralMultiModalProjector {
linear_1: Linear,
linear_2: Linear,
activation: candle_nn::Activation,
}
impl VoxtralMultiModalProjector {
pub fn new(cfg: &VoxtralConfig, vb: VarBuilder) -> Result<Self> {
let linear_1 = linear_no_bias(
cfg.audio_config.intermediate_size,
cfg.text_config.hidden_size,
vb.pp("linear_1"),
)?;
let linear_2 = linear_no_bias(
cfg.text_config.hidden_size,
cfg.text_config.hidden_size,
vb.pp("linear_2"),
)?;
let activation = match cfg.projector_hidden_act.as_str() {
"gelu" => candle_nn::Activation::Gelu,
"relu" => candle_nn::Activation::Relu,
_ => candle::bail!(
"Unsupported projector activation: {}",
cfg.projector_hidden_act
),
};
Ok(Self {
linear_1,
linear_2,
activation,
})
}
pub fn forward(&self, audio_features: &Tensor) -> Result<Tensor> {
let x = self.linear_1.forward(audio_features)?;
let x = x.apply(&self.activation)?;
self.linear_2.forward(&x)
}
}
#[derive(Debug, Clone)]
pub struct VoxtralForConditionalGeneration {
audio_tower: VoxtralEncoder,
language_model: VoxtralLlama,
multi_modal_projector: VoxtralMultiModalProjector,
audio_token_id: usize,
audio_config: VoxtralEncoderConfig,
text_config: VoxtralLlamaConfig,
}
impl VoxtralForConditionalGeneration {
pub fn new(cfg: &VoxtralConfig, vb: VarBuilder) -> Result<Self> {
let audio_tower = VoxtralEncoder::new(&cfg.audio_config, vb.pp("audio_tower"))?;
let language_model = VoxtralLlama::load(vb.pp("language_model"), &cfg.text_config)?;
let multi_modal_projector =
VoxtralMultiModalProjector::new(cfg, vb.pp("multi_modal_projector"))?;
Ok(Self {
audio_tower,
language_model,
multi_modal_projector,
audio_token_id: cfg.audio_token_id,
audio_config: cfg.audio_config.clone(),
text_config: cfg.text_config.clone(),
})
}
pub fn audio_token_id(&self) -> usize {
self.audio_token_id
}
pub fn text_config(&self) -> &VoxtralLlamaConfig {
&self.text_config
}
pub fn audio_config(&self) -> &VoxtralEncoderConfig {
&self.audio_config
}
pub fn get_audio_embeds(&self, input_features: &Tensor) -> Result<Tensor> {
let audio_outputs = self.audio_tower.forward(input_features)?;
let (batch_size, seq_len, hidden_size) = audio_outputs.dims3()?;
let total_elements = batch_size * seq_len * hidden_size;
let new_batch_size = total_elements / self.audio_config.intermediate_size;
if total_elements % self.audio_config.intermediate_size != 0 {
return Err(candle::Error::DimOutOfRange {
shape: candle::Shape::from_dims(&[batch_size, seq_len, hidden_size]),
dim: 0,
op: "reshape",
});
}
let audio_hidden =
audio_outputs.reshape((new_batch_size, self.audio_config.intermediate_size))?;
let projected = self.multi_modal_projector.forward(&audio_hidden)?;
Ok(projected)
}
pub fn get_audio_embeds_chunked(
&self,
input_features: &Tensor,
chunk_size: usize,
overlap: usize,
) -> Result<Tensor> {
let audio_outputs =
self.audio_tower
.process_long_audio(input_features, chunk_size, overlap)?;
let (batch_size, seq_len, hidden_size) = audio_outputs.dims3()?;
let total_elements = batch_size * seq_len * hidden_size;
let new_batch_size = total_elements / self.audio_config.intermediate_size;
let audio_hidden =
audio_outputs.reshape((new_batch_size, self.audio_config.intermediate_size))?;
let projected = self.multi_modal_projector.forward(&audio_hidden)?;
let text_hidden_size = self.text_config.hidden_size;
let projected = projected.reshape((batch_size, seq_len, text_hidden_size))?;
let pooled = projected.mean(1)?;
Ok(pooled)
}
pub fn forward(
&self,
input_ids: &Tensor,
input_features: Option<&Tensor>,
cache: &mut VoxtralCache,
index_pos: usize,
) -> Result<Tensor> {
let mut inputs_embeds = self.language_model.embed(input_ids)?;
if let Some(features) = input_features {
if !cache.audio_processed {
let audio_embeds = self.get_audio_embeds(features)?;
let audio_positions = find_audio_token_positions(input_ids, self.audio_token_id)?;
cache.cached_audio_embeds = Some(audio_embeds.clone());
cache.cached_audio_positions = Some(audio_positions.clone());
cache.audio_processed = true;
inputs_embeds = replace_audio_tokens(
&inputs_embeds,
&audio_embeds,
&audio_positions,
input_ids.device(),
)?;
}
}
self.language_model
.forward_input_embed(&inputs_embeds, index_pos, &mut cache.cache)
}
pub fn generate(
&self,
input_ids: &Tensor,
input_features: Option<&Tensor>,
config: VoxtralGenerationConfig,
) -> Result<Vec<u32>> {
if config.max_new_tokens == 0 {
return input_ids.i(0)?.to_vec1::<u32>(); }
if config.temperature < 0.0 {
candle::bail!(
"Temperature must be non-negative, got {}",
config.temperature
);
}
if let Some(p) = config.top_p {
if !(0.0..=1.0).contains(&p) {
candle::bail!("top_p must be between 0 and 1, got {}", p);
}
}
let mut final_cache = if let Some(cache) = config.cache {
cache
} else {
let dummy_token = Tensor::new(&[1u32], &config.device)?;
let dummy_embed = self.language_model.embed(&dummy_token)?;
let model_dtype = dummy_embed.dtype();
VoxtralCache::new(true, model_dtype, &self.text_config, &config.device)?
};
let mut tokens = input_ids.i(0)?.to_vec1::<u32>()?; let initial_len = tokens.len();
for idx in 0..config.max_new_tokens {
let (input, index_pos) = if idx == 0 {
(input_ids.clone(), 0)
} else {
let last_token = tokens[tokens.len() - 1];
let calculated_pos = initial_len + idx - 1;
(
Tensor::new(&[last_token], &config.device)?.unsqueeze(0)?,
calculated_pos,
)
};
let logits = if idx == 0 {
match self.forward(&input, input_features, &mut final_cache, index_pos) {
Ok(logits) => logits,
Err(e) => {
return Err(candle::Error::Msg(format!(
"Failed to generate tokens: {e}"
)));
}
}
} else {
match self.forward(&input, None, &mut final_cache, index_pos) {
Ok(logits) => logits,
Err(e) => {
return Err(candle::Error::Msg(format!(
"Failed to generate tokens: {e}"
)));
}
}
};
let logits = if logits.dims().len() == 3 {
logits.i((.., logits.dim(1)? - 1, ..))?
} else {
logits
};
let next_token = if config.temperature > 0.0 {
let prs = (logits / config.temperature)?;
let prs = candle_nn::ops::softmax_last_dim(&prs)?;
if let Some(top_p_val) = config.top_p {
sample_top_p(&prs.squeeze(0)?, top_p_val, &config.device)?
} else {
let probs_vec = prs.squeeze(0)?.to_vec1::<f32>()?;
let mut rng = rand::rng();
let mut cumsum = 0.0;
let rand_val: f32 = rng.random();
let mut sampled = 0u32;
for (idx, &prob) in probs_vec.iter().enumerate() {
cumsum += prob;
if cumsum > rand_val {
sampled = idx as u32;
break;
}
}
sampled
}
} else {
let argmax_result = match logits.argmax(D::Minus1) {
Ok(result) => result,
Err(e) => {
return Err(candle::Error::Msg(format!("Argmax failed: {e}")));
}
};
if argmax_result.dims().is_empty() {
match argmax_result.to_scalar::<u32>() {
Ok(token) => token,
Err(e) => {
return Err(candle::Error::Msg(format!("to_scalar failed: {e}")));
}
}
} else if argmax_result.dims() == [1] {
match argmax_result.i(0) {
Ok(scalar_tensor) => match scalar_tensor.to_scalar::<u32>() {
Ok(token) => token,
Err(e) => {
return Err(candle::Error::Msg(format!(
"to_scalar on extracted element failed: {e}"
)));
}
},
Err(e) => {
return Err(candle::Error::Msg(format!(
"indexing argmax result failed: {e}"
)));
}
}
} else {
return Err(candle::Error::Msg(format!(
"Unexpected argmax result shape: {:?}",
argmax_result.shape()
)));
}
};
tokens.push(next_token);
let eos_tokens = [2u32, 128001, 128009, 128256];
if eos_tokens.contains(&next_token) {
break;
}
if next_token == 0 && tokens.len() > 5 {
let last_5_tokens = &tokens[tokens.len() - 5..];
if last_5_tokens.iter().all(|&t| t == 0) {
break;
}
}
}
Ok(tokens)
}
}
fn sample_top_p(probs: &Tensor, top_p: f64, _device: &Device) -> Result<u32> {
let (sorted_probs, sorted_indices) = probs.sort_last_dim(false)?;
let cumsum = sorted_probs.cumsum(D::Minus1)?;
let mask = cumsum.le(top_p)?;
let filtered_probs = sorted_probs.where_cond(&mask, &Tensor::zeros_like(&sorted_probs)?)?;
let filtered_probs = (&filtered_probs / filtered_probs.sum_keepdim(D::Minus1)?)?;
let probs_vec = filtered_probs.to_vec1::<f32>()?;
let mut cumsum = 0.0;
let mut rng = rand::rng();
let rand_val: f32 = rng.random();
let mut sample_idx = 0;
for (idx, &prob) in probs_vec.iter().enumerate() {
cumsum += prob;
if cumsum > rand_val {
sample_idx = idx;
break;
}
}
sorted_indices.i(sample_idx)?.to_scalar::<u32>()
}