use anyhow::anyhow;
use futures_util::StreamExt;
use tokio_tungstenite::tungstenite::{Message, handshake::client::Request};
use crate::{
Result,
protocol::Event,
retry::{RetryOptions, RetryState},
};
use super::{MilkyTransport, MilkyTransportEvent, is_closed, sleep_or_closed};
pub(crate) fn spawn_websocket_transport(
request: Request,
reconnect: RetryOptions,
) -> MilkyTransport {
let (transport, sender, mut close_receiver) = MilkyTransport::channel();
tokio::spawn(async move {
let mut retry_state = RetryState::new(reconnect);
let mut connected_once = false;
'reconnect: loop {
if is_closed(&close_receiver) {
return;
}
let connection = tokio::select! {
changed = close_receiver.changed() => {
if changed.is_ok() && is_closed(&close_receiver) {
return;
}
continue;
}
connection = tokio_tungstenite::connect_async(request.clone()) => connection,
};
let (mut source, _) = match connection {
Ok(connection) => connection,
Err(error) => {
let Some(decision) = retry_state.next() else {
let _ = sender.send(Err(error.into()));
return;
};
if sender
.send(Ok(MilkyTransportEvent::Reconnecting {
attempt: decision.attempt,
next_delay: decision.delay,
}))
.is_err()
{
return;
}
if sleep_or_closed(decision.delay, &mut close_receiver).await {
return;
}
continue;
}
};
retry_state.reset();
let transport_event = if connected_once {
MilkyTransportEvent::Reconnected
} else {
MilkyTransportEvent::Open
};
connected_once = true;
if sender.send(Ok(transport_event)).is_err() {
return;
}
loop {
tokio::select! {
changed = close_receiver.changed() => {
if changed.is_ok() && is_closed(&close_receiver) {
return;
}
}
message = source.next() => match message {
Some(Ok(Message::Close(_))) | None => {
let Some(decision) = retry_state.next() else {
return;
};
if sender
.send(Ok(MilkyTransportEvent::Reconnecting {
attempt: decision.attempt,
next_delay: decision.delay,
}))
.is_err()
{
return;
}
if sleep_or_closed(decision.delay, &mut close_receiver).await {
return;
}
continue 'reconnect;
}
Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) | Some(Ok(Message::Frame(_))) => {}
Some(Ok(message)) => {
if sender.send(parse_websocket_message(message)).is_err() {
return;
}
}
Some(Err(error)) => {
let Some(decision) = retry_state.next() else {
let _ = sender.send(Err(error.into()));
return;
};
if sender
.send(Ok(MilkyTransportEvent::Reconnecting {
attempt: decision.attempt,
next_delay: decision.delay,
}))
.is_err()
{
return;
}
if sleep_or_closed(decision.delay, &mut close_receiver).await {
return;
}
continue 'reconnect;
}
}
}
}
}
});
transport
}
fn parse_websocket_message(message: Message) -> Result<MilkyTransportEvent> {
match message {
Message::Text(payload) => {
let event: Event = serde_json::from_str(&payload)?;
Ok(MilkyTransportEvent::Push(event))
}
_ => Err(anyhow!("Unexpected WebSocket message type")),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::EventBotOfflineData;
#[test]
fn parses_websocket_text_messages() {
let message = Message::Text(
r#"{"event_type":"bot_offline","time":1712740200,"self_id":10001,"data":{"reason":"offline"}}"#
.into(),
);
let event = parse_websocket_message(message).unwrap();
assert_eq!(
event,
MilkyTransportEvent::Push(Event::BotOffline {
time: 1712740200,
self_id: 10001,
data: EventBotOfflineData {
reason: "offline".to_string(),
},
})
);
}
#[test]
fn rejects_invalid_websocket_json_payloads() {
let message = Message::Text("{".into());
assert!(parse_websocket_message(message).is_err());
}
#[test]
fn rejects_binary_websocket_messages() {
let message = Message::Binary(vec![1, 2, 3].into());
let error = parse_websocket_message(message).unwrap_err();
assert!(
error
.to_string()
.contains("Unexpected WebSocket message type")
);
}
}