use std::{collections::HashSet, sync::Arc};
use anyhow::{Error, Result};
use futures::stream::{self, StreamExt};
use tracing as log;
use crate::model_card::model::{ModelDeploymentCard, TokenizerKind};
use dynamo_runtime::{
pipeline::{
async_trait, AsyncEngineContextProvider, ManyOut, Operator, ResponseStream,
ServerStreamingEngine, SingleIn,
},
protocols::annotated::Annotated,
};
use crate::protocols::{
common::{
llm_backend::{BackendInput, BackendOutput, FinishReason, LLMEngineOutput},
StopConditions,
},
TokenIdType,
};
use crate::tokenizers::{DecodeStream, HuggingFaceTokenizer, Tokenizer};
use tokenizers::Tokenizer as HfTokenizer;
pub type ExecutionOutputStream = Annotated<LLMEngineOutput>;
pub type ExecutionContext = ServerStreamingEngine<BackendInput, ExecutionOutputStream>;
#[allow(dead_code)]
pub struct Backend {
pub tokenizer: Option<Tokenizer>, validate_engine_decode: bool, }
#[allow(dead_code)]
struct DecoderUnfoldState {
stream: ManyOut<ExecutionOutputStream>,
decoder: Decoder,
validate_engine_decode: bool,
}
impl Backend {
pub async fn from_tokenizer(tokenizer: HfTokenizer) -> Result<Arc<Self>> {
let tokenizer = HuggingFaceTokenizer::from_tokenizer(tokenizer);
let tokenizer = Tokenizer::from(Arc::new(tokenizer));
Ok(Arc::new(Self {
tokenizer: Some(tokenizer),
validate_engine_decode: false,
}))
}
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
let tokenizer = match &mdc.tokenizer {
Some(TokenizerKind::HfTokenizerJson(file)) => {
HfTokenizer::from_file(file).map_err(Error::msg)?
}
Some(TokenizerKind::GGUF(t)) => *t.clone(),
None => {
return Ok(Arc::new(Self {
tokenizer: None,
validate_engine_decode: false,
}));
}
};
Self::from_tokenizer(tokenizer).await
}
fn decoder(
&self,
stream: ManyOut<ExecutionOutputStream>,
stop_conditions: StopConditions,
) -> anyhow::Result<DecoderUnfoldState> {
let Some(tokenizer) = self.tokenizer.as_ref() else {
anyhow::bail!("Backend built from blank ModelDeploymentCard, no tokenizer");
};
let decoder = Decoder::new(tokenizer.decode_stream(false), stop_conditions);
Ok(DecoderUnfoldState {
stream,
decoder,
validate_engine_decode: self.validate_engine_decode,
})
}
}
#[async_trait]
impl
Operator<
SingleIn<BackendInput>,
ManyOut<Annotated<BackendOutput>>,
SingleIn<BackendInput>,
ManyOut<Annotated<LLMEngineOutput>>,
> for Backend
{
async fn generate(
&self,
request: SingleIn<BackendInput>,
next: ServerStreamingEngine<BackendInput, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<BackendOutput>>> {
let stop_conditions = request.stop_conditions.clone();
let next_stream = next.generate(request).await?;
let context = next_stream.context();
let state = self.decoder(next_stream, stop_conditions)?;
let processed_stream = stream::unfold(state, |mut state| async move {
match state.stream.next().await {
Some(output) => {
if output.is_event() || output.data.is_none() {
return Some((output, state));
}
if let Some(data) = &output.data {
if data.text.is_some() && !state.validate_engine_decode {
return Some((output, state));
}
}
let data = output.data.as_ref().unwrap();
let result = state.decoder.process_token_ids(&data.token_ids).unwrap();
let finish_reason = match &result.stop_trigger {
Some(StopTrigger::MaxTokensLimit) => Some(FinishReason::Length),
Some(StopTrigger::HiddenStopTokenDetected(_)) => Some(FinishReason::Stop),
Some(StopTrigger::HiddenStopSequenceDetected(_)) => {
Some(FinishReason::Stop)
}
None => None,
};
if data.finish_reason.is_none() && finish_reason.is_some() {
tracing::debug!(
?result.stop_trigger,
"upstream did not provide a finish reason; issuing a stop_generation request to free resources",
);
state.stream.context().stop_generating();
}
let text = result.text;
let tokens = result.tokens;
if state.validate_engine_decode {
if data.finish_reason != finish_reason {
log::warn!(
"finish reason mismatch: expected {:?}, got {:?}",
data.finish_reason,
finish_reason
);
}
if data.text.is_some() && data.text != text {
log::warn!("text mismatch: expected {:?}, got {:?}", data.text, text);
}
}
let mut output = output;
let mut data = output.data.take().unwrap();
data.finish_reason = finish_reason;
data.text = text;
data.tokens = Some(tokens);
output.data = Some(data);
Some((output, state))
}
None => None,
}
});
let stream = processed_stream.map(move |output| {
output.map_data(|data| {
Ok(BackendOutput {
token_ids: data.token_ids,
tokens: data.tokens.unwrap_or_default(),
text: data.text,
cum_log_probs: data.cum_log_probs,
log_probs: data.log_probs,
finish_reason: data.finish_reason,
})
})
});
Ok(ResponseStream::new(Box::pin(stream), context))
}
}
#[allow(dead_code)]
pub struct Decoder {
decode_stream: DecodeStream,
min_tokens: u32,
hidden_stop_ids: HashSet<TokenIdType>,
hidden_stop_sequences: Vec<String>,
generated_tokens: u32,
jail: String,
jail_max_bytes: usize,
jailed_bytes: usize,
}
#[allow(dead_code)]
#[derive(Debug)]
pub enum StopTrigger {
MaxTokensLimit,
HiddenStopTokenDetected(TokenIdType),
HiddenStopSequenceDetected(String),
}
impl StopTrigger {
pub fn should_hide_text(&self) -> bool {
match self {
StopTrigger::MaxTokensLimit => false,
StopTrigger::HiddenStopTokenDetected(_) => true,
StopTrigger::HiddenStopSequenceDetected(_) => true,
}
}
}
pub struct StepResult {
pub token: Option<String>,
pub stop_trigger: Option<StopTrigger>,
}
impl StepResult {
fn ok(token: Option<String>) -> Self {
Self {
token,
stop_trigger: None,
}
}
fn with_stop_trigger(token: Option<String>, stop_trigger: StopTrigger) -> Self {
Self {
token,
stop_trigger: Some(stop_trigger),
}
}
}
pub struct SeqResult {
pub tokens: Vec<Option<String>>, pub text: Option<String>, pub stop_trigger: Option<StopTrigger>, }
#[allow(dead_code)]
impl Decoder {
pub fn new(
decode_stream: DecodeStream,
stop_condition: StopConditions,
) -> Self {
let hidden_stop_ids: HashSet<TokenIdType> = stop_condition
.stop_token_ids_hidden
.unwrap_or_default()
.iter()
.copied()
.collect();
let hidden_stop_sequences: Vec<String> = stop_condition
.stop
.unwrap_or_default()
.iter()
.map(|x| x.to_string())
.collect();
let jail_max_bytes = hidden_stop_sequences
.iter()
.map(|x| x.len())
.max()
.unwrap_or(0);
Self {
decode_stream,
hidden_stop_ids,
hidden_stop_sequences,
min_tokens: stop_condition.min_tokens.unwrap_or(0),
generated_tokens: 0,
jail: String::new(),
jail_max_bytes,
jailed_bytes: 0,
}
}
pub fn step(&mut self, token_id: TokenIdType) -> Result<StepResult> {
self.generated_tokens += 1;
let token = self.decode_stream.step(token_id)?;
if self.generated_tokens < self.min_tokens {
return Ok(StepResult::ok(token));
}
if self.hidden_stop_ids.contains(&token_id) {
return Ok(StepResult::with_stop_trigger(
token,
StopTrigger::HiddenStopTokenDetected(token_id),
));
}
if self.jail_max_bytes > 0 {
if let Some(token) = &token {
let pre_append = self.jail.len();
log::debug!("pre_append: {}", pre_append);
log::debug!("jail: {}", self.jail);
self.jail.push_str(token);
log::debug!("post_append: {}", self.jail.len());
log::debug!("jail: {}", self.jail);
for seq in &self.hidden_stop_sequences {
log::debug!("stop seq: {}", seq);
if let Some(offset) =
galil_seiferas::gs_find(self.jail.as_bytes(), seq.as_bytes())
{
log::debug!("offset: {}", offset);
let partial_token = if offset >= pre_append {
self.jail[pre_append..offset].to_string()
} else {
"".to_string()
};
return Ok(StepResult::with_stop_trigger(
Some(partial_token),
StopTrigger::HiddenStopSequenceDetected(seq.to_string()),
));
}
}
if self.jail.len() > self.jail_max_bytes {
let drain_len = self.jail.len() - self.jail_max_bytes;
self.jail.drain(0..drain_len);
}
}
}
Ok(StepResult::ok(token))
}
pub fn process_token_ids(&mut self, token_ids: &[TokenIdType]) -> Result<SeqResult> {
let mut text: Option<String> = None;
let mut tokens = Vec::new();
for token_id in token_ids {
let StepResult {
token,
stop_trigger,
} = self.step(*token_id)?;
let hide_text = stop_trigger
.as_ref()
.map(|x| x.should_hide_text())
.unwrap_or(false);
if !hide_text {
if let Some(token) = &token {
text.get_or_insert_with(String::new).push_str(token);
}
}
tokens.push(token);
if let Some(stop_trigger) = stop_trigger {
return Ok(SeqResult {
tokens,
text,
stop_trigger: Some(stop_trigger),
});
}
}
Ok(SeqResult {
tokens,
text,
stop_trigger: None,
})
}
fn return_token(&self, token: Option<String>) -> StepResult {
StepResult {
token,
stop_trigger: None,
}
}
fn return_with_stop_trigger(
&self,
token: Option<String>,
stop_trigger: StopTrigger,
) -> StepResult {
StepResult {
token,
stop_trigger: Some(stop_trigger),
}
}
fn jailed_string(&self) -> Option<String> {
if self.jailed_bytes > 0 {
Some(self.jail[self.jail.len() - self.jailed_bytes..].to_string())
} else {
None
}
}
}