tapped 0.3.1

Rust wrapper for the tap ATProto utility
Documentation
//! WebSocket event channel and receiver.

use futures_util::{SinkExt, StreamExt};
use serde::Serialize;
use tokio::sync::mpsc;
use tokio_tungstenite::{connect_async, tungstenite::Message};
use tungstenite::protocol::frame::Utf8Bytes;
use url::Url;

use crate::types::RawEvent;
use crate::{Error, Event, Result};

type WsStream =
    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>;
type WsSink = futures_util::stream::SplitSink<WsStream, Message>;
type WsSource = futures_util::stream::SplitStream<WsStream>;

/// Receiver for events from a tap WebSocket channel.
///
/// Events are received via the [`recv`](EventReceiver::recv) method.
/// Acknowledgments are sent automatically when events are dropped.
///
/// This type does not implement auto-reconnection. If the connection
/// closes, `recv()` will return an error and you must create a new
/// `EventReceiver` via [`TapClient::channel()`](crate::TapClient::channel).
pub struct EventReceiver {
    event_rx: mpsc::Receiver<EventWithAck>,
    _ack_tx: mpsc::Sender<u64>,
}

struct EventWithAck {
    event: Utf8Bytes,
    ack_tx: mpsc::Sender<u64>,
}

struct AckGuard {
    id: u64,
    ack_tx: Option<mpsc::Sender<u64>>,
}

impl Drop for AckGuard {
    fn drop(&mut self) {
        if let Some(tx) = self.ack_tx.take() {
            // Fire and forget - if the channel is closed, we can't ack anyway
            let id = self.id;
            tokio::spawn(async move {
                let _ = tx.send(id).await;
            });
        }
    }
}

/// Wrapper around Event that includes the ack trigger.
pub struct ReceivedEvent {
    pub event: Event,
    _ack_guard: AckGuard,
}

impl std::ops::Deref for ReceivedEvent {
    type Target = Event;

    fn deref(&self) -> &Self::Target {
        &self.event
    }
}

impl EventReceiver {
    /// Connect to a tap WebSocket channel.
    pub(crate) async fn connect(base_url: &Url, admin_password: Option<&str>) -> Result<Self> {
        let mut ws_url = base_url.clone();
        match ws_url.scheme() {
            "http" => ws_url.set_scheme("ws").unwrap(),
            "https" => ws_url.set_scheme("wss").unwrap(),
            _ => {}
        }
        ws_url.set_path("/channel");

        if let Some(password) = admin_password {
            ws_url
                .set_username("admin")
                .map_err(|_| Error::InvalidUrl("cannot set username".into()))?;
            ws_url
                .set_password(Some(password))
                .map_err(|_| Error::InvalidUrl("cannot set password".into()))?;
        }

        let (ws_stream, response) = connect_async(ws_url.as_str())
            .await
            .map_err(|e| Error::WebSocket(Box::new(e)))?;

        if response.status().as_u16() == 400 {
            return Err(Error::WebhookModeActive);
        }

        let (write, read) = ws_stream.split();

        let (event_tx, event_rx) = mpsc::channel(100);
        let (ack_tx, ack_rx) = mpsc::channel(1000);

        let ack_tx_clone = ack_tx.clone();
        tokio::spawn(async move {
            Self::writer_task(write, ack_rx).await;
        });

        tokio::spawn(async move {
            Self::reader_task(read, event_tx, ack_tx_clone).await;
        });

        Ok(Self {
            event_rx,
            _ack_tx: ack_tx,
        })
    }

    /// Receive the next event.
    ///
    /// Returns the event wrapped in a [`ReceivedEvent`] that automatically
    /// sends an acknowledgment when dropped.
    ///
    /// # Errors
    ///
    /// Returns [`Error::ChannelClosed`] if the WebSocket connection closes.
    pub async fn recv(&mut self) -> Result<ReceivedEvent> {
        loop {
            match self.event_rx.recv().await {
                Some(event_with_ack) => {
                    let json = event_with_ack.event;
                    let raw = match serde_json::from_str::<RawEvent>(json.as_str()) {
                        Ok(raw) => raw,
                        Err(e) => {
                            tracing::warn!("Failed to parse event: {}", e);
                            continue;
                        }
                    };

                    if let Some(event) = raw.into_event(json.clone()) {
                        let id = event.id();
                        break Ok(ReceivedEvent {
                            event,
                            _ack_guard: AckGuard {
                                id,
                                ack_tx: Some(event_with_ack.ack_tx),
                            },
                        });
                    }
                }
                None => break Err(Error::ChannelClosed),
            }
        }
    }

    /// Writer task: sends ack messages to the WebSocket.
    async fn writer_task(mut write: WsSink, mut ack_rx: mpsc::Receiver<u64>) {
        #[derive(Serialize)]
        struct AckMessage {
            #[serde(rename = "type")]
            type_: &'static str,
            id: u64,
        }

        while let Some(id) = ack_rx.recv().await {
            let msg = AckMessage { type_: "ack", id };
            let json = match serde_json::to_string(&msg) {
                Ok(j) => j,
                Err(e) => {
                    tracing::warn!("Failed to serialize ack: {}", e);
                    continue;
                }
            };

            if let Err(e) = write.send(Message::Text(json.into())).await {
                tracing::warn!("Failed to send ack: {}", e);
                break;
            }
        }
    }

    /// Reader task: reads events from WebSocket and sends to channel.
    async fn reader_task(
        mut read: WsSource,
        event_tx: mpsc::Sender<EventWithAck>,
        ack_tx: mpsc::Sender<u64>,
    ) {
        while let Some(msg_result) = read.next().await {
            match msg_result {
                Ok(Message::Text(event)) => {
                    let event_with_ack = EventWithAck {
                        event,
                        ack_tx: ack_tx.clone(),
                    };
                    if event_tx.send(event_with_ack).await.is_err() {
                        break;
                    }
                }
                Ok(Message::Close(_)) => {
                    break;
                }
                Ok(_) => {
                    // Ignore ping/pong/binary
                }
                Err(_) => {
                    break;
                }
            }
        }
    }
}