1use futures::stream::{self, BoxStream, StreamExt};
15use tokio::sync::broadcast;
16use tokio_stream::wrappers::BroadcastStream;
17
18use crate::checkpoint::CheckpointManager;
19use crate::types::SessionEvent;
20
21pub fn get_seq(event: &SessionEvent) -> u64 {
36 match event {
37 SessionEvent::Message { seq, .. }
38 | SessionEvent::ToolUse { seq, .. }
39 | SessionEvent::CustomToolUse { seq, .. }
40 | SessionEvent::McpToolUse { seq, .. }
41 | SessionEvent::StatusRunning { seq, .. }
42 | SessionEvent::StatusIdle { seq, .. }
43 | SessionEvent::Error { seq, .. } => *seq,
44 }
45}
46
47pub fn create_event_stream(
79 checkpoint: &CheckpointManager,
80 broadcast_rx: broadcast::Receiver<SessionEvent>,
81 from_seq: Option<u64>,
82) -> BoxStream<'static, SessionEvent> {
83 let live_stream = BroadcastStream::new(broadcast_rx).filter_map(|result| async move {
85 result.ok() });
87
88 match from_seq {
89 None => {
90 Box::pin(live_stream)
92 }
93 Some(k) => {
94 let historical: Vec<SessionEvent> =
96 checkpoint.events().iter().filter(|event| get_seq(event) > k).cloned().collect();
97
98 let replay_stream = stream::iter(historical);
99 Box::pin(replay_stream.chain(live_stream))
100 }
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::checkpoint::RunState;
108 use crate::types::{ContentBlock, SessionStatus};
109 use futures::StreamExt;
110 use serde_json::json;
111
112 fn message_event(seq: u64) -> SessionEvent {
114 SessionEvent::Message {
115 content: vec![ContentBlock::Text { text: format!("msg_{seq}") }],
116 seq,
117 }
118 }
119
120 #[test]
121 fn test_get_seq_message() {
122 let event = SessionEvent::Message { content: vec![], seq: 10 };
123 assert_eq!(get_seq(&event), 10);
124 }
125
126 #[test]
127 fn test_get_seq_tool_use() {
128 let event = SessionEvent::ToolUse {
129 tool_use_id: "tu_1".to_string(),
130 name: "search".to_string(),
131 input: json!({}),
132 seq: 5,
133 };
134 assert_eq!(get_seq(&event), 5);
135 }
136
137 #[test]
138 fn test_get_seq_custom_tool_use() {
139 let event = SessionEvent::CustomToolUse {
140 custom_tool_use_id: "ctu_1".to_string(),
141 name: "deploy".to_string(),
142 input: json!({}),
143 seq: 7,
144 };
145 assert_eq!(get_seq(&event), 7);
146 }
147
148 #[test]
149 fn test_get_seq_mcp_tool_use() {
150 let event = SessionEvent::McpToolUse {
151 tool_use_id: "mcp_1".to_string(),
152 name: "read".to_string(),
153 input: json!({}),
154 seq: 3,
155 };
156 assert_eq!(get_seq(&event), 3);
157 }
158
159 #[test]
160 fn test_get_seq_status_running() {
161 let event = SessionEvent::StatusRunning { seq: 0 };
162 assert_eq!(get_seq(&event), 0);
163 }
164
165 #[test]
166 fn test_get_seq_status_idle() {
167 let event = SessionEvent::StatusIdle { seq: 99, stop_reason: None, usage: None };
168 assert_eq!(get_seq(&event), 99);
169 }
170
171 #[test]
172 fn test_get_seq_error() {
173 let event =
174 SessionEvent::Error { code: "err".to_string(), message: "oops".to_string(), seq: 42 };
175 assert_eq!(get_seq(&event), 42);
176 }
177
178 #[tokio::test]
179 async fn test_replay_with_from_seq_filters_correctly() {
180 let mut checkpoint = CheckpointManager::new("sess_1".to_string());
181
182 for seq in 1..=5 {
184 let event = message_event(seq);
185 let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
186 checkpoint.checkpoint(event, state);
187 }
188
189 let (tx, rx) = broadcast::channel::<SessionEvent>(16);
190 drop(tx); let stream = create_event_stream(&checkpoint, rx, Some(3));
194 let events: Vec<SessionEvent> = stream.collect().await;
195
196 assert_eq!(events.len(), 2);
197 assert_eq!(get_seq(&events[0]), 4);
198 assert_eq!(get_seq(&events[1]), 5);
199 }
200
201 #[tokio::test]
202 async fn test_replay_with_from_seq_zero_returns_all() {
203 let mut checkpoint = CheckpointManager::new("sess_2".to_string());
204
205 for seq in 1..=3 {
206 let event = message_event(seq);
207 let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
208 checkpoint.checkpoint(event, state);
209 }
210
211 let (tx, rx) = broadcast::channel::<SessionEvent>(16);
212 drop(tx);
213
214 let stream = create_event_stream(&checkpoint, rx, Some(0));
216 let events: Vec<SessionEvent> = stream.collect().await;
217
218 assert_eq!(events.len(), 3);
219 assert_eq!(get_seq(&events[0]), 1);
220 assert_eq!(get_seq(&events[1]), 2);
221 assert_eq!(get_seq(&events[2]), 3);
222 }
223
224 #[tokio::test]
225 async fn test_live_only_mode() {
226 let checkpoint = CheckpointManager::new("sess_3".to_string());
227
228 let (tx, rx) = broadcast::channel::<SessionEvent>(16);
229
230 let stream = create_event_stream(&checkpoint, rx, None);
234
235 tx.send(message_event(10)).unwrap();
237 tx.send(message_event(11)).unwrap();
238 drop(tx); let events: Vec<SessionEvent> = stream.collect().await;
241
242 assert_eq!(events.len(), 2);
243 assert_eq!(get_seq(&events[0]), 10);
244 assert_eq!(get_seq(&events[1]), 11);
245 }
246
247 #[tokio::test]
248 async fn test_combined_replay_plus_live() {
249 let mut checkpoint = CheckpointManager::new("sess_4".to_string());
250
251 for seq in 1..=3 {
253 let event = message_event(seq);
254 let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
255 checkpoint.checkpoint(event, state);
256 }
257
258 let (tx, rx) = broadcast::channel::<SessionEvent>(16);
259
260 let stream = create_event_stream(&checkpoint, rx, Some(2));
262
263 tx.send(message_event(4)).unwrap();
265 tx.send(message_event(5)).unwrap();
266 drop(tx); let events: Vec<SessionEvent> = stream.collect().await;
269
270 assert_eq!(events.len(), 3);
272 assert_eq!(get_seq(&events[0]), 3);
273 assert_eq!(get_seq(&events[1]), 4);
274 assert_eq!(get_seq(&events[2]), 5);
275 }
276
277 #[tokio::test]
278 async fn test_replay_with_from_seq_beyond_all_events() {
279 let mut checkpoint = CheckpointManager::new("sess_5".to_string());
280
281 for seq in 1..=3 {
282 let event = message_event(seq);
283 let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
284 checkpoint.checkpoint(event, state);
285 }
286
287 let (tx, rx) = broadcast::channel::<SessionEvent>(16);
288 drop(tx);
289
290 let stream = create_event_stream(&checkpoint, rx, Some(100));
292 let events: Vec<SessionEvent> = stream.collect().await;
293
294 assert_eq!(events.len(), 0);
295 }
296
297 #[tokio::test]
298 async fn test_replay_empty_checkpoint_with_live() {
299 let checkpoint = CheckpointManager::new("sess_6".to_string());
300
301 let (tx, rx) = broadcast::channel::<SessionEvent>(16);
302
303 let stream = create_event_stream(&checkpoint, rx, Some(0));
305
306 tx.send(message_event(1)).unwrap();
307 drop(tx);
308
309 let events: Vec<SessionEvent> = stream.collect().await;
310
311 assert_eq!(events.len(), 1);
312 assert_eq!(get_seq(&events[0]), 1);
313 }
314}