use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use futures_util::SinkExt;
use tokio::sync::{Mutex, mpsc};
use tokio_tungstenite::tungstenite::Message;
use super::event::{GatewayPayload, Opcode};
type WsSink = futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
Message,
>;
#[derive(Debug)]
pub struct HeartbeatState {
pub sequence: AtomicU64,
pub acknowledged: AtomicBool,
}
impl HeartbeatState {
pub fn new() -> Self {
Self {
sequence: AtomicU64::new(0),
acknowledged: AtomicBool::new(true),
}
}
}
impl Default for HeartbeatState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HeartbeatFailure {
Timeout,
SendFailed,
}
pub fn spawn_heartbeat(
interval_ms: u64,
sink: Arc<Mutex<WsSink>>,
state: Arc<HeartbeatState>,
failure_tx: mpsc::UnboundedSender<HeartbeatFailure>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let interval = Duration::from_millis(interval_ms);
loop {
tokio::time::sleep(interval).await;
if !state.acknowledged.load(Ordering::SeqCst) {
tracing::warn!("heartbeat not acknowledged, connection may be dead");
let _ = failure_tx.send(HeartbeatFailure::Timeout);
break;
}
state.acknowledged.store(false, Ordering::SeqCst);
let seq = state.sequence.load(Ordering::SeqCst);
let payload = GatewayPayload {
op: Opcode::Heartbeat,
d: Some(serde_json::Value::Number(seq.into())),
s: None,
t: None,
};
let msg = match serde_json::to_string(&payload) {
Ok(msg) => msg,
Err(error) => {
tracing::error!(?error, "failed to serialize heartbeat");
let _ = failure_tx.send(HeartbeatFailure::SendFailed);
break;
}
};
let mut sink = sink.lock().await;
if sink.send(Message::Text(msg.into())).await.is_err() {
tracing::error!("failed to send heartbeat");
let _ = failure_tx.send(HeartbeatFailure::SendFailed);
break;
}
tracing::trace!(seq, "heartbeat sent");
}
})
}