use std::sync::Arc;
use anyhow::{Context, Result};
use either::Either;
use indexmap::IndexMap;
use crate::{
request::ReasoningEffort,
vision_models::{preprocessor_config::PreProcessorConfig, processor_config::ProcessorConfig},
MessageContent, Pipeline, Tool,
};
use super::{chat_template::apply_chat_template_to, text_models_inputs_processor, InputsProcessor};
pub trait ProcessorCreator {
fn new_processor(
_: Option<ProcessorConfig>,
_: PreProcessorConfig,
) -> Arc<dyn Processor + Send + Sync>;
}
pub enum MessagesAction {
Keep,
FlattenOnlyText,
}
pub trait Processor {
#[allow(clippy::too_many_arguments)]
fn process(
&self,
pipeline: &dyn Pipeline,
messages: Vec<IndexMap<String, MessageContent>>,
add_generation_prompt: bool,
add_special_tokens: bool,
enable_thinking: Option<bool>,
reasoning_effort: Option<ReasoningEffort>,
tools: Vec<Tool>,
) -> Result<(Vec<u32>, String)> {
let prompt = apply_chat_template(
pipeline,
messages,
add_generation_prompt,
enable_thinking,
reasoning_effort,
self.template_action(),
tools,
)?;
let encoding = pipeline
.tokenizer()
.with_context(|| {
"Default `Processor::process` requires the model to have a tokenizer."
})?
.encode_fast(prompt.clone(), add_special_tokens)
.map_err(anyhow::Error::msg)?;
Ok((encoding.get_ids().to_vec(), prompt))
}
fn inputs_processor(&self) -> Arc<dyn InputsProcessor>;
fn get_special_tokens(&self) -> &[&'static str];
fn template_action(&self) -> MessagesAction;
}
fn extract_token_string(token: &super::chat_template::BeginEndUnkPadTok) -> String {
match &token.0 {
Either::Left(lit) => lit.clone(),
Either::Right(added) => added.content.clone(),
}
}
fn flatten_content(content: MessageContent) -> MessageContent {
match content {
Either::Left(_) => content,
Either::Right(content_rows) => {
content_rows
.into_iter()
.find_map(|content_row| {
content_row
.get("text")
.and_then(|v| v.as_str())
.map(|s| Either::Left(s.to_string()))
})
.unwrap_or(Either::Right(Vec::new()))
}
}
}
pub(crate) fn apply_chat_template(
pipeline: &dyn Pipeline,
messages: Vec<IndexMap<String, MessageContent>>,
add_generation_prompt: bool,
enable_thinking: Option<bool>,
reasoning_effort: Option<ReasoningEffort>,
action: MessagesAction,
tools: Vec<Tool>,
) -> Result<String> {
let messages = match action {
MessagesAction::Keep => messages,
MessagesAction::FlattenOnlyText => {
messages
.into_iter()
.map(|message| {
message
.into_iter()
.map(|(key, value)| {
let new_value = if key == "content" {
flatten_content(value)
} else {
value
};
(key, new_value)
})
.collect()
})
.collect()
}
};
let chat_template = pipeline
.get_chat_template()
.with_context(|| "`apply_chat_template` expects the pipeline to have a chat template.")?;
let template = chat_template.chat_template.as_ref().unwrap();
let bos_tok = chat_template.bos_token.as_ref().map(extract_token_string);
let eos_tok = chat_template.eos_token.as_ref().map(extract_token_string);
let unk_tok = chat_template.unk_token.as_ref().map(extract_token_string);
apply_chat_template_to(
messages,
add_generation_prompt,
enable_thinking,
reasoning_effort,
template,
bos_tok,
eos_tok,
unk_tok,
tools,
)
}
pub struct BasicProcessor;
impl Processor for BasicProcessor {
fn inputs_processor(&self) -> Arc<dyn InputsProcessor> {
Arc::new(text_models_inputs_processor::TextInputsProcessor)
}
fn get_special_tokens(&self) -> &[&'static str] {
&[]
}
fn template_action(&self) -> MessagesAction {
MessagesAction::Keep
}
}