agcodex_core/conversation/
undo_redo.rs

1//! Undo/Redo system for conversation management
2//!
3//! This module provides functionality to track conversation states and enable
4//! undo/redo operations with branch preservation and memory-efficient snapshots.
5
6use std::collections::VecDeque;
7use std::sync::Arc;
8use std::time::SystemTime;
9
10use serde::Deserialize;
11use serde::Serialize;
12use uuid::Uuid;
13
14use crate::error::CodexErr;
15use crate::error::Result as CodexResult;
16use crate::models::ContentItem;
17use crate::models::ResponseItem;
18
19/// Maximum number of undo states to keep in memory
20const MAX_UNDO_STATES: usize = 50;
21
22/// Maximum size in bytes for a single snapshot before compression
23const MAX_SNAPSHOT_SIZE: usize = 10 * 1024 * 1024; // 10MB
24
25/// Represents a complete snapshot of conversation state at a point in time
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct ConversationSnapshot {
28    /// Unique identifier for this snapshot
29    pub id: Uuid,
30    /// Timestamp when this snapshot was created
31    pub timestamp: SystemTime,
32    /// The conversation items at this point
33    pub items: Vec<ResponseItem>,
34    /// Metadata about the conversation state
35    pub metadata: SnapshotMetadata,
36    /// Branch information if this is a branch point
37    pub branch_info: Option<BranchInfo>,
38    /// Size in bytes (estimated)
39    pub size_bytes: usize,
40    /// Whether this snapshot is compressed
41    pub compressed: bool,
42}
43
44/// Metadata associated with a conversation snapshot
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct SnapshotMetadata {
47    /// Turn number in the conversation
48    pub turn_number: usize,
49    /// Total token count up to this point
50    pub _total_tokens: usize,
51    /// Active model at this point
52    pub _model: String,
53    /// Active mode (Plan/Build/Review)
54    pub mode: String,
55    /// User who created this turn
56    pub user: Option<String>,
57    /// Custom tags for this snapshot
58    pub tags: Vec<String>,
59}
60
61/// Information about a branch in the conversation
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct BranchInfo {
64    /// Name of the branch
65    pub name: String,
66    /// Parent snapshot ID this branches from
67    pub parent_id: Uuid,
68    /// Description of why this branch was created
69    pub description: Option<String>,
70    /// Whether this is the active branch
71    pub is_active: bool,
72}
73
74/// Manages undo/redo operations for a conversation
75pub struct UndoRedoManager {
76    /// Stack of undo states (newest last)
77    undo_stack: VecDeque<Arc<ConversationSnapshot>>,
78    /// Stack of redo states (newest last)
79    redo_stack: VecDeque<Arc<ConversationSnapshot>>,
80    /// Current conversation state
81    current_state: Option<Arc<ConversationSnapshot>>,
82    /// All branches indexed by their parent snapshot ID
83    branches: std::collections::HashMap<Uuid, Vec<Arc<ConversationSnapshot>>>,
84    /// Memory usage tracking
85    total_memory_usage: usize,
86    /// Maximum memory usage allowed (in bytes)
87    max_memory_usage: usize,
88}
89
90impl UndoRedoManager {
91    /// Create a new undo/redo manager
92    pub fn new() -> Self {
93        Self {
94            undo_stack: VecDeque::with_capacity(MAX_UNDO_STATES),
95            redo_stack: VecDeque::new(),
96            current_state: None,
97            branches: std::collections::HashMap::new(),
98            total_memory_usage: 0,
99            max_memory_usage: 100 * 1024 * 1024, // 100MB default
100        }
101    }
102
103    /// Create a new undo/redo manager with custom memory limit
104    pub fn with_memory_limit(max_memory_mb: usize) -> Self {
105        Self {
106            undo_stack: VecDeque::with_capacity(MAX_UNDO_STATES),
107            redo_stack: VecDeque::new(),
108            current_state: None,
109            branches: std::collections::HashMap::new(),
110            total_memory_usage: 0,
111            max_memory_usage: max_memory_mb * 1024 * 1024,
112        }
113    }
114
115    /// Save the current conversation state
116    pub fn save_state(
117        &mut self,
118        items: Vec<ResponseItem>,
119        metadata: SnapshotMetadata,
120    ) -> CodexResult<Uuid> {
121        // Clear redo stack when new state is saved
122        self.redo_stack.clear();
123
124        // Move current state to undo stack if it exists
125        if let Some(current) = self.current_state.take() {
126            self.push_to_undo_stack(current);
127        }
128
129        // Create new snapshot
130        let snapshot = self.create_snapshot(items, metadata)?;
131        let snapshot_id = snapshot.id;
132        let snapshot_arc = Arc::new(snapshot);
133
134        // Update memory tracking
135        self.total_memory_usage += snapshot_arc.size_bytes;
136        self.enforce_memory_limit();
137
138        // Set as current state
139        self.current_state = Some(snapshot_arc);
140
141        Ok(snapshot_id)
142    }
143
144    /// Undo the last operation
145    pub fn undo(&mut self) -> CodexResult<Option<ConversationSnapshot>> {
146        if self.undo_stack.is_empty() {
147            return Ok(None);
148        }
149
150        // Move current state to redo stack
151        if let Some(current) = self.current_state.take() {
152            self.push_to_redo_stack(current);
153        }
154
155        // Pop from undo stack and set as current
156        if let Some(previous) = self.undo_stack.pop_back() {
157            let snapshot = (*previous).clone();
158            self.current_state = Some(previous);
159            Ok(Some(snapshot))
160        } else {
161            Ok(None)
162        }
163    }
164
165    /// Redo the last undone operation
166    pub fn redo(&mut self) -> CodexResult<Option<ConversationSnapshot>> {
167        if self.redo_stack.is_empty() {
168            return Ok(None);
169        }
170
171        // Move current state to undo stack
172        if let Some(current) = self.current_state.take() {
173            self.push_to_undo_stack(current);
174        }
175
176        // Pop from redo stack and set as current
177        if let Some(next) = self.redo_stack.pop_back() {
178            let snapshot = (*next).clone();
179            self.current_state = Some(next);
180            Ok(Some(snapshot))
181        } else {
182            Ok(None)
183        }
184    }
185
186    /// Create a branch from the current state
187    pub fn create_branch(
188        &mut self,
189        branch_name: String,
190        description: Option<String>,
191        items: Vec<ResponseItem>,
192        metadata: SnapshotMetadata,
193    ) -> CodexResult<Uuid> {
194        let parent_id = self
195            .current_state
196            .as_ref()
197            .map(|s| s.id)
198            .ok_or(CodexErr::NoBranchPointAvailable)?;
199
200        let branch_info = BranchInfo {
201            name: branch_name,
202            parent_id,
203            description,
204            is_active: false,
205        };
206
207        let mut snapshot = self.create_snapshot(items, metadata)?;
208        snapshot.branch_info = Some(branch_info);
209
210        let snapshot_id = snapshot.id;
211        let snapshot_arc = Arc::new(snapshot);
212
213        // Add to branches map
214        self.branches
215            .entry(parent_id)
216            .or_default()
217            .push(snapshot_arc.clone());
218
219        // Update memory tracking
220        self.total_memory_usage += snapshot_arc.size_bytes;
221        self.enforce_memory_limit();
222
223        Ok(snapshot_id)
224    }
225
226    /// Switch to a specific branch
227    pub fn switch_to_branch(
228        &mut self,
229        branch_id: Uuid,
230    ) -> CodexResult<Option<ConversationSnapshot>> {
231        // Find the branch snapshot
232        let branch_snapshot = self
233            .branches
234            .values()
235            .flatten()
236            .find(|s| s.id == branch_id)
237            .cloned();
238
239        if let Some(snapshot) = branch_snapshot {
240            // Save current state to undo stack
241            if let Some(current) = self.current_state.take() {
242                self.push_to_undo_stack(current);
243            }
244
245            // Clear redo stack when switching branches
246            self.redo_stack.clear();
247
248            // Set branch as current
249            let result = (*snapshot).clone();
250            self.current_state = Some(snapshot);
251
252            Ok(Some(result))
253        } else {
254            Ok(None)
255        }
256    }
257
258    /// Get all available branches
259    pub fn get_branches(&self) -> Vec<(Uuid, BranchInfo)> {
260        self.branches
261            .values()
262            .flatten()
263            .filter_map(|s| s.branch_info.as_ref().map(|b| (s.id, b.clone())))
264            .collect()
265    }
266
267    /// Get the current state
268    pub fn current_state(&self) -> Option<&ConversationSnapshot> {
269        self.current_state.as_deref()
270    }
271
272    /// Get undo history (for visualization)
273    pub fn undo_history(&self) -> Vec<&ConversationSnapshot> {
274        self.undo_stack.iter().map(|s| s.as_ref()).collect()
275    }
276
277    /// Get redo history (for visualization)
278    pub fn redo_history(&self) -> Vec<&ConversationSnapshot> {
279        self.redo_stack.iter().map(|s| s.as_ref()).collect()
280    }
281
282    /// Clear all history
283    pub fn clear(&mut self) {
284        self.undo_stack.clear();
285        self.redo_stack.clear();
286        self.current_state = None;
287        self.branches.clear();
288        self.total_memory_usage = 0;
289    }
290
291    /// Create a checkpoint (special snapshot with tag)
292    pub fn create_checkpoint(&mut self, name: String) -> CodexResult<Uuid> {
293        if let Some(current) = &self.current_state {
294            let mut checkpoint = (**current).clone();
295            checkpoint.id = Uuid::new_v4();
296            checkpoint.timestamp = SystemTime::now();
297            checkpoint
298                .metadata
299                .tags
300                .push(format!("checkpoint:{}", name));
301
302            let checkpoint_id = checkpoint.id;
303            let checkpoint_arc = Arc::new(checkpoint);
304
305            // Store in branches as a special branch
306            self.branches
307                .entry(current.id)
308                .or_default()
309                .push(checkpoint_arc);
310
311            Ok(checkpoint_id)
312        } else {
313            Err(CodexErr::NoCurrentStateForCheckpoint)
314        }
315    }
316
317    /// Restore from a checkpoint
318    pub fn restore_checkpoint(
319        &mut self,
320        checkpoint_id: Uuid,
321    ) -> CodexResult<Option<ConversationSnapshot>> {
322        self.switch_to_branch(checkpoint_id)
323    }
324
325    /// Get memory usage information
326    pub fn memory_info(&self) -> MemoryInfo {
327        MemoryInfo {
328            total_usage_bytes: self.total_memory_usage,
329            max_usage_bytes: self.max_memory_usage,
330            undo_stack_size: self.undo_stack.len(),
331            redo_stack_size: self.redo_stack.len(),
332            branch_count: self.branches.values().map(|v| v.len()).sum(),
333            usage_percentage: (self.total_memory_usage as f64 / self.max_memory_usage as f64)
334                * 100.0,
335        }
336    }
337
338    // Private helper methods
339
340    fn create_snapshot(
341        &self,
342        items: Vec<ResponseItem>,
343        metadata: SnapshotMetadata,
344    ) -> CodexResult<ConversationSnapshot> {
345        let size_bytes = Self::estimate_size(&items);
346        let compressed = size_bytes > MAX_SNAPSHOT_SIZE;
347
348        // In a real implementation, we would compress large snapshots here
349        let snapshot = ConversationSnapshot {
350            id: Uuid::new_v4(),
351            timestamp: SystemTime::now(),
352            items,
353            metadata,
354            branch_info: None,
355            size_bytes,
356            compressed,
357        };
358
359        Ok(snapshot)
360    }
361
362    fn push_to_undo_stack(&mut self, snapshot: Arc<ConversationSnapshot>) {
363        // Enforce maximum undo states
364        while self.undo_stack.len() >= MAX_UNDO_STATES {
365            if let Some(removed) = self.undo_stack.pop_front() {
366                self.total_memory_usage =
367                    self.total_memory_usage.saturating_sub(removed.size_bytes);
368            }
369        }
370        self.undo_stack.push_back(snapshot);
371    }
372
373    fn push_to_redo_stack(&mut self, snapshot: Arc<ConversationSnapshot>) {
374        self.redo_stack.push_back(snapshot);
375    }
376
377    fn enforce_memory_limit(&mut self) {
378        // Remove oldest snapshots if memory limit is exceeded
379        while self.total_memory_usage > self.max_memory_usage && !self.undo_stack.is_empty() {
380            if let Some(removed) = self.undo_stack.pop_front() {
381                self.total_memory_usage =
382                    self.total_memory_usage.saturating_sub(removed.size_bytes);
383            }
384        }
385    }
386
387    fn estimate_size(items: &[ResponseItem]) -> usize {
388        // Simple size estimation based on content
389        items
390            .iter()
391            .map(|item| match item {
392                ResponseItem::Message { content, .. } => {
393                    content
394                        .iter()
395                        .map(|c| match c {
396                            ContentItem::InputText { text } | ContentItem::OutputText { text } => {
397                                text.len()
398                            }
399                            ContentItem::InputImage { .. } => 1024, // Estimate for image metadata
400                        })
401                        .sum::<usize>()
402                }
403                ResponseItem::Reasoning {
404                    summary, content, ..
405                } => {
406                    // Estimate size based on summary and optional content
407                    let summary_size: usize = summary.iter().map(|_| 100).sum(); // Estimate 100 bytes per summary item
408                    let content_size = content.as_ref().map(|c| c.len() * 50).unwrap_or(0); // Estimate 50 bytes per content item
409                    summary_size + content_size
410                }
411                ResponseItem::FunctionCall { arguments, .. } => arguments.len(),
412                ResponseItem::FunctionCallOutput { output, .. } => {
413                    // Estimate based on the payload content
414                    output.content.len() + 100 // Content plus overhead
415                }
416                _ => 256, // Default estimate for other types
417            })
418            .sum()
419    }
420}
421
422impl Default for UndoRedoManager {
423    fn default() -> Self {
424        Self::new()
425    }
426}
427
428/// Memory usage information
429#[derive(Debug, Clone)]
430pub struct MemoryInfo {
431    pub total_usage_bytes: usize,
432    pub max_usage_bytes: usize,
433    pub undo_stack_size: usize,
434    pub redo_stack_size: usize,
435    pub branch_count: usize,
436    pub usage_percentage: f64,
437}
438
439/// Diff between two conversation states (for efficient storage)
440#[derive(Debug, Clone, Serialize, Deserialize)]
441pub struct ConversationDiff {
442    /// Items added in this diff
443    pub added: Vec<ResponseItem>,
444    /// Indices of items removed
445    pub removed: Vec<usize>,
446    /// Items that were modified (index, new_item)
447    pub modified: Vec<(usize, ResponseItem)>,
448}
449
450impl ConversationDiff {
451    /// Create a diff between two conversation states
452    pub fn create(old: &[ResponseItem], new: &[ResponseItem]) -> Self {
453        let mut added = Vec::new();
454        let mut removed = Vec::new();
455        let mut modified = Vec::new();
456
457        // Simple diff algorithm - can be optimized with proper diffing
458        let min_len = old.len().min(new.len());
459
460        // Check for modifications in common range
461        for i in 0..min_len {
462            if !Self::items_equal(&old[i], &new[i]) {
463                modified.push((i, new[i].clone()));
464            }
465        }
466
467        // Check for additions
468        if new.len() > old.len() {
469            added.extend(new[old.len()..].iter().cloned());
470        }
471
472        // Check for removals
473        if old.len() > new.len() {
474            for i in new.len()..old.len() {
475                removed.push(i);
476            }
477        }
478
479        Self {
480            added,
481            removed,
482            modified,
483        }
484    }
485
486    /// Apply this diff to a conversation state
487    pub fn apply(&self, items: &mut Vec<ResponseItem>) {
488        // Apply modifications
489        for (index, new_item) in &self.modified {
490            if *index < items.len() {
491                items[*index] = new_item.clone();
492            }
493        }
494
495        // Remove items (in reverse order to maintain indices)
496        for &index in self.removed.iter().rev() {
497            if index < items.len() {
498                items.remove(index);
499            }
500        }
501
502        // Add new items
503        items.extend(self.added.iter().cloned());
504    }
505
506    fn items_equal(a: &ResponseItem, b: &ResponseItem) -> bool {
507        // Simple equality check - could be optimized
508        match (a, b) {
509            (
510                ResponseItem::Message {
511                    role: r1,
512                    content: c1,
513                    ..
514                },
515                ResponseItem::Message {
516                    role: r2,
517                    content: c2,
518                    ..
519                },
520            ) => r1 == r2 && Self::content_equal(c1, c2),
521            _ => false, // Different types or other items
522        }
523    }
524
525    fn content_equal(a: &[ContentItem], b: &[ContentItem]) -> bool {
526        if a.len() != b.len() {
527            return false;
528        }
529        a.iter()
530            .zip(b.iter())
531            .all(|(a_item, b_item)| match (a_item, b_item) {
532                (
533                    ContentItem::InputText { text: t1 } | ContentItem::OutputText { text: t1 },
534                    ContentItem::InputText { text: t2 } | ContentItem::OutputText { text: t2 },
535                ) => t1 == t2,
536                _ => false,
537            })
538    }
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544
545    fn create_test_message(role: &str, content: &str) -> ResponseItem {
546        ResponseItem::Message {
547            id: None,
548            role: role.to_string(),
549            content: vec![ContentItem::OutputText {
550                text: content.to_string(),
551            }],
552        }
553    }
554
555    fn create_test_metadata(turn: usize) -> SnapshotMetadata {
556        SnapshotMetadata {
557            turn_number: turn,
558            _total_tokens: turn * 100,
559            _model: "test-model".to_string(),
560            mode: "Build".to_string(),
561            user: None,
562            tags: Vec::new(),
563        }
564    }
565
566    #[test]
567    fn test_save_and_undo() {
568        let mut manager = UndoRedoManager::new();
569
570        // Save first state
571        let items1 = vec![create_test_message("user", "Hello")];
572        let _id1 = manager
573            .save_state(items1.clone(), create_test_metadata(1))
574            .unwrap();
575        assert!(manager.current_state().is_some());
576
577        // Save second state
578        let items2 = vec![
579            create_test_message("user", "Hello"),
580            create_test_message("assistant", "Hi there"),
581        ];
582        let _id2 = manager
583            .save_state(items2.clone(), create_test_metadata(2))
584            .unwrap();
585
586        // Undo should return to first state
587        let undone = manager.undo().unwrap();
588        assert!(undone.is_some());
589        assert_eq!(undone.unwrap().items.len(), 1);
590    }
591
592    #[test]
593    fn test_undo_redo() {
594        let mut manager = UndoRedoManager::new();
595
596        // Create three states
597        let items1 = vec![create_test_message("user", "1")];
598        manager.save_state(items1, create_test_metadata(1)).unwrap();
599
600        let items2 = vec![create_test_message("user", "2")];
601        manager.save_state(items2, create_test_metadata(2)).unwrap();
602
603        let items3 = vec![create_test_message("user", "3")];
604        manager.save_state(items3, create_test_metadata(3)).unwrap();
605
606        // Undo twice
607        manager.undo().unwrap();
608        let state = manager.undo().unwrap().unwrap();
609        assert_eq!(state.metadata.turn_number, 1);
610
611        // Redo once
612        let state = manager.redo().unwrap().unwrap();
613        assert_eq!(state.metadata.turn_number, 2);
614    }
615
616    #[test]
617    fn test_branching() {
618        let mut manager = UndoRedoManager::new();
619
620        // Create initial state
621        let items1 = vec![create_test_message("user", "main")];
622        manager.save_state(items1, create_test_metadata(1)).unwrap();
623
624        // Create a branch
625        let branch_items = vec![create_test_message("user", "branch")];
626        let branch_id = manager
627            .create_branch(
628                "Alternative".to_string(),
629                Some("Testing branch".to_string()),
630                branch_items,
631                create_test_metadata(2),
632            )
633            .unwrap();
634
635        // Verify branch exists
636        let branches = manager.get_branches();
637        assert_eq!(branches.len(), 1);
638        assert_eq!(branches[0].1.name, "Alternative");
639
640        // Switch to branch
641        let switched = manager.switch_to_branch(branch_id).unwrap();
642        assert!(switched.is_some());
643    }
644
645    #[test]
646    fn test_memory_limit() {
647        let mut manager = UndoRedoManager::with_memory_limit(1); // 1MB limit
648
649        // Add many states to exceed memory limit
650        for i in 0..100 {
651            let items = vec![create_test_message("user", &"x".repeat(20000))]; // ~20KB each
652            manager.save_state(items, create_test_metadata(i)).unwrap();
653        }
654
655        // Check that old states were removed to stay under limit
656        let info = manager.memory_info();
657        assert!(info.total_usage_bytes <= info.max_usage_bytes);
658        assert!(manager.undo_stack.len() < 100);
659    }
660
661    #[test]
662    fn test_checkpoint() {
663        let mut manager = UndoRedoManager::new();
664
665        // Create initial state
666        let items = vec![create_test_message("user", "checkpoint test")];
667        manager.save_state(items, create_test_metadata(1)).unwrap();
668
669        // Create checkpoint
670        let checkpoint_id = manager
671            .create_checkpoint("test_checkpoint".to_string())
672            .unwrap();
673
674        // Make more changes
675        let items2 = vec![create_test_message("user", "after checkpoint")];
676        manager.save_state(items2, create_test_metadata(2)).unwrap();
677
678        // Restore checkpoint
679        let restored = manager.restore_checkpoint(checkpoint_id).unwrap();
680        assert!(restored.is_some());
681        assert!(
682            restored
683                .unwrap()
684                .metadata
685                .tags
686                .contains(&"checkpoint:test_checkpoint".to_string())
687        );
688    }
689
690    #[test]
691    fn test_conversation_diff() {
692        let old = vec![
693            create_test_message("user", "Hello"),
694            create_test_message("assistant", "Hi"),
695        ];
696
697        let new = vec![
698            create_test_message("user", "Hello"),
699            create_test_message("assistant", "Hi there!"),
700            create_test_message("user", "How are you?"),
701        ];
702
703        let diff = ConversationDiff::create(&old, &new);
704        assert_eq!(diff.modified.len(), 1);
705        assert_eq!(diff.added.len(), 1);
706        assert_eq!(diff.removed.len(), 0);
707
708        // Apply diff
709        let mut result = old.clone();
710        diff.apply(&mut result);
711        assert_eq!(result.len(), new.len());
712    }
713}