use std::collections::VecDeque;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::llm::ChatMessage;
const DEFAULT_MAX_CHECKPOINTS: usize = 20;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: Uuid,
pub turn_number: usize,
pub messages: Vec<ChatMessage>,
pub description: String,
}
impl Checkpoint {
pub fn new(
turn_number: usize,
messages: Vec<ChatMessage>,
description: impl Into<String>,
) -> Self {
Self {
id: Uuid::new_v4(),
turn_number,
messages,
description: description.into(),
}
}
}
pub struct UndoManager {
undo_stack: VecDeque<Checkpoint>,
redo_stack: Vec<Checkpoint>,
max_checkpoints: usize,
}
impl UndoManager {
pub fn new() -> Self {
Self {
undo_stack: VecDeque::new(),
redo_stack: Vec::new(),
max_checkpoints: DEFAULT_MAX_CHECKPOINTS,
}
}
pub fn with_max_checkpoints(mut self, max: usize) -> Self {
self.max_checkpoints = max;
self
}
fn push_undo(&mut self, checkpoint: Checkpoint) {
self.undo_stack.push_back(checkpoint);
while self.undo_stack.len() > self.max_checkpoints {
self.undo_stack.pop_front();
}
}
pub fn checkpoint(
&mut self,
turn_number: usize,
messages: Vec<ChatMessage>,
description: impl Into<String>,
) {
self.redo_stack.clear();
let checkpoint = Checkpoint::new(turn_number, messages, description);
self.push_undo(checkpoint);
}
pub fn undo(
&mut self,
current_turn: usize,
current_messages: Vec<ChatMessage>,
) -> Option<Checkpoint> {
if self.undo_stack.is_empty() {
return None;
}
let current = Checkpoint::new(
current_turn,
current_messages,
format!("Turn {}", current_turn),
);
self.redo_stack.push(current);
self.undo_stack.pop_back()
}
pub fn pop_undo(&mut self) -> Option<Checkpoint> {
self.undo_stack.pop_back()
}
pub fn redo(
&mut self,
current_turn: usize,
current_messages: Vec<ChatMessage>,
) -> Option<Checkpoint> {
if self.redo_stack.is_empty() {
return None;
}
let current = Checkpoint::new(
current_turn,
current_messages,
format!("Turn {}", current_turn),
);
self.push_undo(current);
self.redo_stack.pop()
}
pub fn can_undo(&self) -> bool {
!self.undo_stack.is_empty()
}
pub fn can_redo(&self) -> bool {
!self.redo_stack.is_empty()
}
pub fn undo_count(&self) -> usize {
self.undo_stack.len()
}
pub fn redo_count(&self) -> usize {
self.redo_stack.len()
}
pub fn get_checkpoint(&self, id: Uuid) -> Option<&Checkpoint> {
self.undo_stack
.iter()
.find(|c| c.id == id)
.or_else(|| self.redo_stack.iter().find(|c| c.id == id))
}
pub fn list_checkpoints(&self) -> Vec<&Checkpoint> {
self.undo_stack.iter().collect()
}
pub fn clear(&mut self) {
self.undo_stack.clear();
self.redo_stack.clear();
}
pub fn restore(&mut self, checkpoint_id: Uuid) -> Option<Checkpoint> {
let pos = self.undo_stack.iter().position(|c| c.id == checkpoint_id)?;
self.redo_stack.clear();
while self.undo_stack.len() > pos + 1 {
self.undo_stack.pop_back();
}
self.undo_stack.pop_back()
}
}
impl Default for UndoManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_creation() {
let mut manager = UndoManager::new();
manager.checkpoint(0, vec![], "Initial state");
manager.checkpoint(1, vec![ChatMessage::user("Hello")], "Turn 1");
assert_eq!(manager.undo_count(), 2);
}
#[test]
fn test_undo_redo() {
let mut manager = UndoManager::new();
manager.checkpoint(0, vec![], "Turn 0");
manager.checkpoint(1, vec![ChatMessage::user("Hello")], "Turn 1");
assert!(manager.can_undo());
assert!(!manager.can_redo());
let current = vec![ChatMessage::user("Hello"), ChatMessage::assistant("Hi")];
let checkpoint = manager.undo(2, current);
assert!(checkpoint.is_some());
let checkpoint = checkpoint.unwrap();
assert_eq!(checkpoint.turn_number, 1);
assert!(manager.can_redo());
let restored = manager.redo(checkpoint.turn_number, checkpoint.messages);
assert!(restored.is_some());
}
#[test]
fn test_max_checkpoints() {
let mut manager = UndoManager::new().with_max_checkpoints(3);
for i in 0..5 {
manager.checkpoint(i, vec![], format!("Turn {}", i));
}
assert_eq!(manager.undo_count(), 3);
}
#[test]
fn test_restore_to_checkpoint() {
let mut manager = UndoManager::new();
manager.checkpoint(0, vec![], "Turn 0");
let checkpoint_id = manager.undo_stack.back().unwrap().id;
manager.checkpoint(1, vec![], "Turn 1");
manager.checkpoint(2, vec![], "Turn 2");
let restored = manager.restore(checkpoint_id);
assert!(restored.is_some());
assert_eq!(manager.undo_count(), 0);
}
#[test]
fn test_repeated_undo_advances_through_stack() {
let mut manager = UndoManager::new();
manager.checkpoint(0, vec![], "Turn 0");
manager.checkpoint(1, vec![ChatMessage::user("msg1")], "Turn 1");
manager.checkpoint(2, vec![ChatMessage::user("msg2")], "Turn 2");
assert_eq!(manager.undo_count(), 3);
let cp1 = manager
.undo(3, vec![ChatMessage::user("msg3")])
.expect("first undo should succeed");
assert_eq!(cp1.turn_number, 2);
assert_eq!(manager.undo_count(), 2);
let cp2 = manager
.undo(cp1.turn_number, cp1.messages)
.expect("second undo should succeed");
assert_eq!(cp2.turn_number, 1);
assert_eq!(manager.undo_count(), 1);
assert_ne!(cp1.turn_number, cp2.turn_number);
}
#[test]
fn test_undo_redo_cycle_preserves_state() {
let mut manager = UndoManager::new();
let msgs_t0: Vec<ChatMessage> = vec![];
let msgs_t1 = vec![ChatMessage::user("hello")];
let msgs_t2 = vec![ChatMessage::user("hello"), ChatMessage::assistant("hi")];
manager.checkpoint(0, msgs_t0, "Turn 0");
manager.checkpoint(1, msgs_t1, "Turn 1");
let cp_undo1 = manager
.undo(2, msgs_t2.clone())
.expect("undo should succeed");
assert_eq!(cp_undo1.turn_number, 1);
let cp_redo = manager
.redo(cp_undo1.turn_number, cp_undo1.messages)
.expect("redo should succeed");
assert_eq!(cp_redo.turn_number, 2);
assert_eq!(cp_redo.messages.len(), 2);
let cp_undo2 = manager
.undo(cp_redo.turn_number, cp_redo.messages)
.expect("second undo should succeed");
assert_eq!(cp_undo2.turn_number, 1);
}
#[test]
fn test_undo_redo_stack_sizes_consistent() {
let mut manager = UndoManager::new();
manager.checkpoint(0, vec![], "Turn 0");
manager.checkpoint(1, vec![ChatMessage::user("a")], "Turn 1");
manager.checkpoint(2, vec![ChatMessage::user("b")], "Turn 2");
let total = manager.undo_count() + manager.redo_count();
assert_eq!(total, 3);
let cp = manager.undo(3, vec![]).unwrap();
assert_eq!(manager.undo_count() + manager.redo_count(), 3);
let cp2 = manager.redo(cp.turn_number, cp.messages).unwrap();
assert_eq!(manager.undo_count() + manager.redo_count(), 3);
let _cp3 = manager.undo(cp2.turn_number, cp2.messages).unwrap();
assert_eq!(manager.undo_count() + manager.redo_count(), 3);
}
}