use std::cell::RefCell;
use smol_str::format_smolstr;
use crate::{
array::Array,
error::{
EmptyInputPayload, Error, LengthMismatchPayload, NonFiniteScalarPayload, OutOfRangePayload,
RankMismatchPayload, Result, try_extend_from_slice, try_with_capacity,
},
lm::{cache::KvCache, model::Model, sample},
ops,
};
#[cfg(feature = "tokenizer-stream")]
use crate::tokenizer::StreamingDetokenizer as _;
pub type LogitsProcessorFn = Box<dyn Fn(&[u32], &Array) -> Result<Array>>;
pub type SamplerFn = Box<dyn FnMut(&Array) -> Result<Array>>;
#[derive(Debug)]
pub struct LogitBiasPayload {
indices: Vec<i32>,
values: Array,
}
impl LogitBiasPayload {
pub fn new(indices: Vec<i32>, values: Array) -> Self {
Self { indices, values }
}
#[inline(always)]
pub fn indices_slice(&self) -> &[i32] {
&self.indices
}
#[inline(always)]
pub fn values_ref(&self) -> &Array {
&self.values
}
}
#[derive(Debug, Clone, Copy)]
pub struct RepetitionPenaltyPayload {
penalty: f32,
context_size: usize,
}
impl RepetitionPenaltyPayload {
pub const fn new(penalty: f32, context_size: usize) -> Self {
Self {
penalty,
context_size,
}
}
#[inline(always)]
pub const fn penalty(&self) -> f32 {
self.penalty
}
#[inline(always)]
pub const fn context_size(&self) -> usize {
self.context_size
}
}
#[derive(Debug, Clone, Copy)]
pub struct PresencePenaltyPayload {
penalty: f32,
context_size: usize,
}
impl PresencePenaltyPayload {
pub const fn new(penalty: f32, context_size: usize) -> Self {
Self {
penalty,
context_size,
}
}
#[inline(always)]
pub const fn penalty(&self) -> f32 {
self.penalty
}
#[inline(always)]
pub const fn context_size(&self) -> usize {
self.context_size
}
}
#[derive(Debug, Clone, Copy)]
pub struct FrequencyPenaltyPayload {
penalty: f32,
context_size: usize,
}
impl FrequencyPenaltyPayload {
pub const fn new(penalty: f32, context_size: usize) -> Self {
Self {
penalty,
context_size,
}
}
#[inline(always)]
pub const fn penalty(&self) -> f32 {
self.penalty
}
#[inline(always)]
pub const fn context_size(&self) -> usize {
self.context_size
}
}
#[non_exhaustive]
#[derive(derive_more::IsVariant)]
pub enum LogitsProcessor {
LogitBias(LogitBiasPayload),
RepetitionPenalty(RepetitionPenaltyPayload),
PresencePenalty(PresencePenaltyPayload),
FrequencyPenalty(FrequencyPenaltyPayload),
Custom(LogitsProcessorFn),
}
impl LogitsProcessor {
pub fn apply(&self, tokens: &[u32], logits: &Array) -> Result<Array> {
match self {
Self::LogitBias(p) => sample::apply_logit_bias(logits, p.indices_slice(), p.values_ref()),
Self::RepetitionPenalty(p) => {
let ids = recent_ids(tokens, p.context_size())?;
sample::apply_repetition_penalty(logits, &ids, p.penalty())
}
Self::PresencePenalty(p) => {
let ids = recent_ids(tokens, p.context_size())?;
sample::apply_presence_penalty(logits, &ids, p.penalty())
}
Self::FrequencyPenalty(p) => {
let ids = recent_ids(tokens, p.context_size())?;
sample::apply_frequency_penalty(logits, &ids, p.penalty())
}
Self::Custom(f) => f(tokens, logits),
}
}
}
impl std::fmt::Debug for LogitsProcessor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::LogitBias(p) => f
.debug_struct("LogitBias")
.field("n", &p.indices_slice().len())
.finish(),
Self::RepetitionPenalty(p) => f
.debug_struct("RepetitionPenalty")
.field("penalty", &p.penalty())
.field("context_size", &p.context_size())
.finish(),
Self::PresencePenalty(p) => f
.debug_struct("PresencePenalty")
.field("penalty", &p.penalty())
.field("context_size", &p.context_size())
.finish(),
Self::FrequencyPenalty(p) => f
.debug_struct("FrequencyPenalty")
.field("penalty", &p.penalty())
.field("context_size", &p.context_size())
.finish(),
Self::Custom(_) => f.debug_tuple("Custom").finish(),
}
}
}
pub enum Sampler {
Argmax,
Chain(SamplerChain),
Custom(SamplerFn),
}
impl Sampler {
pub fn custom<F>(f: F) -> Self
where
F: FnMut(&Array) -> Result<Array> + 'static,
{
Self::Custom(Box::new(f))
}
pub fn sample(&mut self, logits: &Array) -> Result<Array> {
match self {
Self::Argmax => sample::argmax_sample(logits),
Self::Chain(c) => c.sample(logits),
Self::Custom(f) => f(logits),
}
}
}
impl std::fmt::Debug for Sampler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Argmax => f.write_str("Argmax"),
Self::Chain(c) => f.debug_tuple("Chain").field(c).finish(),
Self::Custom(_) => f.debug_tuple("Custom").finish(),
}
}
}
pub struct SamplerChain {
temp: f32,
top_p: f32,
min_p: f32,
min_tokens_to_keep: i32,
top_k: i32,
xtc_probability: f32,
xtc_threshold: f32,
xtc_special: Vec<i32>,
do_top_p: bool,
do_min_p: bool,
do_xtc: bool,
do_top_k: bool,
key: RefCell<Array>,
}
impl SamplerChain {
fn sample(&self, logprobs: &Array) -> Result<Array> {
let (k_xtc, k_cat) = {
let mut k = self.key.borrow_mut();
let (next, k_xtc) = ops::random::split(&k)?;
let (next, k_cat) = ops::random::split(&next)?;
*k = next;
(k_xtc, k_cat)
};
let mut x: Option<Array> = if self.do_top_p {
Some(sample::apply_top_p(logprobs, self.top_p)?)
} else {
None
};
if self.do_min_p {
x = Some(sample::apply_min_p(
x.as_ref().unwrap_or(logprobs),
self.min_p,
self.min_tokens_to_keep,
)?);
}
if self.do_xtc {
x = Some(sample::apply_xtc(
x.as_ref().unwrap_or(logprobs),
self.xtc_probability,
self.xtc_threshold,
&self.xtc_special,
&k_xtc,
)?);
}
if self.do_top_k {
x = Some(sample::apply_top_k(
x.as_ref().unwrap_or(logprobs),
self.top_k,
)?);
}
sample::categorical_sampling(x.as_ref().unwrap_or(logprobs), self.temp, &k_cat)
}
}
impl std::fmt::Debug for SamplerChain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SamplerChain")
.field("temp", &self.temp)
.field("top_p", &self.top_p)
.field("min_p", &self.min_p)
.field("top_k", &self.top_k)
.field("xtc_probability", &self.xtc_probability)
.finish()
}
}
pub const DEFAULT_REPETITION_CONTEXT_SIZE: usize = 20;
fn next_sampler_seed() -> u64 {
use std::{
sync::atomic::{AtomicU64, Ordering},
time::{SystemTime, UNIX_EPOCH},
};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
nanos ^ n.wrapping_mul(0x9E37_79B9_7F4A_7C15)
}
#[doc(hidden)]
pub fn __resolved_unseeded_seed_for_test() -> u64 {
next_sampler_seed()
}
#[derive(Debug, Clone)]
pub struct GenConfig {
pub max_tokens: usize,
pub prefill_step_size: usize,
pub temp: f32,
pub top_p: f32,
pub min_p: f32,
pub min_tokens_to_keep: i32,
pub top_k: i32,
pub xtc_probability: f32,
pub xtc_threshold: f32,
pub(crate) xtc_special_tokens: Vec<i32>,
pub(crate) logit_bias: Vec<(i32, f32)>,
pub repetition_penalty: Option<f32>,
pub repetition_context_size: usize,
pub presence_penalty: Option<f32>,
pub presence_context_size: usize,
pub frequency_penalty: Option<f32>,
pub frequency_context_size: usize,
pub(crate) eos: Vec<u32>,
pub(crate) stop_strings: Vec<String>,
pub seed: Option<u64>,
pub collect_logprobs: bool,
}
impl Default for GenConfig {
fn default() -> Self {
Self {
max_tokens: 256,
prefill_step_size: 2048,
temp: 0.0,
top_p: 0.0,
min_p: 0.0,
min_tokens_to_keep: 1,
top_k: 0,
xtc_probability: 0.0,
xtc_threshold: 0.0,
xtc_special_tokens: Vec::new(),
logit_bias: Vec::new(),
repetition_penalty: None,
repetition_context_size: DEFAULT_REPETITION_CONTEXT_SIZE,
presence_penalty: None,
presence_context_size: DEFAULT_REPETITION_CONTEXT_SIZE,
frequency_penalty: None,
frequency_context_size: DEFAULT_REPETITION_CONTEXT_SIZE,
eos: Vec::new(),
stop_strings: Vec::new(),
seed: None,
collect_logprobs: false,
}
}
}
impl GenConfig {
pub fn new() -> Self {
Self::default()
}
#[inline(always)]
pub fn xtc_special_tokens_slice(&self) -> &[i32] {
&self.xtc_special_tokens
}
#[inline(always)]
pub fn logit_bias_slice(&self) -> &[(i32, f32)] {
&self.logit_bias
}
#[inline(always)]
pub fn eos_slice(&self) -> &[u32] {
&self.eos
}
#[inline(always)]
pub fn stop_strings_slice(&self) -> &[String] {
&self.stop_strings
}
#[must_use]
pub fn with_max_tokens(mut self, n: usize) -> Self {
self.max_tokens = n;
self
}
#[must_use]
pub fn with_prefill_step_size(mut self, n: usize) -> Self {
self.prefill_step_size = n;
self
}
#[must_use]
pub fn with_xtc_special_tokens(mut self, tokens: impl Into<Vec<i32>>) -> Self {
self.xtc_special_tokens = tokens.into();
self
}
#[must_use]
pub fn with_logit_bias(mut self, bias: impl Into<Vec<(i32, f32)>>) -> Self {
self.logit_bias = bias.into();
self
}
#[must_use]
pub fn with_eos(mut self, eos: impl Into<Vec<u32>>) -> Self {
self.eos = eos.into();
self
}
#[must_use]
pub fn with_stop_strings(mut self, stops: impl Into<Vec<String>>) -> Self {
self.stop_strings = stops.into();
self
}
pub fn set_xtc_special_tokens(&mut self, tokens: impl Into<Vec<i32>>) -> &mut Self {
self.xtc_special_tokens = tokens.into();
self
}
pub fn set_logit_bias(&mut self, bias: impl Into<Vec<(i32, f32)>>) -> &mut Self {
self.logit_bias = bias.into();
self
}
pub fn set_eos(&mut self, eos: impl Into<Vec<u32>>) -> &mut Self {
self.eos = eos.into();
self
}
pub fn set_stop_strings(&mut self, stops: impl Into<Vec<String>>) -> &mut Self {
self.stop_strings = stops.into();
self
}
#[must_use]
pub fn with_temp(mut self, temp: f32) -> Self {
self.temp = temp;
self
}
pub fn validate(&self) -> Result<()> {
if !self.temp.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"GenConfig::validate: temp",
self.temp as f64,
)));
}
if self.temp < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"GenConfig::validate: temp",
"must be a finite non-negative float (0.0 = argmax, > 0.0 = stochastic)",
format_smolstr!("{}", self.temp),
)));
}
if !self.top_p.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"GenConfig::validate: top_p",
self.top_p as f64,
)));
}
if !(0.0..=1.0).contains(&self.top_p) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"GenConfig::validate: top_p",
"must be in [0, 1] (0 = off, (0, 1) = nucleus cutoff, 1 = include everything)",
format_smolstr!("{}", self.top_p),
)));
}
if !self.min_p.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"GenConfig::validate: min_p",
self.min_p as f64,
)));
}
if !(0.0..=1.0).contains(&self.min_p) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"GenConfig::validate: min_p",
"must be in [0, 1]",
format_smolstr!("{}", self.min_p),
)));
}
if self.min_tokens_to_keep < 1 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"GenConfig::validate: min_tokens_to_keep",
"must be a positive integer (>= 1)",
format_smolstr!("{}", self.min_tokens_to_keep),
)));
}
if self.top_k < 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"GenConfig::validate: top_k",
"must be non-negative (0 = off, > 0 = top-k cutoff)",
format_smolstr!("{}", self.top_k),
)));
}
if !self.xtc_probability.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"GenConfig::validate: xtc_probability",
self.xtc_probability as f64,
)));
}
if !(0.0..=1.0).contains(&self.xtc_probability) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"GenConfig::validate: xtc_probability",
"must be in [0, 1]",
format_smolstr!("{}", self.xtc_probability),
)));
}
if !self.xtc_threshold.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"GenConfig::validate: xtc_threshold",
self.xtc_threshold as f64,
)));
}
if !(0.0..=0.5).contains(&self.xtc_threshold) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"GenConfig::validate: xtc_threshold",
"must be in [0, 0.5]",
format_smolstr!("{}", self.xtc_threshold),
)));
}
if let Some(p) = self.repetition_penalty {
if !p.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"GenConfig::validate: repetition_penalty",
p as f64,
)));
}
if p < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"GenConfig::validate: repetition_penalty",
"must be a finite non-negative float when Some(_)",
format_smolstr!("{p}"),
)));
}
}
if let Some(p) = self.presence_penalty
&& !p.is_finite()
{
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"GenConfig::validate: presence_penalty",
p as f64,
)));
}
if let Some(p) = self.frequency_penalty
&& !p.is_finite()
{
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"GenConfig::validate: frequency_penalty",
p as f64,
)));
}
for &(_id, v) in &self.logit_bias {
if !v.is_finite() {
return Err(Error::NonFiniteScalar(NonFiniteScalarPayload::new(
"GenConfig::validate: logit_bias value",
v as f64,
)));
}
}
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
pub fn make_sampler(
temp: f32,
top_p: f32,
min_p: f32,
min_tokens_to_keep: i32,
top_k: i32,
xtc_probability: f32,
xtc_threshold: f32,
xtc_special_tokens: &[i32],
seed: Option<u64>,
) -> Result<Sampler> {
if temp == 0.0 {
return Ok(Sampler::Argmax);
}
let do_top_p = top_p > 0.0 && top_p < 1.0;
let do_min_p = min_p != 0.0;
let do_xtc = xtc_probability > 0.0;
let do_top_k = top_k > 0;
let xtc_special: Vec<i32> = xtc_special_tokens.to_vec();
let resolved_seed = seed.unwrap_or_else(next_sampler_seed);
let key = RefCell::new(ops::random::key(resolved_seed)?);
Ok(Sampler::Chain(SamplerChain {
temp,
top_p,
min_p,
min_tokens_to_keep,
top_k,
xtc_probability,
xtc_threshold,
xtc_special,
do_top_p,
do_min_p,
do_xtc,
do_top_k,
key,
}))
}
#[allow(clippy::too_many_arguments)]
pub fn make_logits_processors(
logit_bias: &[(i32, f32)],
repetition_penalty: Option<f32>,
repetition_context_size: usize,
presence_penalty: Option<f32>,
presence_context_size: usize,
frequency_penalty: Option<f32>,
frequency_context_size: usize,
) -> Result<Vec<LogitsProcessor>> {
let mut processors: Vec<LogitsProcessor> = Vec::new();
if !logit_bias.is_empty() {
let mut indices: Vec<i32> = try_with_capacity(logit_bias.len())?;
indices.extend(logit_bias.iter().map(|&(i, _)| i));
let mut values_vec: Vec<f32> = try_with_capacity(logit_bias.len())?;
values_vec.extend(logit_bias.iter().map(|&(_, v)| v));
let values = Array::from_slice::<f32>(&values_vec, &(values_vec.len(),))?;
processors.push(LogitsProcessor::LogitBias(LogitBiasPayload::new(
indices, values,
)));
}
if let Some(p) = repetition_penalty.filter(|&p| p != 0.0) {
processors.push(LogitsProcessor::RepetitionPenalty(
RepetitionPenaltyPayload::new(p, repetition_context_size),
));
}
if let Some(p) = presence_penalty.filter(|&p| p != 0.0) {
processors.push(LogitsProcessor::PresencePenalty(
PresencePenaltyPayload::new(p, presence_context_size),
));
}
if let Some(p) = frequency_penalty.filter(|&p| p != 0.0) {
processors.push(LogitsProcessor::FrequencyPenalty(
FrequencyPenaltyPayload::new(p, frequency_context_size),
));
}
Ok(processors)
}
fn recent_ids(tokens: &[u32], context_size: usize) -> Result<Vec<i32>> {
let start = if context_size == 0 {
0
} else {
tokens.len().saturating_sub(context_size)
};
let tail = &tokens[start..];
let mut ids = try_with_capacity(tail.len())?;
ids.extend(tail.iter().map(|&t| t as i32));
Ok(ids)
}
#[derive(
Debug,
Clone,
PartialEq,
Eq,
derive_more::Display,
derive_more::IsVariant,
derive_more::Unwrap,
derive_more::TryUnwrap,
)]
#[display("{}", self.as_str())]
#[unwrap(ref, ref_mut)]
#[try_unwrap(ref, ref_mut)]
pub enum FinishReason {
Eos,
Length,
Stop(String),
}
impl FinishReason {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Eos | Self::Stop(_) => "stop",
Self::Length => "length",
}
}
pub fn stop_sequence(&self) -> Option<&str> {
match self {
Self::Stop(s) => Some(s.as_str()),
_ => None,
}
}
}
#[derive(Debug)]
pub struct GenStep {
pub token: u32,
pub logprobs: Option<Array>,
pub step_index: usize,
pub finish_reason: Option<FinishReason>,
}
impl From<GenStep> for (u32, Option<Array>) {
fn from(s: GenStep) -> Self {
(s.token, s.logprobs)
}
}
pub(crate) struct Generator<'a, M: Model + ?Sized> {
model: &'a M,
cache: Vec<Box<dyn KvCache>>,
sampler: Sampler,
processors: Vec<LogitsProcessor>,
prompt: Vec<u32>,
prefill_offset: usize,
history: Vec<u32>,
last: Option<u32>,
produced: usize,
max_tokens: usize,
prefill_step_size: usize,
eos: Vec<u32>,
collect_logprobs: bool,
needs_logprobs: bool,
temp_stochastic: bool,
prefilled: bool,
first_step: bool,
pending_err: Option<Error>,
done: bool,
}
impl<M: Model + ?Sized> Generator<'_, M> {
pub fn into_cache(self) -> Vec<Box<dyn KvCache>> {
self.cache
}
fn prefill(&mut self) -> Result<()> {
while self.prompt.len() - self.prefill_offset > 1 {
let remaining = (self.prompt.len() - self.prefill_offset) - 1;
let n = self.prefill_step_size.min(remaining);
let chunk = token_window(&self.prompt[self.prefill_offset..self.prefill_offset + n])?;
let _ = self.model.forward(&chunk, &mut self.cache)?;
self.prefill_offset += n;
}
Ok(())
}
fn step(&mut self, input: &[u32]) -> Result<GenStep> {
let tokens = token_window(input)?;
let logits = self.model.forward(&tokens, &mut self.cache)?;
let logits = last_position(&logits)?;
let mut logits = logits;
if !self.processors.is_empty() && !input.is_empty() {
try_extend_from_slice(&mut self.history, input)?;
for p in &self.processors {
logits = p.apply(&self.history, &logits)?;
}
}
let needs_normalization = self.collect_logprobs || self.needs_logprobs;
let sampler_input: Option<Array> = match (needs_normalization, self.temp_stochastic) {
(true, _) => {
let lse = ops::reduction::logsumexp(&logits, true)?;
Some(ops::arithmetic::subtract(&logits, &lse)?)
}
(false, true) => {
let m = ops::reduction::max(&logits, true)?;
Some(ops::arithmetic::subtract(&logits, &m)?)
}
(false, false) => None,
};
let mut sampled = self
.sampler
.sample(sampler_input.as_ref().unwrap_or(&logits))?;
let token: u32 = sampled.item::<u32>()?;
let logprobs = if self.collect_logprobs {
Some(ops::shape::squeeze_axes(
sampler_input
.as_ref()
.expect("sampler_input is Some (full normalization) when collect_logprobs == true"),
&[0],
)?)
} else {
None
};
Ok(GenStep {
token,
logprobs,
step_index: self.produced,
finish_reason: None,
})
}
}
impl<M: Model + ?Sized> Iterator for Generator<'_, M> {
type Item = Result<GenStep>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
if let Some(e) = self.pending_err.take() {
self.done = true;
return Some(Err(e));
}
if self.produced >= self.max_tokens {
self.done = true;
return None;
}
if !self.prefilled {
self.prefilled = true;
if let Err(e) = self.prefill() {
self.done = true;
return Some(Err(e));
}
}
let input: Vec<u32> = if self.first_step {
self.first_step = false;
self.prompt[self.prefill_offset..].to_vec()
} else {
match self.last {
Some(t) => vec![t],
None => {
self.done = true;
return None;
}
}
};
match self.step(&input) {
Ok(mut step) => {
self.produced += 1;
self.last = Some(step.token);
if self.eos.contains(&step.token) {
self.done = true;
step.finish_reason = Some(FinishReason::Eos);
}
Some(Ok(step))
}
Err(e) => {
self.done = true;
Some(Err(e))
}
}
}
}
fn token_window(ids: &[u32]) -> Result<Array> {
let mut row: Vec<i32> = try_with_capacity(ids.len())?;
row.extend(ids.iter().map(|&t| t as i32));
Array::from_slice::<i32>(&row, &(1usize, row.len()))
}
fn last_position(logits: &Array) -> Result<Array> {
let shape = logits.shape();
if shape.len() != 3 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"generate::last_position: expected [B, S, V] logits from `forward`",
shape.len() as u32,
shape.to_vec(),
)));
}
if shape[1] == 0 || shape[2] == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"generate::last_position: forward logits axes (S and V)",
"must be >= 1 (logits[:, -1, :] requires S >= 1 and V >= 1)",
format_smolstr!("S={}, V={}", shape[1], shape[2]),
)));
}
let (b, s, v) = (shape[0] as i32, shape[1] as i32, shape[2] as i32);
let sliced = ops::indexing::slice(logits, &[0, s - 1, 0], &[b, s, v], &[1, 1, 1])?;
ops::shape::squeeze_axes(&sliced, &[1])
}
pub fn generate_step<'a, M: Model + ?Sized>(
model: &'a M,
prompt: &[u32],
cache: Vec<Box<dyn KvCache>>,
cfg: GenConfig,
) -> impl Iterator<Item = Result<GenStep>> + 'a {
build_generator(model, prompt, cache, cfg)
}
pub(crate) fn build_generator<'a, M: Model + ?Sized>(
model: &'a M,
prompt: &[u32],
cache: Vec<Box<dyn KvCache>>,
cfg: GenConfig,
) -> Generator<'a, M> {
let built = (|| -> Result<(Sampler, Vec<LogitsProcessor>)> {
if prompt.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"generate: prompt",
)));
}
cfg.validate()?;
let sampler = make_sampler(
cfg.temp,
cfg.top_p,
cfg.min_p,
cfg.min_tokens_to_keep,
cfg.top_k,
cfg.xtc_probability,
cfg.xtc_threshold,
&cfg.xtc_special_tokens,
cfg.seed,
)?;
let processors = make_logits_processors(
&cfg.logit_bias,
cfg.repetition_penalty,
cfg.repetition_context_size,
cfg.presence_penalty,
cfg.presence_context_size,
cfg.frequency_penalty,
cfg.frequency_context_size,
)?;
Ok((sampler, processors))
})();
let collect_logprobs = cfg.collect_logprobs;
let needs_logprobs = cfg.temp != 0.0 && cfg.top_p > 0.0 && cfg.top_p < 1.0;
let temp_stochastic = cfg.temp > 0.0;
match built {
Ok((sampler, processors)) => Generator {
model,
cache,
sampler,
processors,
prompt: prompt.to_vec(),
prefill_offset: 0,
history: Vec::new(),
last: None,
produced: 0,
max_tokens: cfg.max_tokens,
prefill_step_size: cfg.prefill_step_size.max(1),
eos: cfg.eos,
collect_logprobs,
needs_logprobs,
temp_stochastic,
prefilled: false,
first_step: true,
pending_err: None,
done: false,
},
Err(e) => Generator {
model,
cache,
sampler: Sampler::Argmax,
processors: Vec::new(),
prompt: Vec::new(),
prefill_offset: 0,
history: Vec::new(),
last: None,
produced: 0,
max_tokens: cfg.max_tokens,
prefill_step_size: 1,
eos: Vec::new(),
collect_logprobs,
needs_logprobs,
temp_stochastic,
prefilled: true,
first_step: false,
pending_err: Some(e),
done: false,
},
}
}
#[derive(Debug)]
pub struct GenerationResponse {
pub text: String,
pub token: u32,
pub logprobs: Option<Array>,
pub prompt_tokens: usize,
pub prompt_tps: f64,
pub generation_tokens: usize,
pub generation_tps: f64,
pub peak_memory_bytes: Option<u64>,
pub finish_reason: Option<FinishReason>,
}
#[derive(Debug, Clone, Copy)]
pub struct GenerationStats {
pub prompt_tokens: usize,
pub generation_tokens: usize,
pub prompt_tps: f64,
pub generation_tps: f64,
pub peak_memory_bytes: Option<u64>,
}
pub fn stream_generate<'a, M: Model + ?Sized>(
model: &'a M,
tokenizer: &'a crate::tokenizer::Tokenizer,
prompt: &[u32],
cache: Vec<Box<dyn KvCache>>,
cfg: GenConfig,
) -> impl Iterator<Item = Result<GenerationResponse>> + 'a {
use std::time::Instant;
let prompt_tokens = prompt.len();
let mut cfg = cfg;
cfg.eos = tokenizer.eos_token_ids_iter().collect();
let max_tokens = cfg.max_tokens;
let eos: Vec<u32> = cfg.eos.clone();
let matcher = crate::lm::stop::StopMatcher::new(cfg.stop_strings.clone());
let mut emitted_len: usize = 0;
let mut steps = generate_step(model, prompt, cache, cfg);
let mut detok = tokenizer.detokenizer();
let mut n: usize = 0;
let mut finished = false;
let mut tic = Instant::now();
let mut prompt_tps = 0.0_f64;
std::iter::from_fn(move || {
if finished {
return None;
}
let GenStep {
token, logprobs, ..
} = match steps.next()? {
Ok(step) => step,
Err(e) => {
finished = true;
return Some(Err(e));
}
};
if n == 0 {
let prompt_time = tic.elapsed().as_secs_f64();
prompt_tps = if prompt_time > 0.0 {
prompt_tokens as f64 / prompt_time
} else {
0.0
};
tic = Instant::now();
}
let gen_tps = |gen_count: usize| -> f64 {
let dt = tic.elapsed().as_secs_f64();
if dt > 0.0 { gen_count as f64 / dt } else { 0.0 }
};
let peak = crate::memory::peak_memory().ok();
if eos.contains(&token) {
finished = true;
detok.finalize();
let (text, reason) = if matcher.is_active() {
finalize_active_tail(&detok, &matcher, &mut emitted_len, FinishReason::Eos)
} else {
(detok.last_segment(), FinishReason::Eos)
};
return Some(Ok(GenerationResponse {
text,
token,
logprobs,
prompt_tokens,
prompt_tps,
generation_tokens: n + 1,
generation_tps: gen_tps(n + 1),
peak_memory_bytes: peak,
finish_reason: Some(reason),
}));
}
detok.add_token(token);
n += 1;
if matcher.is_active() {
let full = detok.text();
match matcher.step(&full) {
crate::lm::stop::StopDecision::Stop(p) => {
finished = true;
let end = p.trimmed_len().max(emitted_len).min(full.len());
let text = full[emitted_len..end].to_string();
let stop = p.stop().to_owned();
emitted_len = end;
drop(full);
return Some(Ok(GenerationResponse {
text,
token,
logprobs,
prompt_tokens,
prompt_tps,
generation_tokens: n,
generation_tps: gen_tps(n),
peak_memory_bytes: peak,
finish_reason: Some(FinishReason::Stop(stop)),
}));
}
crate::lm::stop::StopDecision::Continue(p) => {
if n >= max_tokens {
finished = true;
drop(full);
detok.finalize();
let (text, reason) =
finalize_active_tail(&detok, &matcher, &mut emitted_len, FinishReason::Length);
return Some(Ok(GenerationResponse {
text,
token,
logprobs,
prompt_tokens,
prompt_tps,
generation_tokens: n,
generation_tps: gen_tps(n),
peak_memory_bytes: peak,
finish_reason: Some(reason),
}));
}
let safe_len = p.safe_len();
let end = safe_len.max(emitted_len).min(full.len());
let text = full[emitted_len..end].to_string();
emitted_len = end;
drop(full);
return Some(Ok(GenerationResponse {
text,
token,
logprobs,
prompt_tokens,
prompt_tps,
generation_tokens: n,
generation_tps: gen_tps(n),
peak_memory_bytes: peak,
finish_reason: None,
}));
}
}
}
if n >= max_tokens {
finished = true;
detok.finalize();
let text = detok.last_segment();
return Some(Ok(GenerationResponse {
text,
token,
logprobs,
prompt_tokens,
prompt_tps,
generation_tokens: n,
generation_tps: gen_tps(n),
peak_memory_bytes: peak,
finish_reason: Some(FinishReason::Length),
}));
}
let text = detok.last_segment();
Some(Ok(GenerationResponse {
text,
token,
logprobs,
prompt_tokens,
prompt_tps,
generation_tokens: n,
generation_tps: gen_tps(n),
peak_memory_bytes: peak,
finish_reason: None,
}))
})
}
fn finalize_active_tail(
detok: &dyn crate::tokenizer::StreamingDetokenizer,
matcher: &crate::lm::stop::StopMatcher,
emitted_len: &mut usize,
default_reason: FinishReason,
) -> (String, FinishReason) {
let full = detok.text();
match matcher.step(&full) {
crate::lm::stop::StopDecision::Stop(p) => {
let end = p.trimmed_len().max(*emitted_len).min(full.len());
let text = full[*emitted_len..end].to_string();
*emitted_len = end;
(text, FinishReason::Stop(p.stop().to_owned()))
}
crate::lm::stop::StopDecision::Continue(_) => {
let start = (*emitted_len).min(full.len());
let text = full[start..].to_string();
*emitted_len = full.len();
(text, default_reason)
}
}
}
pub fn generate<M: Model + ?Sized>(
model: &M,
tokenizer: &crate::tokenizer::Tokenizer,
prompt: &[u32],
cache: Vec<Box<dyn KvCache>>,
cfg: GenConfig,
) -> Result<(String, GenerationStats)> {
let prompt_tokens = prompt.len();
let mut text = String::new();
let mut final_response: Option<GenerationResponse> = None;
for response in stream_generate(model, tokenizer, prompt, cache, cfg) {
let response = response?;
text.push_str(&response.text);
final_response = Some(response);
}
let stats = match final_response {
Some(r) => GenerationStats {
prompt_tokens: r.prompt_tokens,
generation_tokens: r.generation_tokens,
prompt_tps: r.prompt_tps,
generation_tps: r.generation_tps,
peak_memory_bytes: r.peak_memory_bytes,
},
None => GenerationStats {
prompt_tokens,
generation_tokens: 0,
prompt_tps: 0.0,
generation_tps: 0.0,
peak_memory_bytes: crate::memory::peak_memory().ok(),
},
};
Ok((text, stats))
}
#[derive(Debug)]
pub struct BatchGenStep {
pub row: usize,
pub token: u32,
pub logprobs: Array,
pub finish_reason: Option<FinishReason>,
}
pub struct BatchGenerator<'a, M: Model + ?Sized> {
model: &'a M,
cache: Vec<Box<dyn KvCache>>,
sampler: Sampler,
processors: Vec<LogitsProcessor>,
padded_rows: Vec<Vec<u32>>,
max_len: usize,
prefill_offset: usize,
history: Vec<Vec<u32>>,
last: Vec<u32>,
produced: Vec<usize>,
finished: Vec<Option<FinishReason>>,
pending_emit: std::collections::VecDeque<BatchGenStep>,
pad_token_id: u32,
max_tokens: usize,
prefill_step_size: usize,
eos: Vec<u32>,
prefilled: bool,
first_step: bool,
pending_err: Option<Error>,
done: bool,
}
impl<M: Model + ?Sized> BatchGenerator<'_, M> {
fn batch_size(&self) -> usize {
self.padded_rows.len()
}
fn prefill(&mut self) -> Result<()> {
while self.max_len - self.prefill_offset > 1 {
let remaining = (self.max_len - self.prefill_offset) - 1;
let n = self.prefill_step_size.min(remaining);
let chunk = batch_token_window(
&self.padded_rows,
self.prefill_offset,
self.prefill_offset + n,
)?;
let _ = self.model.forward(&chunk, &mut self.cache)?;
self.prefill_offset += n;
}
Ok(())
}
fn step(&mut self, input: &[u32]) -> Result<Vec<BatchGenStep>> {
let b = self.batch_size();
let tokens = batch_full_window(input, b, input.len() / b)?;
let logits = self.model.forward(&tokens, &mut self.cache)?;
let mut logits = last_position(&logits)?;
if !self.processors.is_empty() && !input.is_empty() {
let s = input.len() / b;
let mut row_logits: Vec<Array> = try_with_capacity(b)?;
for (row, hist) in self.history.iter_mut().enumerate().take(b) {
let row_input = &input[row * s..(row + 1) * s];
try_extend_from_slice(hist, row_input)?;
let v = logits.shape()[1] as i32;
let row_logit =
ops::indexing::slice(&logits, &[row as i32, 0], &[(row + 1) as i32, v], &[1, 1])?;
let mut row_l = row_logit;
for p in &self.processors {
row_l = p.apply(hist, &row_l)?;
}
row_logits.push(row_l);
}
let row_refs: Vec<&Array> = row_logits.iter().collect();
logits = ops::shape::concatenate(&row_refs, 0)?;
}
let lse = ops::reduction::logsumexp_axes(&logits, &[-1], true)?;
let logprobs = ops::arithmetic::subtract(&logits, &lse)?;
let mut sampled = self.sampler.sample(&logprobs)?;
let tokens: Vec<u32> = sampled.to_vec::<u32>()?;
if tokens.len() != b {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"batch_generate: sampler returned tokens (must be one per row)",
b,
tokens.len(),
)));
}
let mut steps: Vec<BatchGenStep> = try_with_capacity(b)?;
let v = logprobs.shape()[1] as i32;
for (row, &tok) in tokens.iter().enumerate() {
let row_lp =
ops::indexing::slice(&logprobs, &[row as i32, 0], &[(row + 1) as i32, v], &[1, 1])?;
let row_lp = ops::shape::squeeze_axes(&row_lp, &[0])?;
let prior = self.finished[row].clone();
let new_reason: Option<FinishReason> = if prior.is_some() {
prior
} else if self.eos.contains(&tok) {
Some(FinishReason::Eos)
} else {
if self.produced[row] + 1 >= self.max_tokens {
Some(FinishReason::Length)
} else {
None
}
};
steps.push(BatchGenStep {
row,
token: tok,
logprobs: row_lp,
finish_reason: new_reason,
});
}
Ok(steps)
}
}
impl<M: Model + ?Sized> Iterator for BatchGenerator<'_, M> {
type Item = Result<BatchGenStep>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(step) = self.pending_emit.pop_front() {
return Some(Ok(step));
}
if self.done {
return None;
}
if let Some(e) = self.pending_err.take() {
self.done = true;
return Some(Err(e));
}
if self
.produced
.iter()
.zip(self.finished.iter())
.all(|(&p, f)| f.is_some() || p >= self.max_tokens)
{
self.done = true;
return None;
}
if !self.prefilled {
self.prefilled = true;
if let Err(e) = self.prefill() {
self.done = true;
return Some(Err(e));
}
}
let b = self.batch_size();
let input: Vec<u32> = if self.first_step {
self.first_step = false;
let tail_len = self.max_len - self.prefill_offset;
let mut buf = match try_with_capacity::<u32>(b * tail_len) {
Ok(b) => b,
Err(e) => {
self.done = true;
return Some(Err(e));
}
};
for row in &self.padded_rows {
buf.extend_from_slice(&row[self.prefill_offset..self.prefill_offset + tail_len]);
}
buf
} else {
self.last.clone()
};
let b = self.batch_size();
let mut was_unfinished: Vec<bool> = match try_with_capacity(b) {
Ok(v) => v,
Err(e) => {
self.done = true;
return Some(Err(e));
}
};
for f in &self.finished {
was_unfinished.push(f.is_none());
}
let steps = match self.step(&input) {
Ok(s) => s,
Err(e) => {
self.done = true;
return Some(Err(e));
}
};
for step in &steps {
let row = step.row;
if self.finished[row].is_some() {
self.last[row] = self.pad_token_id;
continue;
}
self.last[row] = step.token;
self.produced[row] += 1;
if let Some(ref reason) = step.finish_reason {
self.finished[row] = Some(reason.clone());
if reason.is_eos() {
self.last[row] = self.pad_token_id;
}
}
}
for step in steps {
if was_unfinished[step.row] {
self.pending_emit.push_back(step);
}
}
if self.finished.iter().all(|r| r.is_some()) {
self.done = true;
}
self.pending_emit.pop_front().map(Ok)
}
}
fn left_pad_rows(prompts: &[&[u32]], pad_token_id: u32) -> Result<(Vec<Vec<u32>>, usize)> {
if prompts.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"batch_generate: prompts",
)));
}
let max_len = prompts.iter().map(|p| p.len()).max().unwrap_or(0);
if max_len == 0 {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"batch_generate: every prompt",
)));
}
let mut padded: Vec<Vec<u32>> = try_with_capacity(prompts.len())?;
for p in prompts {
if p.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"batch_generate: every prompt",
)));
}
let mut row: Vec<u32> = try_with_capacity(max_len)?;
for _ in 0..(max_len - p.len()) {
row.push(pad_token_id);
}
try_extend_from_slice(&mut row, p)?;
padded.push(row);
}
Ok((padded, max_len))
}
fn batch_token_window(rows: &[Vec<u32>], start: usize, end: usize) -> Result<Array> {
let b = rows.len();
let s = end - start;
let mut buf: Vec<i32> = try_with_capacity(b * s)?;
for row in rows {
buf.extend(row[start..end].iter().map(|&t| t as i32));
}
Array::from_slice::<i32>(&buf, &(b, s))
}
fn batch_full_window(flat: &[u32], b: usize, s: usize) -> Result<Array> {
let mut buf: Vec<i32> = try_with_capacity(flat.len())?;
buf.extend(flat.iter().map(|&t| t as i32));
Array::from_slice::<i32>(&buf, &(b, s))
}
pub fn batch_left_padding(prompts: &[&[u32]]) -> Vec<i32> {
let max_len = prompts.iter().map(|p| p.len()).max().unwrap_or(0);
prompts.iter().map(|p| (max_len - p.len()) as i32).collect()
}
pub fn batch_generate_step<'a, M: Model + ?Sized>(
model: &'a M,
prompts: &[&[u32]],
pad_token_id: u32,
cache: Vec<Box<dyn KvCache>>,
cfg: GenConfig,
) -> BatchGenerator<'a, M> {
type Built = (Vec<Vec<u32>>, usize, Sampler, Vec<LogitsProcessor>);
let built = (|| -> Result<Built> {
cfg.validate()?;
let (padded_rows, max_len) = left_pad_rows(prompts, pad_token_id)?;
let sampler = make_sampler(
cfg.temp,
cfg.top_p,
cfg.min_p,
cfg.min_tokens_to_keep,
cfg.top_k,
cfg.xtc_probability,
cfg.xtc_threshold,
&cfg.xtc_special_tokens,
cfg.seed,
)?;
let processors = make_logits_processors(
&cfg.logit_bias,
cfg.repetition_penalty,
cfg.repetition_context_size,
cfg.presence_penalty,
cfg.presence_context_size,
cfg.frequency_penalty,
cfg.frequency_context_size,
)?;
Ok((padded_rows, max_len, sampler, processors))
})();
match built {
Ok((padded_rows, max_len, sampler, processors)) => {
let b = padded_rows.len();
BatchGenerator {
model,
cache,
sampler,
processors,
padded_rows,
max_len,
prefill_offset: 0,
history: vec![Vec::new(); b],
last: vec![pad_token_id; b],
produced: vec![0; b],
finished: vec![None; b],
pending_emit: std::collections::VecDeque::new(),
pad_token_id,
max_tokens: cfg.max_tokens,
prefill_step_size: cfg.prefill_step_size.max(1),
eos: cfg.eos,
prefilled: false,
first_step: true,
pending_err: None,
done: false,
}
}
Err(e) => BatchGenerator {
model,
cache,
sampler: Sampler::Argmax,
processors: Vec::new(),
padded_rows: Vec::new(),
max_len: 0,
prefill_offset: 0,
history: Vec::new(),
last: Vec::new(),
produced: Vec::new(),
finished: Vec::new(),
pending_emit: std::collections::VecDeque::new(),
pad_token_id,
max_tokens: cfg.max_tokens,
prefill_step_size: 1,
eos: Vec::new(),
prefilled: true,
first_step: false,
pending_err: Some(e),
done: false,
},
}
}
pub fn batch_stream_generate<'a, M: Model + ?Sized>(
model: &'a M,
tokenizer: &'a crate::tokenizer::Tokenizer,
prompts: &[&[u32]],
pad_token_id: u32,
cache: Vec<Box<dyn KvCache>>,
cfg: GenConfig,
) -> BatchGenerator<'a, M> {
let mut cfg = cfg;
cfg.eos = tokenizer.eos_token_ids_iter().collect();
batch_generate_step(model, prompts, pad_token_id, cache, cfg)
}
pub fn batch_generate<M: Model + ?Sized>(
model: &M,
tokenizer: &crate::tokenizer::Tokenizer,
prompts: &[&[u32]],
pad_token_id: u32,
cache: Vec<Box<dyn KvCache>>,
cfg: GenConfig,
) -> Result<Vec<Vec<u32>>> {
let b = prompts.len();
let mut results: Vec<Vec<u32>> = try_with_capacity(b)?;
for _ in 0..b {
results.push(Vec::new());
}
for step in batch_stream_generate(model, tokenizer, prompts, pad_token_id, cache, cfg) {
let step = step?;
let row = step.row;
if row >= results.len() {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"batch_generate: step row",
"must be < prompts count",
format_smolstr!("{row} (prompts={b})"),
)));
}
match &step.finish_reason {
Some(r) if r.is_eos() => {
}
_ => {
results[row].push(step.token);
}
}
}
Ok(results)
}
#[cfg(test)]
mod batch_tests;
#[cfg(test)]
mod stop_sequence_tests;
#[cfg(test)]
mod tests;