use crate::llm::types::{ChatMessage, LLMRequest, MessageRole};
use crate::tokens::budget::TokenBudget;
use crate::tokens::estimate_tokens;
#[derive(Debug, Clone)]
pub struct ConversationHistory {
messages: Vec<ChatMessage>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RoleValidationError {
pub attempted: MessageRole,
pub previous: MessageRole,
}
impl std::fmt::Display for RoleValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"invalid role transition: {:?} after {:?}",
self.attempted, self.previous
)
}
}
impl std::error::Error for RoleValidationError {}
impl ConversationHistory {
pub fn new() -> Self {
Self {
messages: Vec::new(),
}
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn messages(&self) -> &[ChatMessage] {
&self.messages
}
pub fn push(&mut self, message: ChatMessage) -> Result<(), RoleValidationError> {
if let Some(last) = self.messages.last() {
let valid = match message.role {
MessageRole::System => false,
MessageRole::User => {
matches!(last.role, MessageRole::Assistant)
|| matches!(last.role, MessageRole::System)
}
MessageRole::Assistant => {
matches!(last.role, MessageRole::User | MessageRole::Tool)
}
MessageRole::Tool => {
matches!(last.role, MessageRole::Assistant | MessageRole::Tool)
}
};
if !valid {
return Err(RoleValidationError {
attempted: message.role,
previous: last.role,
});
}
} else {
match message.role {
MessageRole::System | MessageRole::User | MessageRole::Tool => {}
MessageRole::Assistant => {
return Err(RoleValidationError {
attempted: message.role,
previous: MessageRole::System, });
}
}
}
self.messages.push(message);
Ok(())
}
pub fn token_count(&self) -> u32 {
self.messages
.iter()
.map(|m| estimate_tokens(&m.text_content()) as u32)
.sum()
}
pub fn truncate_to_budget(&mut self, budget: &TokenBudget, needed: u32) -> usize {
if budget.try_reserve(needed) {
return 0;
}
let start = if self
.messages
.first()
.is_some_and(|m| m.role == MessageRole::System)
{
1
} else {
0
};
let mut removed = 0;
while start < self.messages.len() {
if budget.try_reserve(needed) {
break;
}
let tokens = estimate_tokens(&self.messages[start].text_content()) as u32;
budget.release(tokens);
self.messages.remove(start);
removed += 1;
}
removed
}
pub fn into_request(self, system_prompt: impl Into<String>) -> LLMRequest {
let system = Some(system_prompt.into());
let messages = self
.messages
.into_iter()
.filter(|m| m.role != MessageRole::System)
.collect();
LLMRequest {
system,
messages,
temperature: 0.7,
max_tokens: None,
model: None,
response_format: None,
tools: None,
}
}
}
impl Default for ConversationHistory {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn push_user_then_assistant() {
let mut h = ConversationHistory::new();
assert!(h.push(ChatMessage::user("hello")).is_ok());
assert!(h.push(ChatMessage::assistant("hi")).is_ok());
assert_eq!(h.len(), 2);
}
#[test]
fn push_system_then_user() {
let mut h = ConversationHistory::new();
assert!(h.push(ChatMessage::system("you are helpful")).is_ok());
assert!(h.push(ChatMessage::user("hello")).is_ok());
assert_eq!(h.len(), 2);
}
#[test]
fn push_rejects_system_after_user() {
let mut h = ConversationHistory::new();
h.push(ChatMessage::user("hello")).unwrap();
let err = h.push(ChatMessage::system("nope")).unwrap_err();
assert_eq!(err.attempted, MessageRole::System);
assert_eq!(err.previous, MessageRole::User);
}
#[test]
fn push_rejects_double_user() {
let mut h = ConversationHistory::new();
h.push(ChatMessage::user("first")).unwrap();
let err = h.push(ChatMessage::user("second")).unwrap_err();
assert_eq!(err.attempted, MessageRole::User);
assert_eq!(err.previous, MessageRole::User);
}
#[test]
fn push_rejects_assistant_first() {
let mut h = ConversationHistory::new();
let err = h.push(ChatMessage::assistant("hi")).unwrap_err();
assert_eq!(err.attempted, MessageRole::Assistant);
}
#[test]
fn push_tool_after_assistant() {
let mut h = ConversationHistory::new();
h.push(ChatMessage::user("run tool")).unwrap();
h.push(ChatMessage::assistant("calling tool")).unwrap();
assert!(h.push(ChatMessage::tool("result")).is_ok());
assert_eq!(h.len(), 3);
}
#[test]
fn push_consecutive_tools_allowed() {
let mut h = ConversationHistory::new();
h.push(ChatMessage::user("run tools")).unwrap();
h.push(ChatMessage::assistant("calling")).unwrap();
assert!(h.push(ChatMessage::tool("result1")).is_ok());
assert!(h.push(ChatMessage::tool("result2")).is_ok());
assert_eq!(h.len(), 4);
}
#[test]
fn into_request_sets_system_prompt() {
let mut h = ConversationHistory::new();
h.push(ChatMessage::system("original")).unwrap();
h.push(ChatMessage::user("hello")).unwrap();
let req = h.into_request("new system prompt");
assert_eq!(req.system.as_deref(), Some("new system prompt"));
assert_eq!(req.messages.len(), 1);
assert_eq!(req.messages[0].role, MessageRole::User);
}
#[test]
fn truncate_to_budget_removes_oldest() {
let budget = TokenBudget::new(100);
assert!(budget.try_reserve(90));
let mut h = ConversationHistory::new();
h.push(ChatMessage::system("system")).unwrap();
h.push(ChatMessage::user(&"x".repeat(200))).unwrap();
h.push(ChatMessage::assistant(&"y".repeat(200))).unwrap();
let len_before = h.len();
let removed = h.truncate_to_budget(&budget, 50);
assert!(removed > 0);
assert_eq!(h.len(), len_before - removed);
assert_eq!(h.messages()[0].role, MessageRole::System);
}
#[test]
fn truncate_preserves_system_message() {
let budget = TokenBudget::new(5);
let mut h = ConversationHistory::new();
h.push(ChatMessage::system("system instruction")).unwrap();
let removed = h.truncate_to_budget(&budget, 100);
assert_eq!(removed, 0); assert_eq!(h.messages()[0].role, MessageRole::System);
}
#[test]
fn token_count_estimates() {
let mut h = ConversationHistory::new();
assert_eq!(h.token_count(), 0);
h.push(ChatMessage::user("hello world")).unwrap();
assert!(h.token_count() > 0);
}
#[test]
fn default_is_empty() {
let h = ConversationHistory::default();
assert!(h.is_empty());
}
}