pondsocket-axum 0.1.1

Axum WebSocket adapter for PondSocket.
Documentation
//! Axum adapter for PondSocket.

use async_trait::async_trait;
use axum::extract::ws::{Message, WebSocket};
use futures_util::{SinkExt, StreamExt};
use pondsocket::contexts::IncomingConnection;
use pondsocket::errors::{Result, internal};
use pondsocket::transport::Transport;
use pondsocket::types::{Event, TransportType};
use pondsocket::wire::{event_to_json, parse_inbound_text};
use pondsocket::{Endpoint, PondSocket};
use pondsocket_common::PondAssigns;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};

pub struct AxumWebSocketTransport {
    id: String,
    assigns: RwLock<PondAssigns>,
    active: RwLock<bool>,
    tx: mpsc::Sender<Message>,
}

impl AxumWebSocketTransport {
    pub fn new(id: impl Into<String>, assigns: PondAssigns, tx: mpsc::Sender<Message>) -> Self {
        Self {
            id: id.into(),
            assigns: RwLock::new(assigns),
            active: RwLock::new(true),
            tx,
        }
    }
}

#[async_trait]
impl Transport for AxumWebSocketTransport {
    fn id(&self) -> &str {
        &self.id
    }

    async fn send_event(&self, event: Event) -> Result<()> {
        let text = event_to_json(&event).map_err(|e| internal("", e.to_string()))?;
        self.tx
            .send(Message::Text(text.into()))
            .await
            .map_err(|_| internal("", "websocket writer closed"))
    }

    async fn close(&self) -> Result<()> {
        *self.active.write().await = false;
        let _ = self.tx.send(Message::Close(None)).await;
        Ok(())
    }

    fn transport_type(&self) -> TransportType {
        TransportType::WebSocket
    }

    async fn is_active(&self) -> bool {
        *self.active.read().await
    }

    async fn get_assign(&self, key: &str) -> Option<Value> {
        self.assigns.read().await.get(key).cloned()
    }

    async fn set_assign(&self, key: &str, value: Value) {
        self.assigns.write().await.insert(key.to_owned(), value);
    }

    async fn clone_assigns(&self) -> PondAssigns {
        self.assigns.read().await.clone()
    }
}

#[derive(Debug, Clone, Default)]
pub struct RequestParts {
    pub path: String,
    pub headers: HashMap<String, String>,
    pub cookies: HashMap<String, String>,
    pub query: HashMap<String, String>,
    pub address: String,
}

pub async fn handle_socket(pond: Arc<PondSocket>, socket: WebSocket, request: RequestParts) {
    let Some(matched) = pond.match_endpoint(&request.path).await else {
        let mut socket = socket;
        let _ = socket.send(Message::Close(None)).await;
        return;
    };
    let endpoint = matched.endpoint;
    let incoming = IncomingConnection {
        id: String::new(),
        headers: request.headers,
        cookies: request.cookies,
        query: request.query,
        params: matched.route.params.clone(),
        address: request.address,
    };
    let ctx = endpoint
        .request_connection(incoming, matched.route, None)
        .await;
    if ctx.is_declined() {
        let mut socket = socket;
        let _ = socket.send(Message::Close(None)).await;
        return;
    }

    let (mut sender, mut receiver) = socket.split();
    let (tx, mut rx) = mpsc::channel::<Message>(1024);
    let transport = Arc::new(AxumWebSocketTransport::new(
        ctx.user_id.clone(),
        ctx.assigns(),
        tx,
    ));
    if endpoint
        .register_transport(transport.clone())
        .await
        .is_err()
    {
        let _ = sender.send(Message::Close(None)).await;
        return;
    }
    send_pending_reply(&ctx, transport.clone()).await;

    let writer = tokio::spawn(async move {
        while let Some(message) = rx.recv().await {
            if sender.send(message).await.is_err() {
                break;
            }
        }
    });

    read_loop(endpoint.clone(), transport.clone(), &mut receiver).await;
    let _ = endpoint.unregister_transport(transport.id()).await;
    let _ = transport.close().await;
    writer.abort();
}

async fn read_loop(
    endpoint: Arc<Endpoint>,
    transport: Arc<AxumWebSocketTransport>,
    receiver: &mut futures_util::stream::SplitStream<WebSocket>,
) {
    while let Some(next) = receiver.next().await {
        let Ok(message) = next else {
            break;
        };
        match message {
            Message::Text(text) => {
                if text.len() > endpoint.max_message_size() {
                    break;
                }
                let Ok(event) = parse_inbound_text(&text) else {
                    continue;
                };
                let _ = endpoint.handle_message(event, transport.clone()).await;
            }
            Message::Binary(bytes) => {
                if bytes.len() > endpoint.max_message_size() {
                    break;
                }
                let Ok(text) = String::from_utf8(bytes.to_vec()) else {
                    continue;
                };
                let Ok(event) = parse_inbound_text(&text) else {
                    continue;
                };
                let _ = endpoint.handle_message(event, transport.clone()).await;
            }
            Message::Close(_) => break,
            _ => {}
        }
    }
}

async fn send_pending_reply(
    ctx: &pondsocket::ConnectionContext,
    transport: Arc<AxumWebSocketTransport>,
) {
    if let Some((event, payload)) = ctx.pending_reply() {
        let ev = Event::new(
            "SYSTEM",
            "GATEWAY",
            pondsocket_common::uuid(),
            event,
            payload,
        );
        let _ = transport.send_event(ev).await;
    }
}