Skip to main content

atomcode_core/conversation/
turn.rs

1//! Turn tracking for conversation context management.
2//!
3//! A "turn" is one user request and everything that follows until the next
4//! user message (assistant text, tool calls, tool results). Tracking turns
5//! allows the windowing algorithm to operate at semantic boundaries instead
6//! of raw message indices.
7
8use super::message::{Message, Role};
9
10/// Status of a conversation turn.
11#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
12pub enum TurnStatus {
13    /// Currently being processed by the agent loop.
14    Active,
15    /// Turn finished (agent returned final text, no more tool calls).
16    Completed,
17    /// A summary has been generated for this turn (Phase 3).
18    Summarized,
19}
20
21/// A single conversation turn: one user request + all resulting messages.
22#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
23pub struct Turn {
24    /// Index of the user message that started this turn (into Vec<Message>).
25    pub start_idx: usize,
26    /// Number of messages in this turn (including the user message).
27    pub msg_count: usize,
28    /// Turn status.
29    pub status: TurnStatus,
30    /// Semantic summary of this turn (populated by Phase 3 LLM summarization).
31    /// When present, the windowing algorithm can inject this instead of
32    /// individual messages, drastically reducing token usage.
33    pub summary: Option<String>,
34}
35
36impl Turn {
37    /// Exclusive end index: start_idx + msg_count.
38    pub fn end_idx(&self) -> usize {
39        self.start_idx + self.msg_count
40    }
41}
42
43/// Tracks turn boundaries over a `Vec<Message>`.
44///
45/// This is a lightweight index — it doesn't own the messages, just tracks
46/// where each turn starts and how many messages it contains.
47#[derive(Debug, Clone, Default)]
48pub struct TurnTracker {
49    pub turns: Vec<Turn>,
50}
51
52impl TurnTracker {
53    pub fn new() -> Self {
54        Self { turns: Vec::new() }
55    }
56
57    /// Rebuild the turn index from an existing message list.
58    /// Used when loading conversation history from disk.
59    pub fn rebuild(messages: &[Message]) -> Self {
60        let mut tracker = Self::new();
61        for (i, msg) in messages.iter().enumerate() {
62            if matches!(msg.role, Role::User) {
63                // Close the previous turn if any
64                if let Some(prev) = tracker.turns.last_mut() {
65                    if prev.status == TurnStatus::Active {
66                        prev.msg_count = i - prev.start_idx;
67                        prev.status = TurnStatus::Completed;
68                    }
69                }
70                // Start a new turn
71                tracker.turns.push(Turn {
72                    start_idx: i,
73                    msg_count: 1,
74                    status: TurnStatus::Active,
75                    summary: None,
76                });
77            } else if let Some(current) = tracker.turns.last_mut() {
78                current.msg_count = i - current.start_idx + 1;
79            }
80        }
81        // Mark the last turn as completed if it ends with assistant text
82        // (no pending tool calls). For simplicity on rebuild, mark all as Completed
83        // except the very last one which stays Active (might still be in progress).
84        let len = tracker.turns.len();
85        if len > 1 {
86            for turn in &mut tracker.turns[..len - 1] {
87                turn.status = TurnStatus::Completed;
88            }
89        }
90        tracker
91    }
92
93    /// Notify that a new user message was added at `msg_idx`.
94    /// Closes the previous turn and opens a new Active turn.
95    ///
96    /// ── SAFETY INVARIANT ──
97    /// This method assumes msg_idx >= prev.start_idx (always true when messages are
98    /// added sequentially). However, after compression, this invariant could be violated
99    /// if Turn indices are corrupted. We now defend against this by clamping the result.
100    pub fn on_user_message(&mut self, msg_idx: usize) {
101        // Close previous active turn
102        if let Some(prev) = self.turns.last_mut() {
103            if prev.status == TurnStatus::Active {
104                // DEFENSIVE: Guard against underflow from compression bugs.
105                // Use saturating_sub to safely clamp msg_count to 0 if msg_idx < prev.start_idx.
106                // This prevents panic and maintains internal consistency.
107                prev.msg_count = msg_idx.saturating_sub(prev.start_idx);
108                prev.status = TurnStatus::Completed;
109            }
110        }
111        self.turns.push(Turn {
112            start_idx: msg_idx,
113            msg_count: 1,
114            status: TurnStatus::Active,
115            summary: None,
116        });
117    }
118
119    /// Notify that a message was appended (assistant text, tool call, tool result).
120    /// Extends the current active turn's msg_count.
121    pub fn on_message_added(&mut self, msg_idx: usize) {
122        if let Some(current) = self.turns.last_mut() {
123            if current.status == TurnStatus::Active {
124                current.msg_count = msg_idx - current.start_idx + 1;
125            }
126        }
127    }
128
129    /// Mark the current (last) turn as completed.
130    pub fn complete_current(&mut self) {
131        if let Some(current) = self.turns.last_mut() {
132            if current.status == TurnStatus::Active {
133                current.status = TurnStatus::Completed;
134            }
135        }
136    }
137
138    /// Get the current (last) active turn, if any.
139    pub fn active_turn(&self) -> Option<&Turn> {
140        self.turns.last().filter(|t| t.status == TurnStatus::Active)
141    }
142
143    /// Number of completed turns (available for summarization).
144    pub fn completed_count(&self) -> usize {
145        self.turns
146            .iter()
147            .filter(|t| t.status == TurnStatus::Completed)
148            .count()
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::conversation::message::{Message, Role};
156
157    #[test]
158    fn test_rebuild_empty() {
159        let tracker = TurnTracker::rebuild(&[]);
160        assert!(tracker.turns.is_empty());
161    }
162
163    #[test]
164    fn test_rebuild_single_turn() {
165        let messages = vec![
166            Message::new(Role::User, "hello"),
167            Message::new(Role::Assistant, "hi there"),
168        ];
169        let tracker = TurnTracker::rebuild(&messages);
170        assert_eq!(tracker.turns.len(), 1);
171        assert_eq!(tracker.turns[0].start_idx, 0);
172        assert_eq!(tracker.turns[0].msg_count, 2);
173        // Single turn stays Active (might still be in progress)
174        assert_eq!(tracker.turns[0].status, TurnStatus::Active);
175    }
176
177    #[test]
178    fn test_rebuild_multi_turn() {
179        let messages = vec![
180            Message::new(Role::User, "task 1"),
181            Message::new(Role::Assistant, "done 1"),
182            Message::new(Role::User, "task 2"),
183            Message::new(Role::Assistant, "done 2"),
184            Message::new(Role::User, "task 3"),
185        ];
186        let tracker = TurnTracker::rebuild(&messages);
187        assert_eq!(tracker.turns.len(), 3);
188
189        assert_eq!(tracker.turns[0].start_idx, 0);
190        assert_eq!(tracker.turns[0].msg_count, 2);
191        assert_eq!(tracker.turns[0].status, TurnStatus::Completed);
192
193        assert_eq!(tracker.turns[1].start_idx, 2);
194        assert_eq!(tracker.turns[1].msg_count, 2);
195        assert_eq!(tracker.turns[1].status, TurnStatus::Completed);
196
197        assert_eq!(tracker.turns[2].start_idx, 4);
198        assert_eq!(tracker.turns[2].msg_count, 1);
199        assert_eq!(tracker.turns[2].status, TurnStatus::Active);
200    }
201
202    #[test]
203    fn test_on_user_message_closes_previous() {
204        let mut tracker = TurnTracker::new();
205        tracker.on_user_message(0);
206        assert_eq!(tracker.turns.len(), 1);
207        assert_eq!(tracker.turns[0].status, TurnStatus::Active);
208
209        tracker.on_message_added(1); // assistant response
210        tracker.on_message_added(2); // tool result
211
212        tracker.on_user_message(3); // new turn
213        assert_eq!(tracker.turns.len(), 2);
214        assert_eq!(tracker.turns[0].status, TurnStatus::Completed);
215        assert_eq!(tracker.turns[0].msg_count, 3);
216        assert_eq!(tracker.turns[1].status, TurnStatus::Active);
217        assert_eq!(tracker.turns[1].start_idx, 3);
218    }
219
220    #[test]
221    fn test_complete_current() {
222        let mut tracker = TurnTracker::new();
223        tracker.on_user_message(0);
224        tracker.on_message_added(1);
225        tracker.complete_current();
226        assert_eq!(tracker.turns[0].status, TurnStatus::Completed);
227    }
228
229    #[test]
230    fn test_completed_count() {
231        let mut tracker = TurnTracker::new();
232        tracker.on_user_message(0);
233        tracker.on_message_added(1);
234        assert_eq!(tracker.completed_count(), 0);
235
236        tracker.complete_current();
237        assert_eq!(tracker.completed_count(), 1);
238
239        tracker.on_user_message(2);
240        assert_eq!(tracker.completed_count(), 1);
241    }
242
243    /// Invariant: after `rebuild`, no turn's end_idx exceeds the
244    /// message vec length. Verifies the agent's retry-path fix: after
245    /// `messages.truncate(n)` followed by
246    /// `TurnTracker::rebuild(&messages)`, the tracker is internally
247    /// consistent with the surviving messages (unlike the previous
248    /// behavior where the tracker still pointed past the end).
249    #[test]
250    fn test_rebuild_matches_truncated_messages_length() {
251        use super::super::message::MessageContent;
252        use crate::tool::{ToolCall, ToolResult};
253
254        // Build a 12-msg conversation: 3 turns × 4 msgs each
255        // (user, atc, tool, assistant).
256        let mut msgs: Vec<Message> = Vec::new();
257        for t in 0..3 {
258            msgs.push(Message::new(Role::User, &format!("task {}", t)));
259            msgs.push(Message {
260                role: Role::Assistant,
261                content: MessageContent::AssistantWithToolCalls {
262                    text: Some("working".into()),
263                    tool_calls: vec![ToolCall {
264                        id: format!("c{}", t),
265                        name: "bash".into(),
266                        arguments: "{}".into(),
267                    }],
268                    reasoning_content: None,
269                    thinking_blocks: Vec::new(),
270                },
271            });
272            msgs.push(Message {
273                role: Role::Tool,
274                content: MessageContent::ToolResult(ToolResult {
275                    call_id: format!("c{}", t),
276                    output: "ok".into(),
277                    success: true,
278                }),
279            });
280            msgs.push(Message::new(Role::Assistant, &format!("done {}", t)));
281        }
282        assert_eq!(msgs.len(), 12);
283
284        // Simulate the agent's overflow retry: truncate 4 msgs.
285        msgs.truncate(msgs.len() - 4);
286        let tracker = TurnTracker::rebuild(&msgs);
287
288        for (i, t) in tracker.turns.iter().enumerate() {
289            assert!(
290                t.end_idx() <= msgs.len(),
291                "turn {} end_idx {} exceeds messages.len() {}",
292                i,
293                t.end_idx(),
294                msgs.len(),
295            );
296        }
297    }
298}