use futures_util::{SinkExt, StreamExt};
use serde::Serialize;
use tokio::sync::mpsc;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use tungstenite::protocol::frame::Utf8Bytes;
use url::Url;
use crate::types::RawEvent;
use crate::{Error, Event, Result};
type WsStream =
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
type WsSink = futures_util::stream::SplitSink<WsStream, Message>;
type WsSource = futures_util::stream::SplitStream<WsStream>;
pub struct EventReceiver {
event_rx: mpsc::Receiver<EventWithAck>,
_ack_tx: mpsc::Sender<u64>,
}
struct EventWithAck {
event: Utf8Bytes,
ack_tx: mpsc::Sender<u64>,
}
struct AckGuard {
id: u64,
ack_tx: Option<mpsc::Sender<u64>>,
}
impl Drop for AckGuard {
fn drop(&mut self) {
if let Some(tx) = self.ack_tx.take() {
let id = self.id;
tokio::spawn(async move {
let _ = tx.send(id).await;
});
}
}
}
pub struct ReceivedEvent {
pub event: Event,
_ack_guard: AckGuard,
}
impl std::ops::Deref for ReceivedEvent {
type Target = Event;
fn deref(&self) -> &Self::Target {
&self.event
}
}
impl EventReceiver {
pub(crate) async fn connect(base_url: &Url, admin_password: Option<&str>) -> Result<Self> {
let mut ws_url = base_url.clone();
match ws_url.scheme() {
"http" => ws_url.set_scheme("ws").unwrap(),
"https" => ws_url.set_scheme("wss").unwrap(),
_ => {}
}
ws_url.set_path("/channel");
if let Some(password) = admin_password {
ws_url
.set_username("admin")
.map_err(|_| Error::InvalidUrl("cannot set username".into()))?;
ws_url
.set_password(Some(password))
.map_err(|_| Error::InvalidUrl("cannot set password".into()))?;
}
let (ws_stream, response) = connect_async(ws_url.as_str())
.await
.map_err(|e| Error::WebSocket(Box::new(e)))?;
if response.status().as_u16() == 400 {
return Err(Error::WebhookModeActive);
}
let (write, read) = ws_stream.split();
let (event_tx, event_rx) = mpsc::channel(100);
let (ack_tx, ack_rx) = mpsc::channel(1000);
let ack_tx_clone = ack_tx.clone();
tokio::spawn(async move {
Self::writer_task(write, ack_rx).await;
});
tokio::spawn(async move {
Self::reader_task(read, event_tx, ack_tx_clone).await;
});
Ok(Self {
event_rx,
_ack_tx: ack_tx,
})
}
pub async fn recv(&mut self) -> Result<ReceivedEvent> {
loop {
match self.event_rx.recv().await {
Some(event_with_ack) => {
let json = event_with_ack.event;
let raw = match serde_json::from_str::<RawEvent>(json.as_str()) {
Ok(raw) => raw,
Err(e) => {
tracing::warn!("Failed to parse event: {}", e);
continue;
}
};
if let Some(event) = raw.into_event(json.clone()) {
let id = event.id();
break Ok(ReceivedEvent {
event,
_ack_guard: AckGuard {
id,
ack_tx: Some(event_with_ack.ack_tx),
},
});
}
}
None => break Err(Error::ChannelClosed),
}
}
}
async fn writer_task(mut write: WsSink, mut ack_rx: mpsc::Receiver<u64>) {
#[derive(Serialize)]
struct AckMessage {
#[serde(rename = "type")]
type_: &'static str,
id: u64,
}
while let Some(id) = ack_rx.recv().await {
let msg = AckMessage { type_: "ack", id };
let json = match serde_json::to_string(&msg) {
Ok(j) => j,
Err(e) => {
tracing::warn!("Failed to serialize ack: {}", e);
continue;
}
};
if let Err(e) = write.send(Message::Text(json.into())).await {
tracing::warn!("Failed to send ack: {}", e);
break;
}
}
}
async fn reader_task(
mut read: WsSource,
event_tx: mpsc::Sender<EventWithAck>,
ack_tx: mpsc::Sender<u64>,
) {
while let Some(msg_result) = read.next().await {
match msg_result {
Ok(Message::Text(event)) => {
let event_with_ack = EventWithAck {
event,
ack_tx: ack_tx.clone(),
};
if event_tx.send(event_with_ack).await.is_err() {
break;
}
}
Ok(Message::Close(_)) => {
break;
}
Ok(_) => {
}
Err(_) => {
break;
}
}
}
}
}