use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use super::SpaceId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConversationTurn {
pub user: String,
pub agent: String,
pub space_id: SpaceId,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone)]
pub struct ConversationBuffer {
turns: VecDeque<ConversationTurn>,
max_turns: usize,
turns_since_topic_check: usize,
last_space_id: Option<SpaceId>,
}
impl Default for ConversationBuffer {
fn default() -> Self {
Self::new(50)
}
}
impl ConversationBuffer {
pub fn new(max_turns: usize) -> Self {
Self {
turns: VecDeque::with_capacity(max_turns),
max_turns,
turns_since_topic_check: 0,
last_space_id: None,
}
}
pub fn push_user(&mut self, message: &str) {
let turn = ConversationTurn {
user: message.to_string(),
agent: String::new(), space_id: SpaceId::nil(), timestamp: Utc::now(),
};
if let Some(last) = self.turns.back_mut() {
if last.agent.is_empty() && last.space_id == SpaceId::nil() {
last.user = message.to_string();
last.timestamp = Utc::now();
return;
}
}
self.turns.push_back(turn);
while self.turns.len() > self.max_turns {
self.turns.pop_front();
}
}
pub fn push_agent(&mut self, response: &str, space_id: &SpaceId) {
if let Some(last) = self.turns.back_mut() {
last.agent = truncate_response(response, 200);
last.space_id = *space_id;
self.last_space_id = Some(*space_id);
}
}
pub fn recent(&self, n: usize) -> Vec<&ConversationTurn> {
self.turns.iter().rev().take(n).collect()
}
pub fn turns(&self) -> std::collections::VecDeque<ConversationTurn> {
self.turns.clone()
}
pub fn len(&self) -> usize {
self.turns.len()
}
pub fn is_empty(&self) -> bool {
self.turns.is_empty()
}
pub fn should_check_topic(&self, min_turns: usize) -> bool {
self.turns_since_topic_check >= min_turns || self.pattern_changed()
}
pub fn mark_topic_checked(&mut self) {
self.turns_since_topic_check = 0;
}
pub fn record_turn(&mut self, min_turns: usize) -> bool {
self.turns_since_topic_check += 1;
self.should_check_topic(min_turns)
}
pub fn pattern_changed(&self) -> bool {
if self.turns.len() < 4 {
return false;
}
let all_turns: Vec<_> = self.turns.iter().collect();
let recent = &all_turns[all_turns.len() - 2..];
let previous = &all_turns[all_turns.len() - 4..all_turns.len() - 2];
let avg_word_count_recent =
recent.iter().map(|t| word_count(&t.user)).sum::<usize>() as f64 / recent.len() as f64;
let avg_word_count_prev = previous.iter().map(|t| word_count(&t.user)).sum::<usize>()
as f64
/ previous.len() as f64;
let ratio = avg_word_count_recent / avg_word_count_prev.max(1.0);
if !(0.5..=2.0).contains(&ratio) {
return true;
}
let domain_shift_keywords = [
("code", "food"),
("rust", "요리"),
("bug", "저녁"),
("file", "운동"),
("import", "영화"),
("commit", "음식"),
("function", "게임"),
("Cargo", "장보기"),
];
let recent_text = recent
.iter()
.map(|t| t.user.to_lowercase())
.collect::<String>();
let prev_text = previous
.iter()
.map(|t| t.user.to_lowercase())
.collect::<String>();
for (prev_kw, recent_kw) in domain_shift_keywords {
let has_prev = prev_text.contains(prev_kw);
let has_recent = recent_text.contains(recent_kw);
if has_prev && !has_recent {
return true;
}
}
false
}
pub fn space_changed(&self) -> bool {
if self.turns.len() < 2 {
return false;
}
let all_turns: Vec<_> = self.turns.iter().collect();
let last = &all_turns[all_turns.len() - 1];
let prev = &all_turns[all_turns.len() - 2];
last.space_id != prev.space_id
}
pub fn last_space_id(&self) -> Option<SpaceId> {
self.last_space_id
}
pub fn clear(&mut self) {
self.turns.clear();
self.turns_since_topic_check = 0;
}
}
fn word_count(s: &str) -> usize {
s.split_whitespace().count()
}
fn truncate_response(response: &str, max_len: usize) -> String {
if response.len() <= max_len {
response.to_string()
} else {
format!("{}...", &response[..max_len])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_push_user_and_agent() {
let mut buf = ConversationBuffer::new(10);
assert!(buf.is_empty());
buf.push_user("Hello, how are you?");
assert_eq!(buf.len(), 1);
assert_eq!(buf.turns[0].user, "Hello, how are you?");
assert!(buf.turns[0].agent.is_empty());
buf.push_agent("I'm doing well!", &SpaceId::nil());
assert_eq!(buf.turns[0].agent, "I'm doing well!");
}
#[test]
fn test_max_capacity() {
let mut buf = ConversationBuffer::new(3);
let space = SpaceId::nil();
buf.push_user("msg1");
buf.push_agent("r1", &space);
buf.push_user("msg2");
buf.push_agent("r2", &space);
buf.push_user("msg3");
buf.push_agent("r3", &space);
buf.push_user("msg4");
buf.push_agent("r4", &space);
buf.push_user("msg5");
buf.push_agent("r5", &space);
assert_eq!(buf.len(), 3);
assert_eq!(buf.recent(1)[0].user, "msg5");
}
#[test]
fn test_should_check_topic() {
let mut buf = ConversationBuffer::new(10);
assert!(!buf.should_check_topic(3));
for _ in 0..3 {
buf.push_user("test");
buf.mark_topic_checked();
}
assert!(!buf.should_check_topic(3));
}
#[test]
fn test_pattern_changed_word_count() {
let mut buf = ConversationBuffer::new(10);
let space = SpaceId::nil();
for _ in 0..3 {
buf.push_user("hi");
buf.push_agent("hi", &space);
}
assert!(!buf.pattern_changed());
buf.push_user("This is a very long message that contains many many many many many words to trigger the pattern detection");
buf.push_agent("ok", &space);
assert!(buf.pattern_changed());
}
#[test]
fn test_truncate_response() {
let short = "Hello";
assert_eq!(truncate_response(short, 10), "Hello");
let long = "This is a very long response";
let truncated = truncate_response(long, 10);
assert_eq!(truncated.len(), 13); assert!(truncated.ends_with("..."));
}
#[test]
fn test_recent_turns() {
let mut buf = ConversationBuffer::new(10);
let space = SpaceId::nil();
for i in 0..5 {
buf.push_user(&format!("msg{}", i));
buf.push_agent(&format!("resp{}", i), &space);
}
let recent = buf.recent(3);
assert_eq!(recent.len(), 3);
assert_eq!(recent[0].user, "msg4"); assert_eq!(recent[2].user, "msg2");
}
}