use std::{collections::HashSet, sync::Arc};
use anyhow::{Context, Ok, Result};
use minijinja::Environment;
use crate::model_card::model::{ModelDeploymentCard, PromptContextMixin, PromptFormatterArtifact};
mod context;
mod formatters;
mod oai;
mod tokcfg;
use super::{OAIChatLikeRequest, OAIPromptFormatter, PromptFormatter};
use tokcfg::ChatTemplate;
impl PromptFormatter {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<PromptFormatter> {
match mdc
.prompt_formatter
.ok_or(anyhow::anyhow!("MDC does not contain a prompt formatter"))?
{
PromptFormatterArtifact::HfTokenizerConfigJson(file) => {
let content = std::fs::read_to_string(&file)
.with_context(|| format!("fs:read_to_string '{file}'"))?;
let config: ChatTemplate = serde_json::from_str(&content)?;
Self::from_parts(
config,
mdc.prompt_context
.map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
)
}
PromptFormatterArtifact::GGUF(gguf_path) => {
let config = ChatTemplate::from_gguf(&gguf_path)?;
Self::from_parts(config, ContextMixins::default())
}
}
}
pub fn from_parts(config: ChatTemplate, context: ContextMixins) -> Result<PromptFormatter> {
let formatter = HfTokenizerConfigJsonFormatter::new(config, context)?;
Ok(Self::OAI(Arc::new(formatter)))
}
}
struct JinjaEnvironment {
env: Environment<'static>,
}
#[derive(Debug)]
struct HfTokenizerConfigJsonFormatter {
env: Environment<'static>,
config: ChatTemplate,
mixins: Arc<ContextMixins>,
supports_add_generation_prompt: bool,
}
#[derive(Debug, Clone, Default)]
pub struct ContextMixins {
context_mixins: HashSet<PromptContextMixin>,
}