use anyhow::anyhow;
use futures_util::StreamExt;
use reqwest_eventsource::{EventSource, retry::Never};
use crate::{
Result,
protocol::Event,
retry::{RetryOptions, RetryState},
url::build_sse_request,
};
use super::{MilkyTransport, MilkyTransportEvent, is_closed, sleep_or_closed};
pub(crate) fn spawn_sse_transport(
client: reqwest::Client,
url: reqwest::Url,
reconnect: RetryOptions,
) -> MilkyTransport {
let (transport, sender, mut close_receiver) = MilkyTransport::channel();
tokio::spawn(async move {
let mut retry_state = RetryState::new(reconnect);
let mut last_event_id: Option<String> = None;
let mut connected_once = false;
'reconnect: loop {
if is_closed(&close_receiver) {
return;
}
let mut source = match EventSource::new(build_sse_request(
&client,
url.clone(),
last_event_id.as_deref(),
)) {
Ok(mut source) => {
source.set_retry_policy(Box::new(Never));
source
}
Err(error) => {
let _ = sender.send(Err(error.into()));
return;
}
};
loop {
tokio::select! {
changed = close_receiver.changed() => {
if changed.is_ok() && is_closed(&close_receiver) {
return;
}
}
event = source.next() => match event {
Some(Ok(reqwest_eventsource::Event::Open)) => {
retry_state.reset();
let transport_event = if connected_once {
MilkyTransportEvent::Reconnected
} else {
MilkyTransportEvent::Open
};
connected_once = true;
if sender.send(Ok(transport_event)).is_err() {
return;
}
}
Some(Ok(reqwest_eventsource::Event::Message(content))) => {
if !content.id.is_empty() {
last_event_id = Some(content.id.clone());
}
if let Some(delay) = content.retry {
retry_state.override_delay(delay);
}
if sender
.send(parse_message_event(&content.event, &content.data))
.is_err()
{
return;
}
}
Some(Err(error)) if is_recoverable_sse_error(&error) => {
let Some(decision) = retry_state.next() else {
let _ = sender.send(Err(error.into()));
return;
};
if sender
.send(Ok(MilkyTransportEvent::Reconnecting {
attempt: decision.attempt,
next_delay: decision.delay,
}))
.is_err()
{
return;
}
if sleep_or_closed(decision.delay, &mut close_receiver).await {
return;
}
continue 'reconnect;
}
Some(Err(error)) => {
let _ = sender.send(Err(error.into()));
return;
}
None => return,
}
}
}
}
});
transport
}
fn parse_message_event(event_name: &str, data: &str) -> Result<MilkyTransportEvent> {
if event_name != "milky_event" {
return Err(anyhow!("Unexpected SSE event type: {event_name}"));
}
let event: Event = serde_json::from_str(data)?;
Ok(MilkyTransportEvent::Push(event))
}
fn is_recoverable_sse_error(error: &reqwest_eventsource::Error) -> bool {
matches!(
error,
reqwest_eventsource::Error::Transport(_) | reqwest_eventsource::Error::StreamEnded
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::EventBotOfflineData;
#[test]
fn parses_milky_push_events() {
let event = parse_message_event(
"milky_event",
r#"{"event_type":"bot_offline","time":1712740200,"self_id":10001,"data":{"reason":"offline"}}"#,
)
.unwrap();
assert_eq!(
event,
MilkyTransportEvent::Push(Event::BotOffline {
time: 1712740200,
self_id: 10001,
data: EventBotOfflineData {
reason: "offline".to_string(),
},
})
);
}
#[test]
fn rejects_non_milky_message_types() {
let error = parse_message_event("heartbeat", "{}").unwrap_err();
assert!(error.to_string().contains("Unexpected SSE event type"));
}
#[test]
fn rejects_invalid_milky_payloads() {
assert!(parse_message_event("milky_event", "{").is_err());
}
}