atomcode_core/conversation/
turn.rs1use super::message::{Message, Role};
9
10#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
12pub enum TurnStatus {
13 Active,
15 Completed,
17 Summarized,
19}
20
21#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
23pub struct Turn {
24 pub start_idx: usize,
26 pub msg_count: usize,
28 pub status: TurnStatus,
30 pub summary: Option<String>,
34}
35
36impl Turn {
37 pub fn end_idx(&self) -> usize {
39 self.start_idx + self.msg_count
40 }
41}
42
43#[derive(Debug, Clone, Default)]
48pub struct TurnTracker {
49 pub turns: Vec<Turn>,
50}
51
52impl TurnTracker {
53 pub fn new() -> Self {
54 Self { turns: Vec::new() }
55 }
56
57 pub fn rebuild(messages: &[Message]) -> Self {
60 let mut tracker = Self::new();
61 for (i, msg) in messages.iter().enumerate() {
62 if matches!(msg.role, Role::User) {
63 if let Some(prev) = tracker.turns.last_mut() {
65 if prev.status == TurnStatus::Active {
66 prev.msg_count = i - prev.start_idx;
67 prev.status = TurnStatus::Completed;
68 }
69 }
70 tracker.turns.push(Turn {
72 start_idx: i,
73 msg_count: 1,
74 status: TurnStatus::Active,
75 summary: None,
76 });
77 } else if let Some(current) = tracker.turns.last_mut() {
78 current.msg_count = i - current.start_idx + 1;
79 }
80 }
81 let len = tracker.turns.len();
85 if len > 1 {
86 for turn in &mut tracker.turns[..len - 1] {
87 turn.status = TurnStatus::Completed;
88 }
89 }
90 tracker
91 }
92
93 pub fn on_user_message(&mut self, msg_idx: usize) {
101 if let Some(prev) = self.turns.last_mut() {
103 if prev.status == TurnStatus::Active {
104 prev.msg_count = msg_idx.saturating_sub(prev.start_idx);
108 prev.status = TurnStatus::Completed;
109 }
110 }
111 self.turns.push(Turn {
112 start_idx: msg_idx,
113 msg_count: 1,
114 status: TurnStatus::Active,
115 summary: None,
116 });
117 }
118
119 pub fn on_message_added(&mut self, msg_idx: usize) {
122 if let Some(current) = self.turns.last_mut() {
123 if current.status == TurnStatus::Active {
124 current.msg_count = msg_idx - current.start_idx + 1;
125 }
126 }
127 }
128
129 pub fn complete_current(&mut self) {
131 if let Some(current) = self.turns.last_mut() {
132 if current.status == TurnStatus::Active {
133 current.status = TurnStatus::Completed;
134 }
135 }
136 }
137
138 pub fn active_turn(&self) -> Option<&Turn> {
140 self.turns.last().filter(|t| t.status == TurnStatus::Active)
141 }
142
143 pub fn completed_count(&self) -> usize {
145 self.turns
146 .iter()
147 .filter(|t| t.status == TurnStatus::Completed)
148 .count()
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::conversation::message::{Message, Role};
156
157 #[test]
158 fn test_rebuild_empty() {
159 let tracker = TurnTracker::rebuild(&[]);
160 assert!(tracker.turns.is_empty());
161 }
162
163 #[test]
164 fn test_rebuild_single_turn() {
165 let messages = vec![
166 Message::new(Role::User, "hello"),
167 Message::new(Role::Assistant, "hi there"),
168 ];
169 let tracker = TurnTracker::rebuild(&messages);
170 assert_eq!(tracker.turns.len(), 1);
171 assert_eq!(tracker.turns[0].start_idx, 0);
172 assert_eq!(tracker.turns[0].msg_count, 2);
173 assert_eq!(tracker.turns[0].status, TurnStatus::Active);
175 }
176
177 #[test]
178 fn test_rebuild_multi_turn() {
179 let messages = vec![
180 Message::new(Role::User, "task 1"),
181 Message::new(Role::Assistant, "done 1"),
182 Message::new(Role::User, "task 2"),
183 Message::new(Role::Assistant, "done 2"),
184 Message::new(Role::User, "task 3"),
185 ];
186 let tracker = TurnTracker::rebuild(&messages);
187 assert_eq!(tracker.turns.len(), 3);
188
189 assert_eq!(tracker.turns[0].start_idx, 0);
190 assert_eq!(tracker.turns[0].msg_count, 2);
191 assert_eq!(tracker.turns[0].status, TurnStatus::Completed);
192
193 assert_eq!(tracker.turns[1].start_idx, 2);
194 assert_eq!(tracker.turns[1].msg_count, 2);
195 assert_eq!(tracker.turns[1].status, TurnStatus::Completed);
196
197 assert_eq!(tracker.turns[2].start_idx, 4);
198 assert_eq!(tracker.turns[2].msg_count, 1);
199 assert_eq!(tracker.turns[2].status, TurnStatus::Active);
200 }
201
202 #[test]
203 fn test_on_user_message_closes_previous() {
204 let mut tracker = TurnTracker::new();
205 tracker.on_user_message(0);
206 assert_eq!(tracker.turns.len(), 1);
207 assert_eq!(tracker.turns[0].status, TurnStatus::Active);
208
209 tracker.on_message_added(1); tracker.on_message_added(2); tracker.on_user_message(3); assert_eq!(tracker.turns.len(), 2);
214 assert_eq!(tracker.turns[0].status, TurnStatus::Completed);
215 assert_eq!(tracker.turns[0].msg_count, 3);
216 assert_eq!(tracker.turns[1].status, TurnStatus::Active);
217 assert_eq!(tracker.turns[1].start_idx, 3);
218 }
219
220 #[test]
221 fn test_complete_current() {
222 let mut tracker = TurnTracker::new();
223 tracker.on_user_message(0);
224 tracker.on_message_added(1);
225 tracker.complete_current();
226 assert_eq!(tracker.turns[0].status, TurnStatus::Completed);
227 }
228
229 #[test]
230 fn test_completed_count() {
231 let mut tracker = TurnTracker::new();
232 tracker.on_user_message(0);
233 tracker.on_message_added(1);
234 assert_eq!(tracker.completed_count(), 0);
235
236 tracker.complete_current();
237 assert_eq!(tracker.completed_count(), 1);
238
239 tracker.on_user_message(2);
240 assert_eq!(tracker.completed_count(), 1);
241 }
242
243 #[test]
250 fn test_rebuild_matches_truncated_messages_length() {
251 use super::super::message::MessageContent;
252 use crate::tool::{ToolCall, ToolResult};
253
254 let mut msgs: Vec<Message> = Vec::new();
257 for t in 0..3 {
258 msgs.push(Message::new(Role::User, &format!("task {}", t)));
259 msgs.push(Message {
260 role: Role::Assistant,
261 content: MessageContent::AssistantWithToolCalls {
262 text: Some("working".into()),
263 tool_calls: vec![ToolCall {
264 id: format!("c{}", t),
265 name: "bash".into(),
266 arguments: "{}".into(),
267 }],
268 reasoning_content: None,
269 thinking_blocks: Vec::new(),
270 },
271 });
272 msgs.push(Message {
273 role: Role::Tool,
274 content: MessageContent::ToolResult(ToolResult {
275 call_id: format!("c{}", t),
276 output: "ok".into(),
277 success: true,
278 }),
279 });
280 msgs.push(Message::new(Role::Assistant, &format!("done {}", t)));
281 }
282 assert_eq!(msgs.len(), 12);
283
284 msgs.truncate(msgs.len() - 4);
286 let tracker = TurnTracker::rebuild(&msgs);
287
288 for (i, t) in tracker.turns.iter().enumerate() {
289 assert!(
290 t.end_idx() <= msgs.len(),
291 "turn {} end_idx {} exceeds messages.len() {}",
292 i,
293 t.end_idx(),
294 msgs.len(),
295 );
296 }
297 }
298}