use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::error::{CognisError, Result};
fn now_timestamp() -> String {
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
let secs = duration.as_secs();
format!("{}", secs)
}
fn estimate_tokens(text: &str) -> usize {
if text.is_empty() {
0
} else {
text.len().div_ceil(4)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct HistoryMessage {
pub role: String,
pub content: String,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, Value>,
pub timestamp: String,
pub index: usize,
}
impl HistoryMessage {
pub fn to_json(&self) -> Value {
serde_json::to_value(self).unwrap_or(Value::Null)
}
pub fn is_human(&self) -> bool {
self.role == "human"
}
pub fn is_ai(&self) -> bool {
self.role == "ai"
}
pub fn is_system(&self) -> bool {
self.role == "system"
}
pub fn word_count(&self) -> usize {
self.content.split_whitespace().count()
}
}
#[derive(Debug, Clone, Default)]
pub struct MessageHistory {
messages: Vec<HistoryMessage>,
}
impl MessageHistory {
pub fn new() -> Self {
Self {
messages: Vec::new(),
}
}
pub fn add(&mut self, role: &str, content: &str) {
self.add_with_metadata(role, content, HashMap::new());
}
pub fn add_with_metadata(
&mut self,
role: &str,
content: &str,
metadata: HashMap<String, Value>,
) {
let index = self.messages.len();
self.messages.push(HistoryMessage {
role: role.to_string(),
content: content.to_string(),
metadata,
timestamp: now_timestamp(),
index,
});
}
pub fn messages(&self) -> &[HistoryMessage] {
&self.messages
}
pub fn last(&self) -> Option<&HistoryMessage> {
self.messages.last()
}
pub fn by_role(&self, role: &str) -> Vec<&HistoryMessage> {
self.messages.iter().filter(|m| m.role == role).collect()
}
pub fn len(&self) -> usize {
self.messages.len()
}
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn clear(&mut self) {
self.messages.clear();
}
pub fn to_json(&self) -> Value {
let arr: Vec<Value> = self.messages.iter().map(|m| m.to_json()).collect();
Value::Array(arr)
}
}
#[derive(Debug, Clone)]
pub struct SlidingWindowHistory {
history: MessageHistory,
max_messages: usize,
total_added: usize,
}
impl SlidingWindowHistory {
pub fn new(max_messages: usize) -> Self {
Self {
history: MessageHistory::new(),
max_messages,
total_added: 0,
}
}
pub fn add(&mut self, role: &str, content: &str) {
self.add_with_metadata(role, content, HashMap::new());
}
pub fn add_with_metadata(
&mut self,
role: &str,
content: &str,
metadata: HashMap<String, Value>,
) {
self.total_added += 1;
self.history.add_with_metadata(role, content, metadata);
while self.history.len() > self.max_messages {
self.history.messages.remove(0);
}
}
pub fn window_size(&self) -> usize {
self.max_messages
}
pub fn total_added(&self) -> usize {
self.total_added
}
pub fn messages(&self) -> &[HistoryMessage] {
self.history.messages()
}
pub fn last(&self) -> Option<&HistoryMessage> {
self.history.last()
}
pub fn len(&self) -> usize {
self.history.len()
}
pub fn is_empty(&self) -> bool {
self.history.is_empty()
}
pub fn clear(&mut self) {
self.history.clear();
self.total_added = 0;
}
pub fn to_json(&self) -> Value {
self.history.to_json()
}
}
#[derive(Debug, Clone)]
pub struct TokenBudgetHistory {
history: MessageHistory,
max_tokens: usize,
current_tokens: usize,
}
impl TokenBudgetHistory {
pub fn new(max_tokens: usize) -> Self {
Self {
history: MessageHistory::new(),
max_tokens,
current_tokens: 0,
}
}
pub fn add(&mut self, role: &str, content: &str) {
let new_tokens = estimate_tokens(content);
self.history.add(role, content);
self.current_tokens += new_tokens;
while self.current_tokens > self.max_tokens {
let evict_idx = self
.history
.messages
.iter()
.position(|m| m.role != "system");
match evict_idx {
Some(idx) => {
let removed = self.history.messages.remove(idx);
self.current_tokens = self
.current_tokens
.saturating_sub(estimate_tokens(&removed.content));
}
None => break, }
}
}
pub fn current_tokens(&self) -> usize {
self.current_tokens
}
pub fn budget(&self) -> usize {
self.max_tokens
}
pub fn utilization(&self) -> f64 {
if self.max_tokens == 0 {
return 0.0;
}
self.current_tokens as f64 / self.max_tokens as f64
}
pub fn messages(&self) -> &[HistoryMessage] {
self.history.messages()
}
pub fn last(&self) -> Option<&HistoryMessage> {
self.history.last()
}
pub fn len(&self) -> usize {
self.history.len()
}
pub fn is_empty(&self) -> bool {
self.history.is_empty()
}
pub fn clear(&mut self) {
self.history.clear();
self.current_tokens = 0;
}
pub fn to_json(&self) -> Value {
self.history.to_json()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ConversationTurn {
human: HistoryMessage,
ai: HistoryMessage,
}
impl ConversationTurn {
pub fn new(human: HistoryMessage, ai: HistoryMessage) -> Self {
Self { human, ai }
}
pub fn human(&self) -> &HistoryMessage {
&self.human
}
pub fn ai(&self) -> &HistoryMessage {
&self.ai
}
pub fn total_tokens(&self) -> usize {
estimate_tokens(&self.human.content) + estimate_tokens(&self.ai.content)
}
pub fn to_json(&self) -> Value {
serde_json::json!({
"human": self.human.to_json(),
"ai": self.ai.to_json(),
})
}
}
#[derive(Debug, Clone, Default)]
pub struct TurnHistory {
turns: Vec<ConversationTurn>,
next_index: usize,
}
impl TurnHistory {
pub fn new() -> Self {
Self {
turns: Vec::new(),
next_index: 0,
}
}
pub fn add_turn(&mut self, human_content: &str, ai_content: &str) {
let ts = now_timestamp();
let human = HistoryMessage {
role: "human".to_string(),
content: human_content.to_string(),
metadata: HashMap::new(),
timestamp: ts.clone(),
index: self.next_index,
};
self.next_index += 1;
let ai = HistoryMessage {
role: "ai".to_string(),
content: ai_content.to_string(),
metadata: HashMap::new(),
timestamp: ts,
index: self.next_index,
};
self.next_index += 1;
self.turns.push(ConversationTurn::new(human, ai));
}
pub fn turns(&self) -> &[ConversationTurn] {
&self.turns
}
pub fn last_turn(&self) -> Option<&ConversationTurn> {
self.turns.last()
}
pub fn len(&self) -> usize {
self.turns.len()
}
pub fn is_empty(&self) -> bool {
self.turns.is_empty()
}
pub fn to_json(&self) -> Value {
let arr: Vec<Value> = self.turns.iter().map(|t| t.to_json()).collect();
Value::Array(arr)
}
}
pub struct HistorySerializer;
impl HistorySerializer {
pub fn to_json(history: &MessageHistory) -> Value {
history.to_json()
}
pub fn from_json(json: &Value) -> Result<MessageHistory> {
let arr = json.as_array().ok_or_else(|| {
CognisError::Other("Expected a JSON array for message history".to_string())
})?;
let mut history = MessageHistory::new();
for (i, item) in arr.iter().enumerate() {
let role = item.get("role").and_then(|v| v.as_str()).ok_or_else(|| {
CognisError::Other(format!("Missing 'role' in message at index {}", i))
})?;
let content = item
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| {
CognisError::Other(format!("Missing 'content' in message at index {}", i))
})?;
let metadata: HashMap<String, Value> = item
.get("metadata")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
history.add_with_metadata(role, content, metadata);
}
Ok(history)
}
pub fn to_chat_format(history: &MessageHistory) -> Vec<Value> {
history
.messages()
.iter()
.map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content,
})
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_history_message_is_human() {
let msg = HistoryMessage {
role: "human".into(),
content: "hello".into(),
metadata: HashMap::new(),
timestamp: "0".into(),
index: 0,
};
assert!(msg.is_human());
assert!(!msg.is_ai());
assert!(!msg.is_system());
}
#[test]
fn test_history_message_is_ai() {
let msg = HistoryMessage {
role: "ai".into(),
content: "response".into(),
metadata: HashMap::new(),
timestamp: "0".into(),
index: 0,
};
assert!(msg.is_ai());
assert!(!msg.is_human());
}
#[test]
fn test_history_message_is_system() {
let msg = HistoryMessage {
role: "system".into(),
content: "you are helpful".into(),
metadata: HashMap::new(),
timestamp: "0".into(),
index: 0,
};
assert!(msg.is_system());
}
#[test]
fn test_history_message_word_count() {
let msg = HistoryMessage {
role: "human".into(),
content: "how many words are here".into(),
metadata: HashMap::new(),
timestamp: "0".into(),
index: 0,
};
assert_eq!(msg.word_count(), 5);
}
#[test]
fn test_history_message_word_count_empty() {
let msg = HistoryMessage {
role: "human".into(),
content: "".into(),
metadata: HashMap::new(),
timestamp: "0".into(),
index: 0,
};
assert_eq!(msg.word_count(), 0);
}
#[test]
fn test_history_message_to_json() {
let msg = HistoryMessage {
role: "human".into(),
content: "test".into(),
metadata: HashMap::new(),
timestamp: "123".into(),
index: 0,
};
let json = msg.to_json();
assert_eq!(json["role"], "human");
assert_eq!(json["content"], "test");
assert_eq!(json["timestamp"], "123");
assert_eq!(json["index"], 0);
}
#[test]
fn test_history_message_with_metadata() {
let mut meta = HashMap::new();
meta.insert("source".to_string(), Value::String("test".to_string()));
let msg = HistoryMessage {
role: "ai".into(),
content: "reply".into(),
metadata: meta,
timestamp: "0".into(),
index: 1,
};
let json = msg.to_json();
assert_eq!(json["metadata"]["source"], "test");
}
#[test]
fn test_message_history_new_is_empty() {
let h = MessageHistory::new();
assert!(h.is_empty());
assert_eq!(h.len(), 0);
}
#[test]
fn test_message_history_add_and_len() {
let mut h = MessageHistory::new();
h.add("human", "hello");
h.add("ai", "hi there");
assert_eq!(h.len(), 2);
assert!(!h.is_empty());
}
#[test]
fn test_message_history_messages_order() {
let mut h = MessageHistory::new();
h.add("human", "first");
h.add("ai", "second");
h.add("human", "third");
let msgs = h.messages();
assert_eq!(msgs[0].content, "first");
assert_eq!(msgs[1].content, "second");
assert_eq!(msgs[2].content, "third");
}
#[test]
fn test_message_history_last() {
let mut h = MessageHistory::new();
assert!(h.last().is_none());
h.add("human", "only");
assert_eq!(h.last().unwrap().content, "only");
}
#[test]
fn test_message_history_by_role() {
let mut h = MessageHistory::new();
h.add("human", "q1");
h.add("ai", "a1");
h.add("human", "q2");
h.add("system", "sys");
let humans = h.by_role("human");
assert_eq!(humans.len(), 2);
assert_eq!(humans[0].content, "q1");
assert_eq!(humans[1].content, "q2");
}
#[test]
fn test_message_history_clear() {
let mut h = MessageHistory::new();
h.add("human", "a");
h.add("ai", "b");
h.clear();
assert!(h.is_empty());
assert_eq!(h.len(), 0);
}
#[test]
fn test_message_history_add_with_metadata() {
let mut h = MessageHistory::new();
let mut meta = HashMap::new();
meta.insert("key".to_string(), Value::String("val".to_string()));
h.add_with_metadata("human", "content", meta);
assert_eq!(h.messages()[0].metadata["key"], "val");
}
#[test]
fn test_message_history_to_json() {
let mut h = MessageHistory::new();
h.add("human", "hello");
h.add("ai", "world");
let json = h.to_json();
let arr = json.as_array().unwrap();
assert_eq!(arr.len(), 2);
assert_eq!(arr[0]["role"], "human");
assert_eq!(arr[1]["role"], "ai");
}
#[test]
fn test_message_history_index_tracking() {
let mut h = MessageHistory::new();
h.add("human", "a");
h.add("ai", "b");
h.add("human", "c");
assert_eq!(h.messages()[0].index, 0);
assert_eq!(h.messages()[1].index, 1);
assert_eq!(h.messages()[2].index, 2);
}
#[test]
fn test_sliding_window_basic() {
let mut sw = SlidingWindowHistory::new(3);
sw.add("human", "a");
sw.add("ai", "b");
sw.add("human", "c");
assert_eq!(sw.len(), 3);
assert_eq!(sw.total_added(), 3);
}
#[test]
fn test_sliding_window_eviction() {
let mut sw = SlidingWindowHistory::new(2);
sw.add("human", "first");
sw.add("ai", "second");
sw.add("human", "third");
assert_eq!(sw.len(), 2);
assert_eq!(sw.messages()[0].content, "second");
assert_eq!(sw.messages()[1].content, "third");
assert_eq!(sw.total_added(), 3);
}
#[test]
fn test_sliding_window_eviction_multiple() {
let mut sw = SlidingWindowHistory::new(2);
for i in 0..10 {
sw.add("human", &format!("msg{}", i));
}
assert_eq!(sw.len(), 2);
assert_eq!(sw.total_added(), 10);
assert_eq!(sw.messages()[0].content, "msg8");
assert_eq!(sw.messages()[1].content, "msg9");
}
#[test]
fn test_sliding_window_size() {
let sw = SlidingWindowHistory::new(5);
assert_eq!(sw.window_size(), 5);
}
#[test]
fn test_sliding_window_clear() {
let mut sw = SlidingWindowHistory::new(3);
sw.add("human", "a");
sw.add("ai", "b");
sw.clear();
assert!(sw.is_empty());
assert_eq!(sw.total_added(), 0);
}
#[test]
fn test_sliding_window_last() {
let mut sw = SlidingWindowHistory::new(2);
assert!(sw.last().is_none());
sw.add("human", "hello");
assert_eq!(sw.last().unwrap().content, "hello");
}
#[test]
fn test_sliding_window_to_json() {
let mut sw = SlidingWindowHistory::new(2);
sw.add("human", "a");
sw.add("ai", "b");
let json = sw.to_json();
assert_eq!(json.as_array().unwrap().len(), 2);
}
#[test]
fn test_token_budget_basic() {
let mut tb = TokenBudgetHistory::new(100);
tb.add("human", "hello");
assert!(tb.current_tokens() > 0);
assert_eq!(tb.budget(), 100);
}
#[test]
fn test_token_budget_eviction() {
let mut tb = TokenBudgetHistory::new(10);
tb.add("human", "short");
tb.add("ai", "also short text here");
tb.add("human", "another message with more text");
assert!(tb.current_tokens() <= tb.budget());
}
#[test]
fn test_token_budget_preserves_system() {
let mut tb = TokenBudgetHistory::new(10);
tb.add("system", "you are helpful");
tb.add(
"human",
"a long human message that exceeds the budget easily",
);
tb.add("ai", "another long ai message that also exceeds the budget");
let has_system = tb.messages().iter().any(|m| m.role == "system");
assert!(has_system);
}
#[test]
fn test_token_budget_utilization() {
let mut tb = TokenBudgetHistory::new(100);
assert_eq!(tb.utilization(), 0.0);
tb.add("human", "hello world");
assert!(tb.utilization() > 0.0);
assert!(tb.utilization() <= 1.0);
}
#[test]
fn test_token_budget_utilization_zero_budget() {
let tb = TokenBudgetHistory::new(0);
assert_eq!(tb.utilization(), 0.0);
}
#[test]
fn test_token_budget_clear() {
let mut tb = TokenBudgetHistory::new(100);
tb.add("human", "hello");
tb.clear();
assert!(tb.is_empty());
assert_eq!(tb.current_tokens(), 0);
}
#[test]
fn test_token_budget_len_and_empty() {
let mut tb = TokenBudgetHistory::new(1000);
assert!(tb.is_empty());
tb.add("human", "a");
assert_eq!(tb.len(), 1);
assert!(!tb.is_empty());
}
#[test]
fn test_conversation_turn_accessors() {
let human = HistoryMessage {
role: "human".into(),
content: "question".into(),
metadata: HashMap::new(),
timestamp: "0".into(),
index: 0,
};
let ai = HistoryMessage {
role: "ai".into(),
content: "answer".into(),
metadata: HashMap::new(),
timestamp: "0".into(),
index: 1,
};
let turn = ConversationTurn::new(human, ai);
assert_eq!(turn.human().content, "question");
assert_eq!(turn.ai().content, "answer");
}
#[test]
fn test_conversation_turn_total_tokens() {
let human = HistoryMessage {
role: "human".into(),
content: "hello world".into(),
metadata: HashMap::new(),
timestamp: "0".into(),
index: 0,
};
let ai = HistoryMessage {
role: "ai".into(),
content: "hi".into(),
metadata: HashMap::new(),
timestamp: "0".into(),
index: 1,
};
let turn = ConversationTurn::new(human, ai);
assert!(turn.total_tokens() > 0);
}
#[test]
fn test_conversation_turn_to_json() {
let human = HistoryMessage {
role: "human".into(),
content: "q".into(),
metadata: HashMap::new(),
timestamp: "0".into(),
index: 0,
};
let ai = HistoryMessage {
role: "ai".into(),
content: "a".into(),
metadata: HashMap::new(),
timestamp: "0".into(),
index: 1,
};
let turn = ConversationTurn::new(human, ai);
let json = turn.to_json();
assert!(json.get("human").is_some());
assert!(json.get("ai").is_some());
assert_eq!(json["human"]["content"], "q");
assert_eq!(json["ai"]["content"], "a");
}
#[test]
fn test_turn_history_new_is_empty() {
let th = TurnHistory::new();
assert!(th.is_empty());
assert_eq!(th.len(), 0);
}
#[test]
fn test_turn_history_add_turn() {
let mut th = TurnHistory::new();
th.add_turn("hello", "hi");
assert_eq!(th.len(), 1);
assert!(!th.is_empty());
}
#[test]
fn test_turn_history_turns_content() {
let mut th = TurnHistory::new();
th.add_turn("q1", "a1");
th.add_turn("q2", "a2");
let turns = th.turns();
assert_eq!(turns[0].human().content, "q1");
assert_eq!(turns[0].ai().content, "a1");
assert_eq!(turns[1].human().content, "q2");
assert_eq!(turns[1].ai().content, "a2");
}
#[test]
fn test_turn_history_last_turn() {
let mut th = TurnHistory::new();
assert!(th.last_turn().is_none());
th.add_turn("q1", "a1");
th.add_turn("q2", "a2");
assert_eq!(th.last_turn().unwrap().human().content, "q2");
}
#[test]
fn test_turn_history_to_json() {
let mut th = TurnHistory::new();
th.add_turn("hello", "hi");
let json = th.to_json();
let arr = json.as_array().unwrap();
assert_eq!(arr.len(), 1);
assert!(arr[0].get("human").is_some());
assert!(arr[0].get("ai").is_some());
}
#[test]
fn test_turn_history_index_tracking() {
let mut th = TurnHistory::new();
th.add_turn("q1", "a1");
th.add_turn("q2", "a2");
assert_eq!(th.turns()[0].human().index, 0);
assert_eq!(th.turns()[0].ai().index, 1);
assert_eq!(th.turns()[1].human().index, 2);
assert_eq!(th.turns()[1].ai().index, 3);
}
#[test]
fn test_serializer_to_json() {
let mut h = MessageHistory::new();
h.add("human", "hello");
let json = HistorySerializer::to_json(&h);
assert!(json.is_array());
assert_eq!(json.as_array().unwrap().len(), 1);
}
#[test]
fn test_serializer_roundtrip() {
let mut h = MessageHistory::new();
h.add("human", "question");
h.add("ai", "answer");
h.add("system", "context");
let json = HistorySerializer::to_json(&h);
let restored = HistorySerializer::from_json(&json).unwrap();
assert_eq!(restored.len(), h.len());
for (orig, rest) in h.messages().iter().zip(restored.messages().iter()) {
assert_eq!(orig.role, rest.role);
assert_eq!(orig.content, rest.content);
}
}
#[test]
fn test_serializer_from_json_error_not_array() {
let json = serde_json::json!({"not": "array"});
let result = HistorySerializer::from_json(&json);
assert!(result.is_err());
}
#[test]
fn test_serializer_from_json_error_missing_role() {
let json = serde_json::json!([{"content": "hello"}]);
let result = HistorySerializer::from_json(&json);
assert!(result.is_err());
}
#[test]
fn test_serializer_from_json_error_missing_content() {
let json = serde_json::json!([{"role": "human"}]);
let result = HistorySerializer::from_json(&json);
assert!(result.is_err());
}
#[test]
fn test_serializer_to_chat_format() {
let mut h = MessageHistory::new();
h.add("system", "You are helpful.");
h.add("human", "What is 2+2?");
h.add("ai", "4");
let chat = HistorySerializer::to_chat_format(&h);
assert_eq!(chat.len(), 3);
assert_eq!(chat[0]["role"], "system");
assert_eq!(chat[0]["content"], "You are helpful.");
assert_eq!(chat[1]["role"], "human");
assert_eq!(chat[1]["content"], "What is 2+2?");
assert_eq!(chat[2]["role"], "ai");
assert_eq!(chat[2]["content"], "4");
}
#[test]
fn test_serializer_to_chat_format_empty() {
let h = MessageHistory::new();
let chat = HistorySerializer::to_chat_format(&h);
assert!(chat.is_empty());
}
#[test]
fn test_serializer_roundtrip_with_metadata() {
let mut h = MessageHistory::new();
let mut meta = HashMap::new();
meta.insert("source".to_string(), Value::String("test".to_string()));
h.add_with_metadata("human", "hello", meta);
let json = HistorySerializer::to_json(&h);
let restored = HistorySerializer::from_json(&json).unwrap();
assert_eq!(restored.messages()[0].metadata["source"], "test");
}
#[test]
fn test_estimate_tokens_empty() {
assert_eq!(estimate_tokens(""), 0);
}
#[test]
fn test_estimate_tokens_nonempty() {
assert_eq!(estimate_tokens("hello"), 2);
assert_eq!(estimate_tokens("hi"), 1);
}
#[test]
fn test_by_role_empty_result() {
let h = MessageHistory::new();
assert!(h.by_role("human").is_empty());
}
}