use anyhow::Result;
use candle_core::{DType, Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use rocket::async_stream::stream;
use rocket::futures::Stream;
use std::time::Instant;
use crate::{
models::common::{
InferenceModel, MultiModalData,
sample::{get_logit_processor, use_repeat_penalty},
},
params::chat::{ChatCompletionChunkResponse, ChatCompletionParameters, ChatCompletionResponse},
tokenizer::TokenizerModel,
utils::response_utils::{
build_chunk_response_with_reasoning, build_chunk_response_with_usage,
build_completion_chunk_response, build_completion_response_with_time,
},
};
pub struct GenerationContext {
pub logit_processor: LogitsProcessor,
pub repeat_penalty: f32,
pub repeat_last_n: usize,
pub seqlen_offset: usize,
pub seq_len: usize,
pub sample_len: u32,
pub device: Device,
}
impl GenerationContext {
pub fn new(
temperature: Option<f32>,
top_p: Option<f32>,
top_k: Option<usize>,
repeat_penalty: Option<f32>,
repeat_last_n: Option<usize>,
seed: u64,
initial_seq_len: usize,
max_tokens: u32,
device: Device,
) -> Self {
Self {
logit_processor: get_logit_processor(temperature, top_p, top_k, seed),
repeat_penalty: repeat_penalty.unwrap_or(1.0),
repeat_last_n: repeat_last_n.unwrap_or(64),
seqlen_offset: 0,
seq_len: initial_seq_len,
sample_len: max_tokens,
device,
}
}
pub fn prepare_for_next_token(&mut self, token: u32) -> Result<Tensor> {
self.update_status();
self.create_input_ids(token)
}
fn update_status(&mut self) {
self.seqlen_offset += self.seq_len;
self.seq_len = 1;
}
fn create_input_ids(&self, token: u32) -> Result<Tensor> {
Ok(Tensor::from_vec(vec![token], (1, 1), &self.device)?)
}
}
fn sample_and_push(
ctx: &mut GenerationContext,
logits: &Tensor,
generated: &mut Vec<u32>,
) -> Result<u32> {
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = use_repeat_penalty(
ctx.repeat_penalty,
Some(ctx.repeat_last_n),
&logits,
generated,
)?;
let token = ctx.logit_processor.sample(&logits)?;
generated.push(token);
Ok(token)
}
pub fn generate_generic_text<M: InferenceModel>(
model: &mut M,
tokenizer: &TokenizerModel,
input_ids: Tensor,
data: MultiModalData,
ctx: &mut GenerationContext,
) -> Result<String> {
let mut generated = Vec::new();
let eos_ids = model.stop_token_ids();
let logits = model.forward_initial(&input_ids, ctx.seqlen_offset, data)?;
let next_token = sample_and_push(ctx, &logits, &mut generated)?;
let mut input_ids = ctx.prepare_for_next_token(next_token)?;
for _ in 1..ctx.sample_len {
let logits = model.forward_step(&input_ids, ctx.seqlen_offset)?;
let next_token = sample_and_push(ctx, &logits, &mut generated)?;
if eos_ids.contains(&next_token) {
break;
}
input_ids = ctx.prepare_for_next_token(next_token)?;
}
model.clear_cache();
let text = tokenizer.token_decode(generated)?;
Ok(text)
}
pub fn generate_generic<M: InferenceModel>(
model: &mut M,
tokenizer: &TokenizerModel,
input_ids: Tensor,
data: MultiModalData,
ctx: &mut GenerationContext,
model_name: &str,
) -> Result<ChatCompletionResponse> {
let prompt_tokens = ctx.seq_len as u32;
let mut generated = Vec::new();
let eos_ids = model.stop_token_ids();
let i_start = Instant::now();
let logits = model.forward_initial(&input_ids, ctx.seqlen_offset, data)?;
let next_token = sample_and_push(ctx, &logits, &mut generated)?;
let i_duration = i_start.elapsed();
let prompt_secs = i_duration.as_secs_f64();
let mut input_ids = ctx.prepare_for_next_token(next_token)?;
let i_start = Instant::now();
for _ in 1..ctx.sample_len {
let logits = model.forward_step(&input_ids, ctx.seqlen_offset)?;
let next_token = sample_and_push(ctx, &logits, &mut generated)?;
if eos_ids.contains(&next_token) {
break;
}
input_ids = ctx.prepare_for_next_token(next_token)?;
}
let i_duration = i_start.elapsed();
let completion_secs = i_duration.as_secs_f64();
model.clear_cache();
let num_tokens = generated.len() as u32;
let text = tokenizer.token_decode(generated)?;
Ok(build_completion_response_with_time(
text,
model_name,
Some(num_tokens),
Some(completion_secs),
Some(prompt_tokens),
Some(prompt_secs),
))
}
pub fn generate_stream_generic_text<M: InferenceModel>(
model: &mut M,
tokenizer: &TokenizerModel,
input_ids: Tensor,
data: MultiModalData,
temperature: Option<f32>,
top_p: Option<f32>,
top_k: Option<usize>,
repeat_penalty: Option<f32>,
repeat_last_n: Option<usize>,
seed: u64,
max_tokens: u32,
device: &Device,
) -> Result<impl Stream<Item = Result<String, anyhow::Error>>> {
let mut ctx = GenerationContext::new(
temperature,
top_p,
top_k,
repeat_penalty,
repeat_last_n,
seed,
input_ids.dim(1)?,
max_tokens,
device.clone(),
);
let mut error_tokens = Vec::new();
let eos_ids = model.stop_token_ids();
let stream = stream! {
let mut input_ids = input_ids;
let mut generated = Vec::new();
for _ in 0..ctx.sample_len {
let logits = if ctx.seqlen_offset == 0 {
model.forward_initial(&input_ids, ctx.seqlen_offset, data.clone())
} else {
model.forward_step(&input_ids, ctx.seqlen_offset)
}?;
let next_token = sample_and_push(&mut ctx, &logits, &mut generated)?;
let decode_ids = if error_tokens.is_empty() {
vec![next_token]
} else {
let mut ids = error_tokens.clone();
ids.push(next_token);
ids
};
let decoded = tokenizer.token_decode(decode_ids)?;
if decoded.contains("�") {
error_tokens.push(next_token);
if error_tokens.len() > 3 {
error_tokens.clear();
}
input_ids = ctx.prepare_for_next_token(next_token)?;
continue;
}
error_tokens.clear();
yield Ok(decoded);
if eos_ids.contains(&next_token) {
break;
}
input_ids = ctx.prepare_for_next_token(next_token)?;
}
model.clear_cache();
};
Ok(stream)
}
pub fn generate_stream_generic<M: InferenceModel>(
model: &mut M,
tokenizer: &TokenizerModel,
input_ids: Tensor,
data: MultiModalData,
temperature: Option<f32>,
top_p: Option<f32>,
top_k: Option<usize>,
repeat_penalty: Option<f32>,
repeat_last_n: Option<usize>,
seed: u64,
max_tokens: u32,
in_reasoning: bool,
device: &Device,
model_name: &str,
) -> Result<impl Stream<Item = Result<ChatCompletionChunkResponse, anyhow::Error>>> {
let mut ctx = GenerationContext::new(
temperature,
top_p,
top_k,
repeat_penalty,
repeat_last_n,
seed,
input_ids.dim(1)?,
max_tokens,
device.clone(),
);
let prompt_tokens = ctx.seq_len as u32;
let mut prompt_secs = 0.0f64;
let mut completion_tokens = 0u32;
let mut completion_secs = 0.0f64;
let mut error_tokens = Vec::new();
let eos_ids = model.stop_token_ids();
let stream = stream! {
let mut input_ids = input_ids;
let mut tool_call_id = None;
let mut tool_call_content = String::new();
let mut in_reasoning = in_reasoning;
let mut generated = Vec::new();
for _ in 0..ctx.sample_len {
let i_start = Instant::now();
let logits = if ctx.seqlen_offset == 0 {
model.forward_initial(&input_ids, ctx.seqlen_offset, data.clone())
} else {
model.forward_step(&input_ids, ctx.seqlen_offset)
}?;
let next_token = sample_and_push(&mut ctx, &logits, &mut generated)?;
completion_tokens += 1;
let i_duration = i_start.elapsed();
if ctx.seqlen_offset == 0 {
prompt_secs += i_duration.as_secs_f64();
} else {
completion_secs += i_duration.as_secs_f64();
};
let decode_ids = if error_tokens.is_empty() {
vec![next_token]
} else {
let mut ids = error_tokens.clone();
ids.push(next_token);
ids
};
let decoded = tokenizer.token_decode(decode_ids)?;
if decoded.contains("�") {
error_tokens.push(next_token);
if error_tokens.len() > 3 {
error_tokens.clear();
}
input_ids = ctx.prepare_for_next_token(next_token)?;
continue;
}
error_tokens.clear();
if decoded.eq("<think>") {
in_reasoning = true;
input_ids = ctx.prepare_for_next_token(next_token)?;
continue;
}
if decoded.eq("</think>") {
in_reasoning = false;
input_ids = ctx.prepare_for_next_token(next_token)?;
continue;
}
match decoded.as_str() {
"<tool_call>" => {
tool_call_id = Some(uuid::Uuid::new_v4().to_string());
input_ids = ctx.prepare_for_next_token(next_token)?;
continue;
}
"</tool_call>" => {
let chunk = build_completion_chunk_response(
decoded,
model_name,
tool_call_id.clone(),
Some(tool_call_content.clone())
);
tool_call_id = None;
tool_call_content = String::new();
yield Ok(chunk);
}
_ => {
if tool_call_id.is_some() {
tool_call_content.push_str(&decoded);
input_ids = ctx.prepare_for_next_token(next_token)?;
continue;
} else {
let chunk = if in_reasoning {
build_chunk_response_with_reasoning(decoded, model_name)
} else {
build_completion_chunk_response(
decoded, model_name,
None,
None
)};
yield Ok(chunk);
}
}
}
if eos_ids.contains(&next_token) {
yield Ok(build_chunk_response_with_usage(model_name, completion_tokens.into(), completion_secs.into(), prompt_tokens.into(), prompt_secs.into()));
break;
}
input_ids = ctx.prepare_for_next_token(next_token)?;
}
model.clear_cache();
};
Ok(stream)
}
pub struct PrepareData {
pub in_reasoning: bool,
pub input_ids: Tensor,
pub multi_model_data: MultiModalData,
}
pub trait GenerationDataProvider {
fn get_temperature(&self, req_temp: Option<f32>) -> Option<f32> {
req_temp
}
fn get_top_p(&self, req_top_p: Option<f32>) -> Option<f32> {
req_top_p
}
fn get_top_k(&self, top_k: Option<usize>) -> Option<usize> {
top_k
}
fn is_in_reasoning(&self, text: &str) -> bool {
text.ends_with("<think>\n")
}
fn get_multi_model_data(&self) -> MultiModalData {
MultiModalData::new(vec![])
}
fn get_data(&self, mes: &ChatCompletionParameters) -> Result<PrepareData>;
}
#[macro_export]
macro_rules! impl_generate_model {
($struct_name: ty) => {
impl<'a> $crate::models::GenerateModel for $struct_name {
fn generate(
&mut self,
mes: $crate::params::chat::ChatCompletionParameters,
) -> anyhow::Result<$crate::params::chat::ChatCompletionResponse> {
let seed = mes.seed.unwrap_or(299792458) as u64;
let sample_len = mes.max_tokens.unwrap_or(1024);
let temperature = self.get_temperature(mes.temperature);
let top_p = self.get_top_p(mes.top_p);
let top_k = self.get_top_k(mes.top_k);
let prepare_data = self.get_data(&mes)?;
let input_ids = prepare_data.input_ids;
let data = prepare_data.multi_model_data;
let mut ctx = $crate::models::common::generate::GenerationContext::new(
temperature,
top_p,
top_k,
mes.repeat_penalty,
mes.repeat_last_n,
seed,
input_ids.dim(1)?,
sample_len,
self.device.clone(),
);
$crate::models::common::generate::generate_generic(
&mut self.model,
&self.tokenizer,
input_ids,
data,
&mut ctx,
&self.model_name,
)
}
fn generate_stream(
&mut self,
mes: $crate::params::chat::ChatCompletionParameters,
) -> anyhow::Result<
Box<
dyn rocket::futures::Stream<
Item = anyhow::Result<
$crate::params::chat::ChatCompletionChunkResponse,
>,
> + Send
+ Unpin
+ '_,
>,
> {
let seed = mes.seed.unwrap_or(299792458) as u64;
let prepare_data = self.get_data(&mes)?;
let input_ids = prepare_data.input_ids;
let data = prepare_data.multi_model_data;
let in_reasoning = prepare_data.in_reasoning;
let sample_len = mes.max_tokens.unwrap_or(1024);
let temperature = self.get_temperature(mes.temperature);
let top_p = self.get_top_p(mes.top_p);
let top_k = self.get_top_k(mes.top_k);
let stream = $crate::models::common::generate::generate_stream_generic(
&mut self.model,
&self.tokenizer,
input_ids,
data,
temperature,
top_p,
top_k,
mes.repeat_penalty,
mes.repeat_last_n,
seed,
sample_len,
in_reasoning,
&self.device,
&self.model_name,
)?;
Ok(Box::new(Box::pin(stream)))
}
}
};
}