pub(crate) mod ast;
pub(crate) mod eval;
pub(crate) mod lexer;
pub(crate) mod parser;
use crate::eval::{Evaluator, Value};
use crate::parser::Parser;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Clone, Default)]
pub struct RenderContext {
pub vars: HashMap<String, String>,
pub flags: HashMap<String, bool>,
}
impl RenderContext {
pub fn new() -> Self {
Self::default()
}
pub fn set_var(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
self.vars.insert(key.into(), value.into());
self
}
pub fn set_flag(&mut self, key: impl Into<String>, value: bool) -> &mut Self {
self.flags.insert(key.into(), value);
self
}
}
pub type RenderError = String;
pub fn render_chat_template(template: &str, messages: &[ChatMessage]) -> String {
let mut ctx = RenderContext::new();
ctx.set_var("eos_token", "</s>");
ctx.set_flag("add_generation_prompt", true);
render_chat_template_with_context(template, messages, &ctx)
}
pub fn render_chat_template_with_context(
template: &str,
messages: &[ChatMessage],
ctx: &RenderContext,
) -> String {
try_render_chat_template_with_context(template, messages, ctx)
.unwrap_or_else(|e| panic!("shimmyjinja render error: {}", e))
}
pub fn try_render_chat_template(
template: &str,
messages: &[ChatMessage],
) -> Result<String, RenderError> {
let mut ctx = RenderContext::new();
ctx.set_var("eos_token", "</s>");
ctx.set_flag("add_generation_prompt", true);
try_render_chat_template_with_context(template, messages, &ctx)
}
pub fn try_render_chat_template_with_context(
template: &str,
messages: &[ChatMessage],
ctx: &RenderContext,
) -> Result<String, RenderError> {
let mut parser = Parser::new(template);
let ast = parser.parse().map_err(|e| format!("parse error: {}", e))?;
let mut context = HashMap::new();
let mut msgs_val = Vec::new();
for m in messages {
let mut map = HashMap::new();
map.insert("role".to_string(), Value::String(m.role.clone()));
map.insert("content".to_string(), Value::String(m.content.clone()));
msgs_val.push(Value::Map(map));
}
context.insert("messages".to_string(), Value::Array(msgs_val));
for (k, v) in &ctx.vars {
context.insert(k.clone(), Value::String(v.clone()));
}
for (k, v) in &ctx.flags {
context.insert(k.clone(), Value::Bool(*v));
}
let mut eval = Evaluator::new(context);
eval.render(&ast).map_err(|e| format!("render error: {}", e))
}