Skip to main content

adk_managed/
replay.rs

1//! Event replay for SSE `Last-Event-ID` reconnection.
2//!
3//! Provides [`create_event_stream`] which creates a unified stream combining
4//! historical events (from checkpoint) and live events (from broadcast).
5//!
6//! # Usage
7//!
8//! - `from_seq = None` → live tail only (subscribe to broadcast)
9//! - `from_seq = Some(k)` → replay all events with `seq > k`, then live tail
10//!
11//! This enables SSE reconnection: the client provides the last `seq` it received
12//! via `Last-Event-ID`, and the stream starts from there without gaps or duplicates.
13
14use futures::stream::{self, BoxStream, StreamExt};
15use tokio::sync::broadcast;
16use tokio_stream::wrappers::BroadcastStream;
17
18use crate::checkpoint::CheckpointManager;
19use crate::types::SessionEvent;
20
21/// Extract the `seq` field from any [`SessionEvent`] variant.
22///
23/// Every `SessionEvent` variant carries a monotonic `seq` field.
24/// This helper provides uniform access regardless of variant.
25///
26/// # Example
27///
28/// ```rust
29/// use adk_managed::replay::get_seq;
30/// use adk_managed::types::SessionEvent;
31///
32/// let event = SessionEvent::StatusRunning { seq: 42 };
33/// assert_eq!(get_seq(&event), 42);
34/// ```
35pub fn get_seq(event: &SessionEvent) -> u64 {
36    match event {
37        SessionEvent::Message { seq, .. }
38        | SessionEvent::ToolUse { seq, .. }
39        | SessionEvent::CustomToolUse { seq, .. }
40        | SessionEvent::McpToolUse { seq, .. }
41        | SessionEvent::StatusRunning { seq, .. }
42        | SessionEvent::StatusIdle { seq, .. }
43        | SessionEvent::Error { seq, .. } => *seq,
44    }
45}
46
47/// Create an event stream that replays historical events then attaches to live broadcast.
48///
49/// - `from_seq = None` → live tail only (subscribe to broadcast)
50/// - `from_seq = Some(k)` → replay all events with `seq > k`, then live tail
51///
52/// The returned stream yields events in order: first any historical events from the
53/// checkpoint log that have `seq > k`, then live events from the broadcast channel.
54///
55/// # Arguments
56///
57/// * `checkpoint` - The checkpoint manager holding the event log
58/// * `broadcast_rx` - A broadcast receiver for live events
59/// * `from_seq` - Optional sequence number; if provided, replay events with seq > this value
60///
61/// # Example
62///
63/// ```rust,ignore
64/// use adk_managed::replay::create_event_stream;
65/// use adk_managed::checkpoint::CheckpointManager;
66/// use tokio::sync::broadcast;
67///
68/// let checkpoint = CheckpointManager::new("session_1".to_string());
69/// let (tx, rx) = broadcast::channel(128);
70///
71/// // Live tail only
72/// let stream = create_event_stream(&checkpoint, rx, None);
73///
74/// // Replay from seq 5 onward
75/// let (_, rx2) = (tx.clone(), tx.subscribe());
76/// let stream = create_event_stream(&checkpoint, rx2, Some(5));
77/// ```
78pub fn create_event_stream(
79    checkpoint: &CheckpointManager,
80    broadcast_rx: broadcast::Receiver<SessionEvent>,
81    from_seq: Option<u64>,
82) -> BoxStream<'static, SessionEvent> {
83    // Convert broadcast receiver to a stream, filtering out lagged errors
84    let live_stream = BroadcastStream::new(broadcast_rx).filter_map(|result| async move {
85        result.ok() // Skip lagged messages
86    });
87
88    match from_seq {
89        None => {
90            // Live tail only
91            Box::pin(live_stream)
92        }
93        Some(k) => {
94            // Replay historical events with seq > k, then chain with live
95            let historical: Vec<SessionEvent> =
96                checkpoint.events().iter().filter(|event| get_seq(event) > k).cloned().collect();
97
98            let replay_stream = stream::iter(historical);
99            Box::pin(replay_stream.chain(live_stream))
100        }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::checkpoint::RunState;
108    use crate::types::{ContentBlock, SessionStatus};
109    use futures::StreamExt;
110    use serde_json::json;
111
112    /// Helper to create a simple message event with a given seq.
113    fn message_event(seq: u64) -> SessionEvent {
114        SessionEvent::Message {
115            content: vec![ContentBlock::Text { text: format!("msg_{seq}") }],
116            seq,
117        }
118    }
119
120    #[test]
121    fn test_get_seq_message() {
122        let event = SessionEvent::Message { content: vec![], seq: 10 };
123        assert_eq!(get_seq(&event), 10);
124    }
125
126    #[test]
127    fn test_get_seq_tool_use() {
128        let event = SessionEvent::ToolUse {
129            tool_use_id: "tu_1".to_string(),
130            name: "search".to_string(),
131            input: json!({}),
132            seq: 5,
133        };
134        assert_eq!(get_seq(&event), 5);
135    }
136
137    #[test]
138    fn test_get_seq_custom_tool_use() {
139        let event = SessionEvent::CustomToolUse {
140            custom_tool_use_id: "ctu_1".to_string(),
141            name: "deploy".to_string(),
142            input: json!({}),
143            seq: 7,
144        };
145        assert_eq!(get_seq(&event), 7);
146    }
147
148    #[test]
149    fn test_get_seq_mcp_tool_use() {
150        let event = SessionEvent::McpToolUse {
151            tool_use_id: "mcp_1".to_string(),
152            name: "read".to_string(),
153            input: json!({}),
154            seq: 3,
155        };
156        assert_eq!(get_seq(&event), 3);
157    }
158
159    #[test]
160    fn test_get_seq_status_running() {
161        let event = SessionEvent::StatusRunning { seq: 0 };
162        assert_eq!(get_seq(&event), 0);
163    }
164
165    #[test]
166    fn test_get_seq_status_idle() {
167        let event = SessionEvent::StatusIdle { seq: 99, stop_reason: None, usage: None };
168        assert_eq!(get_seq(&event), 99);
169    }
170
171    #[test]
172    fn test_get_seq_error() {
173        let event =
174            SessionEvent::Error { code: "err".to_string(), message: "oops".to_string(), seq: 42 };
175        assert_eq!(get_seq(&event), 42);
176    }
177
178    #[tokio::test]
179    async fn test_replay_with_from_seq_filters_correctly() {
180        let mut checkpoint = CheckpointManager::new("sess_1".to_string());
181
182        // Store events with seq 1..5
183        for seq in 1..=5 {
184            let event = message_event(seq);
185            let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
186            checkpoint.checkpoint(event, state);
187        }
188
189        let (tx, rx) = broadcast::channel::<SessionEvent>(16);
190        drop(tx); // Close the sender so live stream ends
191
192        // Replay from seq 3 → should get events with seq 4, 5
193        let stream = create_event_stream(&checkpoint, rx, Some(3));
194        let events: Vec<SessionEvent> = stream.collect().await;
195
196        assert_eq!(events.len(), 2);
197        assert_eq!(get_seq(&events[0]), 4);
198        assert_eq!(get_seq(&events[1]), 5);
199    }
200
201    #[tokio::test]
202    async fn test_replay_with_from_seq_zero_returns_all() {
203        let mut checkpoint = CheckpointManager::new("sess_2".to_string());
204
205        for seq in 1..=3 {
206            let event = message_event(seq);
207            let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
208            checkpoint.checkpoint(event, state);
209        }
210
211        let (tx, rx) = broadcast::channel::<SessionEvent>(16);
212        drop(tx);
213
214        // from_seq=0 → all events (seq > 0)
215        let stream = create_event_stream(&checkpoint, rx, Some(0));
216        let events: Vec<SessionEvent> = stream.collect().await;
217
218        assert_eq!(events.len(), 3);
219        assert_eq!(get_seq(&events[0]), 1);
220        assert_eq!(get_seq(&events[1]), 2);
221        assert_eq!(get_seq(&events[2]), 3);
222    }
223
224    #[tokio::test]
225    async fn test_live_only_mode() {
226        let checkpoint = CheckpointManager::new("sess_3".to_string());
227
228        let (tx, rx) = broadcast::channel::<SessionEvent>(16);
229
230        // Send a live event before creating the stream won't be received
231        // (broadcast only delivers after subscription)
232
233        let stream = create_event_stream(&checkpoint, rx, None);
234
235        // Send live events after stream is created
236        tx.send(message_event(10)).unwrap();
237        tx.send(message_event(11)).unwrap();
238        drop(tx); // End the stream
239
240        let events: Vec<SessionEvent> = stream.collect().await;
241
242        assert_eq!(events.len(), 2);
243        assert_eq!(get_seq(&events[0]), 10);
244        assert_eq!(get_seq(&events[1]), 11);
245    }
246
247    #[tokio::test]
248    async fn test_combined_replay_plus_live() {
249        let mut checkpoint = CheckpointManager::new("sess_4".to_string());
250
251        // Historical events: seq 1, 2, 3
252        for seq in 1..=3 {
253            let event = message_event(seq);
254            let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
255            checkpoint.checkpoint(event, state);
256        }
257
258        let (tx, rx) = broadcast::channel::<SessionEvent>(16);
259
260        // Create stream with replay from seq 2 → historical seq 3
261        let stream = create_event_stream(&checkpoint, rx, Some(2));
262
263        // Send live events
264        tx.send(message_event(4)).unwrap();
265        tx.send(message_event(5)).unwrap();
266        drop(tx); // End live stream
267
268        let events: Vec<SessionEvent> = stream.collect().await;
269
270        // Should get: historical seq=3, then live seq=4, seq=5
271        assert_eq!(events.len(), 3);
272        assert_eq!(get_seq(&events[0]), 3);
273        assert_eq!(get_seq(&events[1]), 4);
274        assert_eq!(get_seq(&events[2]), 5);
275    }
276
277    #[tokio::test]
278    async fn test_replay_with_from_seq_beyond_all_events() {
279        let mut checkpoint = CheckpointManager::new("sess_5".to_string());
280
281        for seq in 1..=3 {
282            let event = message_event(seq);
283            let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
284            checkpoint.checkpoint(event, state);
285        }
286
287        let (tx, rx) = broadcast::channel::<SessionEvent>(16);
288        drop(tx);
289
290        // from_seq=100 → no historical events match (all seq <= 100)
291        let stream = create_event_stream(&checkpoint, rx, Some(100));
292        let events: Vec<SessionEvent> = stream.collect().await;
293
294        assert_eq!(events.len(), 0);
295    }
296
297    #[tokio::test]
298    async fn test_replay_empty_checkpoint_with_live() {
299        let checkpoint = CheckpointManager::new("sess_6".to_string());
300
301        let (tx, rx) = broadcast::channel::<SessionEvent>(16);
302
303        // from_seq=0 with empty checkpoint → just live events
304        let stream = create_event_stream(&checkpoint, rx, Some(0));
305
306        tx.send(message_event(1)).unwrap();
307        drop(tx);
308
309        let events: Vec<SessionEvent> = stream.collect().await;
310
311        assert_eq!(events.len(), 1);
312        assert_eq!(get_seq(&events[0]), 1);
313    }
314}