use crate::error::Result;
use numr::dtype::DType;
use numr::ops::RandomOps;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub fn apply_sampling_penalties_impl<R: Runtime>(
_client: &R::Client,
logits: &Tensor<R>,
token_ids: &Tensor<R>,
token_counts: &Tensor<R>,
repeat_penalty: f32,
frequency_penalty: f32,
presence_penalty: f32,
) -> Result<()> {
let mut logits_vec: Vec<f32> = logits.to_vec();
let ids_vec: Vec<i64> = token_ids.to_vec();
let counts_vec: Vec<i32> = token_counts.to_vec();
for (&token_id, &count) in ids_vec.iter().zip(counts_vec.iter()) {
let i = token_id as usize;
if i >= logits_vec.len() {
continue;
}
if repeat_penalty != 1.0 {
if logits_vec[i] > 0.0 {
logits_vec[i] /= repeat_penalty;
} else {
logits_vec[i] *= repeat_penalty;
}
}
if frequency_penalty != 0.0 {
logits_vec[i] -= frequency_penalty * count as f32;
}
if presence_penalty != 0.0 {
logits_vec[i] -= presence_penalty;
}
}
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(logits_vec.as_ptr() as *const u8, logits_vec.len() * 4)
};
R::copy_to_device(bytes, logits.ptr(), logits.device()).map_err(|e| {
crate::error::Error::Numr(numr::error::Error::Internal(format!(
"Failed to write back penalized logits: {}",
e
)))
})?;
Ok(())
}
pub fn sample_token_impl<R: Runtime>(
client: &R::Client,
logits: &Tensor<R>,
temperature: f32,
top_k: usize,
top_p: f32,
min_p: f32,
) -> Result<u32>
where
R::Client: RandomOps<R>,
{
let mut logits_vec: Vec<f32> = logits.to_vec();
if temperature != 1.0 {
let inv_temp = 1.0 / temperature;
for l in logits_vec.iter_mut() {
*l *= inv_temp;
}
}
let max_logit = logits_vec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<f32> = logits_vec.iter().map(|&l| (l - max_logit).exp()).collect();
let sum: f32 = probs.iter().sum();
for p in probs.iter_mut() {
*p /= sum;
}
let mut indexed: Vec<(usize, f32)> = probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if top_k > 0 && top_k < indexed.len() {
indexed.truncate(top_k);
}
if top_p < 1.0 {
let mut cumsum = 0.0f32;
let mut cutoff = indexed.len();
for (i, (_, p)) in indexed.iter().enumerate() {
cumsum += p;
if cumsum > top_p {
cutoff = i + 1;
break;
}
}
indexed.truncate(cutoff);
}
if min_p > 0.0 && !indexed.is_empty() {
let max_prob = indexed[0].1;
let threshold = min_p * max_prob;
indexed.retain(|(_, p)| *p >= threshold);
}
let rand_tensor = client
.rand(&[1], numr::dtype::DType::F32)
.map_err(crate::error::Error::Numr)?;
let random_val: f32 = rand_tensor.to_vec::<f32>()[0];
let total: f32 = indexed.iter().map(|(_, p)| p).sum();
let mut cumsum = 0.0f32;
for (i, p) in &indexed {
cumsum += p / total;
if cumsum > random_val {
return Ok(*i as u32);
}
}
Ok(indexed.last().map(|(i, _)| *i as u32).unwrap_or(0))
}
#[allow(clippy::too_many_arguments)]
pub fn logits_to_token_impl<R: Runtime<DType = numr::dtype::DType>>(
client: &R::Client,
logits: &Tensor<R>,
token_ids: &Tensor<R>,
token_counts: &Tensor<R>,
num_unique: usize,
repeat_penalty: f32,
frequency_penalty: f32,
presence_penalty: f32,
temperature: f32,
top_k: usize,
top_p: f32,
min_p: f32,
seed: Option<u64>,
) -> Result<Tensor<R>>
where
R::Client: numr::ops::RandomOps<R> + numr::ops::TypeConversionOps<R>,
{
let logits = if logits.dtype() != DType::F32 {
use numr::ops::TypeConversionOps;
client
.cast(logits, DType::F32)
.map_err(crate::error::Error::Numr)?
} else {
logits.clone()
};
let shape = logits.shape();
if shape.len() < 3 {
return Err(crate::error::Error::InvalidArgument {
arg: "logits",
reason: format!("expected rank >= 3, got rank {}", shape.len()),
});
}
let seq_len = shape[1];
let vocab_size = shape[2];
if seq_len == 0 || vocab_size == 0 {
return Err(crate::error::Error::InvalidArgument {
arg: "logits",
reason: format!("seq_len and vocab_size must be > 0, got shape {:?}", shape),
});
}
let all_logits: Vec<f32> = logits.to_vec();
let offset = (seq_len - 1) * vocab_size;
let mut last_logits: Vec<f32> = all_logits[offset..offset + vocab_size].to_vec();
if num_unique > 0 {
let ids_vec: Vec<i64> = token_ids.to_vec();
let counts_vec: Vec<i32> = token_counts.to_vec();
let penalty_count = num_unique.min(ids_vec.len()).min(counts_vec.len());
for idx in 0..penalty_count {
let token_id = ids_vec[idx] as usize;
if token_id >= vocab_size {
continue;
}
let count = counts_vec[idx];
if repeat_penalty != 1.0 {
if last_logits[token_id] > 0.0 {
last_logits[token_id] /= repeat_penalty;
} else {
last_logits[token_id] *= repeat_penalty;
}
}
if frequency_penalty != 0.0 {
last_logits[token_id] -= frequency_penalty * count as f32;
}
if presence_penalty != 0.0 {
last_logits[token_id] -= presence_penalty;
}
}
}
let token_id = if temperature == 0.0 {
let mut best_idx = 0usize;
let mut best_val = f32::NEG_INFINITY;
for (i, &v) in last_logits.iter().enumerate() {
if v > best_val {
best_val = v;
best_idx = i;
}
}
best_idx as i64
} else {
let inv_temp = 1.0 / temperature;
for l in last_logits.iter_mut() {
*l *= inv_temp;
}
let max_logit = last_logits
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let mut probs: Vec<f32> = last_logits.iter().map(|&l| (l - max_logit).exp()).collect();
let sum: f32 = probs.iter().sum();
for p in probs.iter_mut() {
*p /= sum;
}
let mut indexed: Vec<(usize, f32)> =
probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
if top_k > 0 && top_k < indexed.len() {
indexed.truncate(top_k);
}
if top_p < 1.0 {
let mut cumsum = 0.0f32;
let mut cutoff = indexed.len();
for (i, (_, p)) in indexed.iter().enumerate() {
cumsum += p;
if cumsum > top_p {
cutoff = i + 1;
break;
}
}
indexed.truncate(cutoff);
}
if min_p > 0.0 && !indexed.is_empty() {
let max_prob = indexed[0].1;
let threshold = min_p * max_prob;
indexed.retain(|(_, p)| *p >= threshold);
}
let rand_tensor = if let Some(s) = seed {
client
.rand_seeded(&[1], numr::dtype::DType::F32, s)
.map_err(crate::error::Error::Numr)?
} else {
client
.rand(&[1], numr::dtype::DType::F32)
.map_err(crate::error::Error::Numr)?
};
let random_val: f32 = rand_tensor.to_vec::<f32>()[0];
let total: f32 = indexed.iter().map(|(_, p)| p).sum();
let mut cumsum = 0.0f32;
let mut sampled = indexed.last().map(|(i, _)| *i).unwrap_or(0);
for (i, p) in &indexed {
cumsum += p / total;
if cumsum > random_val {
sampled = *i;
break;
}
}
sampled as i64
};
Ok(Tensor::from_slice(&[token_id], &[1], logits.device()))
}