use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tiktoken_rs::cl100k_base;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
impl std::fmt::Display for MessageRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MessageRole::System => write!(f, "system"),
MessageRole::User => write!(f, "user"),
MessageRole::Assistant => write!(f, "assistant"),
MessageRole::Tool => write!(f, "tool"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMLMessage {
pub role: MessageRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<crate::ToolCall>>,
}
impl ChatMLMessage {
pub fn new(role: MessageRole, content: String, name: Option<String>) -> Self {
Self {
role,
content,
reasoning_content: None,
name,
tool_call_id: None,
tool_calls: None,
}
}
pub fn new_tool(content: String, tool_call_id: String, name: String) -> Self {
Self {
role: MessageRole::Tool,
content,
reasoning_content: None,
name: Some(name),
tool_call_id: Some(tool_call_id),
tool_calls: None,
}
}
pub fn new_assistant_with_tool_calls(
content: String,
tool_calls: Vec<crate::ToolCall>,
) -> Self {
Self {
role: MessageRole::Assistant,
content,
reasoning_content: None,
name: None,
tool_call_id: None,
tool_calls: Some(tool_calls),
}
}
pub fn new_assistant_with_reasoning(content: String, reasoning_content: String) -> Self {
Self {
role: MessageRole::Assistant,
content,
reasoning_content: Some(reasoning_content),
name: None,
tool_call_id: None,
tool_calls: None,
}
}
pub fn to_dict(&self) -> HashMap<String, serde_json::Value> {
let mut message = HashMap::new();
message.insert(
"role".to_string(),
serde_json::Value::String(self.role.to_string()),
);
message.insert(
"content".to_string(),
serde_json::Value::String(self.content.clone()),
);
if let Some(ref reasoning) = self.reasoning_content {
message.insert(
"reasoning_content".to_string(),
serde_json::Value::String(reasoning.clone()),
);
}
if let Some(name) = &self.name {
message.insert("name".to_string(), serde_json::Value::String(name.clone()));
}
if let Some(tool_call_id) = &self.tool_call_id {
message.insert(
"tool_call_id".to_string(),
serde_json::Value::String(tool_call_id.clone()),
);
}
if let Some(tool_calls) = &self.tool_calls {
let tool_calls_json = serde_json::to_value(tool_calls)
.unwrap_or_else(|_| serde_json::Value::Array(vec![]));
message.insert("tool_calls".to_string(), tool_calls_json);
}
message
}
pub fn to_chatml_string(&self) -> String {
let name_part = if let Some(name) = &self.name {
format!(" name={}", name)
} else {
String::new()
};
format!(
"<|im_start|>{}{}\n{}\n<|im_end|>",
self.role, name_part, self.content
)
}
}
#[derive(Debug, Clone)]
pub struct ChatMLFormatter {
messages: Vec<ChatMLMessage>,
}
impl ChatMLFormatter {
pub fn new() -> Self {
Self {
messages: Vec::new(),
}
}
pub fn add_system_message(&mut self, content: String, name: Option<String>) -> &mut Self {
self.messages
.push(ChatMLMessage::new(MessageRole::System, content, name));
self
}
pub fn add_user_message(&mut self, content: String, name: Option<String>) -> &mut Self {
self.messages
.push(ChatMLMessage::new(MessageRole::User, content, name));
self
}
pub fn add_assistant_message(&mut self, content: String, name: Option<String>) -> &mut Self {
self.messages
.push(ChatMLMessage::new(MessageRole::Assistant, content, name));
self
}
pub fn add_assistant_message_with_reasoning(
&mut self,
content: String,
reasoning_content: String,
tool_calls: Option<Vec<crate::ToolCall>>,
) -> &mut Self {
let mut message = ChatMLMessage::new_assistant_with_reasoning(content, reasoning_content);
message.tool_calls = tool_calls;
self.messages.push(message);
self
}
pub fn add_assistant_message_with_tool_calls(
&mut self,
content: String,
tool_calls: Vec<crate::ToolCall>,
) -> &mut Self {
self.messages
.push(ChatMLMessage::new_assistant_with_tool_calls(
content, tool_calls,
));
self
}
pub fn add_tool_message(
&mut self,
content: String,
tool_call_id: String,
name: String,
) -> &mut Self {
self.messages
.push(ChatMLMessage::new_tool(content, tool_call_id, name));
self
}
pub fn add_tool_results_message(&mut self, content: String, name: Option<String>) -> &mut Self {
self.messages.push(ChatMLMessage::new_tool(
content,
"combined_tool_results".to_string(),
name.unwrap_or_else(|| "tool_results".to_string()),
));
self
}
pub fn to_openai_format(&self) -> Vec<HashMap<String, serde_json::Value>> {
self.messages.iter().map(|msg| msg.to_dict()).collect()
}
pub fn to_chatml_string(&self) -> String {
self.messages
.iter()
.map(|msg| msg.to_chatml_string())
.collect::<Vec<_>>()
.join("\n")
}
pub fn clear(&mut self) -> &mut Self {
self.messages.clear();
self
}
pub fn limit_history(&mut self, max_messages: usize) -> &mut Self {
if self.messages.len() > max_messages {
let system_message = self.messages.first().cloned();
let recent_messages = self
.messages
.iter()
.rev()
.take(max_messages - 1)
.rev()
.cloned()
.collect::<Vec<_>>();
self.messages = if let Some(system) = system_message {
std::iter::once(system).chain(recent_messages).collect()
} else {
recent_messages
};
}
self
}
pub fn get_message_count(&self) -> usize {
self.messages.len()
}
pub fn get_last_message(&self) -> Option<&ChatMLMessage> {
self.messages.last()
}
pub fn get_messages(&self) -> &Vec<ChatMLMessage> {
&self.messages
}
pub fn format_thought_command(&self, thought: &str, command: &str) -> String {
format!("THOUGHT: {}\n\n```bash\n{}\n```", thought, command)
}
pub fn replace_template_variables(
&self,
template: &str,
variables: &HashMap<String, String>,
) -> String {
let mut result = template.to_string();
for (key, value) in variables {
let placeholder = format!("{{{}}}", key);
result = result.replace(&placeholder, value);
}
result
}
pub fn process_template(
&self,
template_path: &str,
variables: &HashMap<String, String>,
) -> Result<String, Box<dyn std::error::Error>> {
let template_content = std::fs::read_to_string(template_path)?;
Ok(self.replace_template_variables(&template_content, variables))
}
pub fn validate_messages(&self) -> bool {
for message in &self.messages {
if message.content.is_empty() && message.tool_calls.is_none() {
return false;
}
if message.role == MessageRole::System {
if message.name.is_none() {
return false;
}
}
if message.role == MessageRole::Assistant {
if message.tool_calls.is_none() && message.name.is_none() {
return false;
}
}
if matches!(message.role, MessageRole::Tool) {
if message.tool_call_id.is_none() || message.name.is_none() {
return false;
}
}
}
true
}
pub fn count_tokens(&self) -> usize {
match cl100k_base() {
Ok(bpe) => {
let chatml_string = self.to_chatml_string();
let tokens = bpe.encode_with_special_tokens(&chatml_string);
tokens.len()
}
Err(_) => 0,
}
}
}
pub fn count_tokens_for_text(text: &str) -> usize {
match cl100k_base() {
Ok(bpe) => bpe.encode_with_special_tokens(text).len(),
Err(_) => 0,
}
}
impl Default for ChatMLFormatter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests;