use std::ops::{Deref, DerefMut};
use num_traits::{Float, FromPrimitive, PrimInt, ToPrimitive};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SamplerError {
#[error("internal error: {0}")]
InternalError(String),
#[error("logits error: {0}")]
LogitsError(LogitsError),
#[cfg(feature = "rand")]
#[error("rand error: {0}")]
RandError(rand::Error),
#[cfg(feature = "rand")]
#[error("rand weights error: {0}")]
RandWeightedError(rand::distributions::WeightedError),
}
#[derive(Debug, Clone, Error)]
pub enum LogitsError {
#[error("Invalid logit for token id {0}")]
InvalidLogit(usize),
#[error("internal logits error: {0}")]
InternalError(String),
}
impl From<LogitsError> for SamplerError {
fn from(value: LogitsError) -> Self {
SamplerError::LogitsError(value)
}
}
pub trait CanTokenId: PrimInt + FromPrimitive + ToPrimitive + Send + Sync {}
impl<T: PrimInt + FromPrimitive + ToPrimitive + Send + Sync> CanTokenId for T {}
pub trait CanLogit: Float + Send + Sync {}
impl<T: Float + Send + Sync> CanLogit for T {}
#[derive(Debug, Clone, PartialEq)]
pub struct Logit<TID, L> {
pub token_id: TID,
pub logit: L,
pub prob: L,
}
#[derive(Debug, Clone)]
pub struct Logits<TID, L> {
sorted: bool,
logits: Vec<Logit<TID, L>>,
}
impl<TID, L> Deref for Logits<TID, L> {
type Target = Vec<Logit<TID, L>>;
fn deref(&self) -> &Self::Target {
&self.logits
}
}
impl<TID, L> DerefMut for Logits<TID, L> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.logits
}
}
impl<L: CanLogit> Logits<u32, L> {
pub fn try_from_iter<I: IntoIterator<Item = L>>(it: I) -> Result<Self, LogitsError> {
Ok(Self {
sorted: false,
logits: it
.into_iter()
.enumerate()
.map(|(tid, logit)| {
if logit.is_nan() {
Err(LogitsError::InvalidLogit(tid))?
}
Ok(Logit {
token_id: tid as u32,
logit,
prob: L::zero(),
})
})
.collect::<Result<Vec<_>, LogitsError>>()?,
})
}
}
impl<L: CanLogit> TryFrom<Vec<L>> for Logits<u32, L> {
type Error = LogitsError;
fn try_from(value: Vec<L>) -> Result<Self, Self::Error> {
Self::try_from_iter(value)
}
}
impl<TID: CanTokenId, L: CanLogit> Logits<TID, L> {
pub fn get_sorted(&self) -> bool {
self.sorted
}
pub fn set_sorted(&mut self, is_sorted: bool) -> &mut Self {
self.sorted = is_sorted;
self
}
pub fn ensure_sorted(&mut self) -> Result<&mut Self, LogitsError> {
if self.get_sorted() {
return Ok(self);
}
let mut sort_err = Ok(());
self.logits.as_mut_slice().sort_by(|a, b| {
b.logit.partial_cmp(&a.logit).unwrap_or_else(|| {
sort_err = Err(LogitsError::InternalError(String::from(
"Impossible: logit comparison failed?",
)));
std::cmp::Ordering::Less
})
});
sort_err?;
self.set_sorted(true);
Ok(self)
}
pub fn softmax(&mut self) -> Result<&mut Self, LogitsError> {
if self.is_empty() {
return Ok(self);
}
self.ensure_sorted()?;
let max_l = self[0].logit;
let cum_sum = self.iter_mut().fold(L::zero(), |cs, l| {
let p = (l.logit - max_l).exp();
l.prob = p;
cs + p
});
self.iter_mut().for_each(|l| l.prob = l.prob / cum_sum);
Ok(self)
}
pub fn sample<S: Sampler<TID, L>>(
&mut self,
sampler: &mut S,
) -> Result<&mut Self, SamplerError> {
sampler.sample(self)
}
pub fn sample_token<S: Sampler<TID, L>>(
&mut self,
sampler: &mut S,
) -> Result<Option<TID>, SamplerError> {
sampler.sample_token(self)
}
}
pub trait Sampler<TID, L>: Send + Sync {
fn sample<'a>(
&mut self,
logits: &'a mut Logits<TID, L>,
) -> Result<&'a mut Logits<TID, L>, SamplerError>;
fn sampled_token_id(&self) -> Option<TID> {
None
}
fn sample_token(&mut self, logits: &mut Logits<TID, L>) -> Result<Option<TID>, SamplerError> {
let _ = self.sample(logits)?;
Ok(self.sampled_token_id())
}
}
#[derive(Default)]
pub struct SamplerChain<TID, L> {
samplers: Vec<Box<dyn Sampler<TID, L>>>,
token: Option<TID>,
}
impl<TID: CanTokenId, L: CanLogit> SamplerChain<TID, L> {
pub fn new() -> Self {
Self {
samplers: vec![],
token: None,
}
}
pub fn push_sampler(
&mut self,
sampler: impl Sampler<TID, L> + Send + Sync + 'static,
) -> &mut Self {
self.token = None;
self.samplers.push(Box::new(sampler));
self
}
}
impl<TID: CanTokenId, L: CanLogit> Sampler<TID, L> for SamplerChain<TID, L> {
fn sample<'a>(
&mut self,
logits: &'a mut Logits<TID, L>,
) -> Result<&'a mut Logits<TID, L>, SamplerError> {
self.token = None;
self.samplers
.iter_mut()
.try_fold(logits, |logits, sampler| {
let new_logits = sampler.sample(logits)?;
self.token = sampler.sampled_token_id();
Ok(new_logits)
})
}
fn sampled_token_id(&self) -> Option<TID> {
self.token
}
}