use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationTurn {
pub user: String,
pub agent: String,
pub project_id: Option<Uuid>,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct ConversationBuffer {
turns: VecDeque<ConversationTurn>,
max_turns: usize,
turns_since_topic_check: usize,
}
impl Default for ConversationBuffer {
fn default() -> Self {
Self::new(50)
}
}
impl ConversationBuffer {
pub fn new(max_turns: usize) -> Self {
Self {
turns: VecDeque::with_capacity(max_turns),
max_turns,
turns_since_topic_check: 0,
}
}
pub fn push_user(&mut self, message: &str) {
let turn = ConversationTurn {
user: message.to_string(),
agent: String::new(),
project_id: None,
timestamp: Utc::now(),
};
if let Some(last) = self.turns.back_mut() {
if last.agent.is_empty() && last.project_id.is_none() {
last.user = message.to_string();
last.timestamp = Utc::now();
return;
}
}
self.turns.push_back(turn);
while self.turns.len() > self.max_turns {
self.turns.pop_front();
}
}
pub fn push_agent(&mut self, response: &str, project_id: Option<Uuid>) {
if let Some(last) = self.turns.back_mut() {
last.agent = truncate_response(response, 200);
last.project_id = project_id;
}
}
pub fn recent(&self, n: usize) -> Vec<&ConversationTurn> {
self.turns.iter().rev().take(n).collect()
}
pub fn turns(&self) -> VecDeque<ConversationTurn> {
self.turns.clone()
}
pub fn len(&self) -> usize {
self.turns.len()
}
pub fn is_empty(&self) -> bool {
self.turns.is_empty()
}
pub fn should_check_topic(&self, min_turns: usize) -> bool {
self.turns_since_topic_check >= min_turns || self.pattern_changed()
}
pub fn mark_topic_checked(&mut self) {
self.turns_since_topic_check = 0;
}
pub fn record_turn(&mut self, min_turns: usize) -> bool {
self.turns_since_topic_check += 1;
self.should_check_topic(min_turns)
}
pub fn pattern_changed(&self) -> bool {
if self.turns.len() < 4 {
return false;
}
let all_turns: Vec<_> = self.turns.iter().collect();
let recent = &all_turns[all_turns.len() - 2..];
let previous = &all_turns[all_turns.len() - 4..all_turns.len() - 2];
let avg_recent =
recent.iter().map(|t| word_count(&t.user)).sum::<usize>() as f64 / recent.len() as f64;
let avg_prev = previous.iter().map(|t| word_count(&t.user)).sum::<usize>() as f64
/ previous.len() as f64;
let ratio = avg_recent / avg_prev.max(1.0);
!(0.5..=2.0).contains(&ratio)
}
pub fn clear(&mut self) {
self.turns.clear();
self.turns_since_topic_check = 0;
}
}
fn word_count(s: &str) -> usize {
s.split_whitespace().count()
}
fn truncate_response(response: &str, max_len: usize) -> String {
if response.len() <= max_len {
response.to_string()
} else {
let end = response
.char_indices()
.take_while(|(idx, _)| *idx < max_len)
.last()
.map(|(idx, c)| idx + c.len_utf8())
.unwrap_or(0);
if end == 0 {
"...".to_string()
} else {
format!("{}...", &response[..end])
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_push_user_and_agent() {
let mut buf = ConversationBuffer::new(10);
assert!(buf.is_empty());
buf.push_user("Hello, how are you?");
assert_eq!(buf.len(), 1);
buf.push_agent("I'm doing well!", None);
assert_eq!(buf.turns[0].agent, "I'm doing well!");
}
#[test]
fn test_max_capacity() {
let mut buf = ConversationBuffer::new(3);
for i in 1..=5 {
buf.push_user(&format!("msg{}", i));
buf.push_agent("r", None);
}
assert_eq!(buf.len(), 3);
assert_eq!(buf.recent(1)[0].user, "msg5");
}
#[test]
fn test_pattern_changed() {
let mut buf = ConversationBuffer::new(10);
for _ in 0..3 {
buf.push_user("hi");
buf.push_agent("hi", None);
}
assert!(!buf.pattern_changed());
buf.push_user("This is a very long message with many many many words to trigger detection");
buf.push_agent("ok", None);
assert!(buf.pattern_changed());
}
}