use crate::prompt::{Prompt, PromptError};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
}
impl Role {
pub fn as_str(&self) -> &'static str {
match self {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Message {
pub role: Role,
pub content: String,
}
#[derive(Default, Clone)]
pub struct Messages {
entries: Vec<(Role, String)>,
}
impl Messages {
pub fn new() -> Self {
Self::default()
}
pub fn system(mut self, tmpl: impl Into<String>) -> Self {
self.entries.push((Role::System, tmpl.into()));
self
}
pub fn user(mut self, tmpl: impl Into<String>) -> Self {
self.entries.push((Role::User, tmpl.into()));
self
}
pub fn assistant(mut self, tmpl: impl Into<String>) -> Self {
self.entries.push((Role::Assistant, tmpl.into()));
self
}
pub fn push(mut self, role: Role, tmpl: impl Into<String>) -> Self {
self.entries.push((role, tmpl.into()));
self
}
pub fn render<T: serde::Serialize>(
&self,
vars: &T,
) -> Result<Vec<Message>, PromptError> {
let mut out = Vec::with_capacity(self.entries.len());
for (role, tmpl) in &self.entries {
let p = Prompt::new(tmpl.clone())?;
out.push(Message {
role: *role,
content: p.render(vars)?,
});
}
Ok(out)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}