mod types;
pub use types::*;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use serde_json::{Map, Value};
use crate::error::{Result, TinyAgentsError};
use crate::harness::message::Message;
use crate::harness::model::{ModelRequest, PromptSegment, ResponseFormat, SegmentRole};
use crate::harness::tool::ToolSchema;
impl PromptTemplate {
pub fn new(template: impl Into<String>) -> Self {
Self {
template: template.into(),
}
}
pub fn render(&self, vars: &Map<String, Value>) -> Result<String> {
render_template(&self.template, vars)
}
pub fn render_message(&self, role: TemplateRole, vars: &Map<String, Value>) -> Result<Message> {
let text = self.render(vars)?;
Ok(match role {
TemplateRole::System => Message::system(text),
TemplateRole::User => Message::user(text),
TemplateRole::Assistant => Message::assistant(text),
})
}
pub fn render_system(&self, vars: &Map<String, Value>) -> Result<Message> {
self.render_message(TemplateRole::System, vars)
}
pub fn render_user(&self, vars: &Map<String, Value>) -> Result<Message> {
self.render_message(TemplateRole::User, vars)
}
pub fn render_assistant(&self, vars: &Map<String, Value>) -> Result<Message> {
self.render_message(TemplateRole::Assistant, vars)
}
}
impl MessagesTemplate {
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, role: TemplateRole, template: PromptTemplate) -> &mut Self {
self.entries.push((role, template));
self
}
pub fn render(&self, vars: &Map<String, Value>) -> Result<Vec<Message>> {
self.entries
.iter()
.map(|(role, tpl)| tpl.render_message(*role, vars))
.collect()
}
}
impl PromptBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn push_system(&mut self, id: impl Into<String>, messages: Vec<Message>) -> &mut Self {
self.segments.push(BuiltSegment {
messages,
meta: PromptSegment {
id: id.into(),
role: SegmentRole::System,
cacheable: true,
},
});
self
}
pub fn push_tools_segment(
&mut self,
id: impl Into<String>,
tools: Vec<ToolSchema>,
) -> &mut Self {
self.tools.extend(tools);
self.segments.push(BuiltSegment {
messages: vec![],
meta: PromptSegment {
id: id.into(),
role: SegmentRole::Tools,
cacheable: true,
},
});
self
}
pub fn push_instructions(
&mut self,
id: impl Into<String>,
messages: Vec<Message>,
) -> &mut Self {
self.segments.push(BuiltSegment {
messages,
meta: PromptSegment {
id: id.into(),
role: SegmentRole::Instructions,
cacheable: true,
},
});
self
}
pub fn push_history(&mut self, id: impl Into<String>, messages: Vec<Message>) -> &mut Self {
self.segments.push(BuiltSegment {
messages,
meta: PromptSegment {
id: id.into(),
role: SegmentRole::History,
cacheable: false,
},
});
self
}
pub fn push_volatile(&mut self, id: impl Into<String>, messages: Vec<Message>) -> &mut Self {
self.segments.push(BuiltSegment {
messages,
meta: PromptSegment {
id: id.into(),
role: SegmentRole::Volatile,
cacheable: false,
},
});
self
}
pub fn with_response_format(&mut self, format: ResponseFormat) -> &mut Self {
self.response_format = Some(format);
self
}
pub fn build(&self, tail: Vec<Message>) -> ModelRequest {
let mut messages: Vec<Message> = self
.segments
.iter()
.flat_map(|s| s.messages.iter().cloned())
.collect();
messages.extend(tail);
let cache_segments: Vec<PromptSegment> =
self.segments.iter().map(|s| s.meta.clone()).collect();
let fp = self.fingerprint();
let mut req = ModelRequest::new(messages)
.with_tools(self.tools.clone())
.with_cache_segments(cache_segments);
req.prompt_fingerprint = Some(fp);
if let Some(fmt) = &self.response_format {
req = req.with_response_format(fmt.clone());
}
req
}
pub fn fingerprint(&self) -> String {
let mut hasher = DefaultHasher::new();
for seg in self.segments.iter().filter(|s| s.meta.cacheable) {
seg.meta.id.hash(&mut hasher);
for msg in &seg.messages {
msg.text().hash(&mut hasher);
}
}
for tool in &self.tools {
tool.name.hash(&mut hasher);
}
format!("{:016x}", hasher.finish())
}
}
fn render_template(template: &str, vars: &Map<String, Value>) -> Result<String> {
let mut result = String::with_capacity(template.len());
let mut chars = template.chars().peekable();
while let Some(c) = chars.next() {
match c {
'{' => match chars.peek() {
Some('{') => {
chars.next();
result.push('{');
}
_ => {
let mut name = String::new();
let mut closed = false;
for nc in chars.by_ref() {
if nc == '}' {
closed = true;
break;
}
name.push(nc);
}
if !closed {
return Err(TinyAgentsError::Validation(format!(
"unclosed placeholder '{{{name}'"
)));
}
match vars.get(&name) {
Some(Value::String(s)) => result.push_str(s),
Some(v) => result.push_str(&v.to_string()),
None => {
return Err(TinyAgentsError::Validation(format!(
"unknown placeholder '{{{name}}}'"
)));
}
}
}
},
'}' => {
if chars.peek() == Some(&'}') {
chars.next();
result.push('}');
} else {
result.push('}');
}
}
_ => result.push(c),
}
}
Ok(result)
}
#[cfg(test)]
mod test;