use futures::stream::{self, BoxStream, StreamExt};
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
use crate::checkpoint::CheckpointManager;
use crate::types::SessionEvent;
pub fn get_seq(event: &SessionEvent) -> u64 {
match event {
SessionEvent::Message { seq, .. }
| SessionEvent::ToolUse { seq, .. }
| SessionEvent::CustomToolUse { seq, .. }
| SessionEvent::McpToolUse { seq, .. }
| SessionEvent::StatusRunning { seq, .. }
| SessionEvent::StatusIdle { seq, .. }
| SessionEvent::Error { seq, .. } => *seq,
}
}
pub fn create_event_stream(
checkpoint: &CheckpointManager,
broadcast_rx: broadcast::Receiver<SessionEvent>,
from_seq: Option<u64>,
) -> BoxStream<'static, SessionEvent> {
let live_stream = BroadcastStream::new(broadcast_rx).filter_map(|result| async move {
result.ok() });
match from_seq {
None => {
Box::pin(live_stream)
}
Some(k) => {
let historical: Vec<SessionEvent> =
checkpoint.events().iter().filter(|event| get_seq(event) > k).cloned().collect();
let replay_stream = stream::iter(historical);
Box::pin(replay_stream.chain(live_stream))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::checkpoint::RunState;
use crate::types::{ContentBlock, SessionStatus};
use futures::StreamExt;
use serde_json::json;
fn message_event(seq: u64) -> SessionEvent {
SessionEvent::Message {
content: vec![ContentBlock::Text { text: format!("msg_{seq}") }],
seq,
}
}
#[test]
fn test_get_seq_message() {
let event = SessionEvent::Message { content: vec![], seq: 10 };
assert_eq!(get_seq(&event), 10);
}
#[test]
fn test_get_seq_tool_use() {
let event = SessionEvent::ToolUse {
tool_use_id: "tu_1".to_string(),
name: "search".to_string(),
input: json!({}),
seq: 5,
};
assert_eq!(get_seq(&event), 5);
}
#[test]
fn test_get_seq_custom_tool_use() {
let event = SessionEvent::CustomToolUse {
custom_tool_use_id: "ctu_1".to_string(),
name: "deploy".to_string(),
input: json!({}),
seq: 7,
};
assert_eq!(get_seq(&event), 7);
}
#[test]
fn test_get_seq_mcp_tool_use() {
let event = SessionEvent::McpToolUse {
tool_use_id: "mcp_1".to_string(),
name: "read".to_string(),
input: json!({}),
seq: 3,
};
assert_eq!(get_seq(&event), 3);
}
#[test]
fn test_get_seq_status_running() {
let event = SessionEvent::StatusRunning { seq: 0 };
assert_eq!(get_seq(&event), 0);
}
#[test]
fn test_get_seq_status_idle() {
let event = SessionEvent::StatusIdle { seq: 99, stop_reason: None, usage: None };
assert_eq!(get_seq(&event), 99);
}
#[test]
fn test_get_seq_error() {
let event =
SessionEvent::Error { code: "err".to_string(), message: "oops".to_string(), seq: 42 };
assert_eq!(get_seq(&event), 42);
}
#[tokio::test]
async fn test_replay_with_from_seq_filters_correctly() {
let mut checkpoint = CheckpointManager::new("sess_1".to_string());
for seq in 1..=5 {
let event = message_event(seq);
let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
checkpoint.checkpoint(event, state);
}
let (tx, rx) = broadcast::channel::<SessionEvent>(16);
drop(tx);
let stream = create_event_stream(&checkpoint, rx, Some(3));
let events: Vec<SessionEvent> = stream.collect().await;
assert_eq!(events.len(), 2);
assert_eq!(get_seq(&events[0]), 4);
assert_eq!(get_seq(&events[1]), 5);
}
#[tokio::test]
async fn test_replay_with_from_seq_zero_returns_all() {
let mut checkpoint = CheckpointManager::new("sess_2".to_string());
for seq in 1..=3 {
let event = message_event(seq);
let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
checkpoint.checkpoint(event, state);
}
let (tx, rx) = broadcast::channel::<SessionEvent>(16);
drop(tx);
let stream = create_event_stream(&checkpoint, rx, Some(0));
let events: Vec<SessionEvent> = stream.collect().await;
assert_eq!(events.len(), 3);
assert_eq!(get_seq(&events[0]), 1);
assert_eq!(get_seq(&events[1]), 2);
assert_eq!(get_seq(&events[2]), 3);
}
#[tokio::test]
async fn test_live_only_mode() {
let checkpoint = CheckpointManager::new("sess_3".to_string());
let (tx, rx) = broadcast::channel::<SessionEvent>(16);
let stream = create_event_stream(&checkpoint, rx, None);
tx.send(message_event(10)).unwrap();
tx.send(message_event(11)).unwrap();
drop(tx);
let events: Vec<SessionEvent> = stream.collect().await;
assert_eq!(events.len(), 2);
assert_eq!(get_seq(&events[0]), 10);
assert_eq!(get_seq(&events[1]), 11);
}
#[tokio::test]
async fn test_combined_replay_plus_live() {
let mut checkpoint = CheckpointManager::new("sess_4".to_string());
for seq in 1..=3 {
let event = message_event(seq);
let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
checkpoint.checkpoint(event, state);
}
let (tx, rx) = broadcast::channel::<SessionEvent>(16);
let stream = create_event_stream(&checkpoint, rx, Some(2));
tx.send(message_event(4)).unwrap();
tx.send(message_event(5)).unwrap();
drop(tx);
let events: Vec<SessionEvent> = stream.collect().await;
assert_eq!(events.len(), 3);
assert_eq!(get_seq(&events[0]), 3);
assert_eq!(get_seq(&events[1]), 4);
assert_eq!(get_seq(&events[2]), 5);
}
#[tokio::test]
async fn test_replay_with_from_seq_beyond_all_events() {
let mut checkpoint = CheckpointManager::new("sess_5".to_string());
for seq in 1..=3 {
let event = message_event(seq);
let state = RunState { seq, pending_tool_ids: vec![], status: SessionStatus::Running };
checkpoint.checkpoint(event, state);
}
let (tx, rx) = broadcast::channel::<SessionEvent>(16);
drop(tx);
let stream = create_event_stream(&checkpoint, rx, Some(100));
let events: Vec<SessionEvent> = stream.collect().await;
assert_eq!(events.len(), 0);
}
#[tokio::test]
async fn test_replay_empty_checkpoint_with_live() {
let checkpoint = CheckpointManager::new("sess_6".to_string());
let (tx, rx) = broadcast::channel::<SessionEvent>(16);
let stream = create_event_stream(&checkpoint, rx, Some(0));
tx.send(message_event(1)).unwrap();
drop(tx);
let events: Vec<SessionEvent> = stream.collect().await;
assert_eq!(events.len(), 1);
assert_eq!(get_seq(&events[0]), 1);
}
}