leptos_ws 0.9.7

Leptos WS is a Websocket for the Leptos framework to support updates coordinated from the Server
Documentation
use std::any::Any;
use std::sync::{Arc, RwLock};

use crate::error::Error;
use crate::messages::{ChannelMessage, Messages};
use crate::traits::{ChannelSignalTrait, private};
use crate::ws_signals::WsSignals;
use async_trait::async_trait;
use leptos::prelude::*;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::broadcast::{Sender, channel};

/// A signal owned by the server which writes to the websocket when mutated.
#[derive(Clone)]
pub struct ServerChannelSignal<T>
where
    T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de>,
{
    name: String,
    observers: Arc<Sender<(Option<String>, Messages)>>,
    server_callback: Arc<RwLock<Option<Arc<dyn Fn(&T) + Send + Sync + 'static>>>>,
}

#[async_trait]
impl<T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static> ChannelSignalTrait
    for ServerChannelSignal<T>
{
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn subscribe(
        &self,
    ) -> Result<tokio::sync::broadcast::Receiver<(Option<String>, Messages)>, Error> {
        Ok(self.observers.subscribe())
    }

    fn handle_message(&self, message: Value) -> Result<(), Error> {
        if let Ok(lock) = self.server_callback.read()
            && let Some(callback) = lock.as_ref()
            && let Ok(message) = serde_json::from_value(message)
        {
            callback(&message);
        }

        Ok(())
    }

    fn on_reconnect_message(&self) -> Result<Messages, Error> {
        Ok(Messages::Channel(ChannelMessage::Establish(
            self.name.clone(),
        )))
    }
}

impl<T> ServerChannelSignal<T>
where
    T: Clone + Serialize + Send + Sync + for<'de> Deserialize<'de> + 'static,
{
    pub fn new(name: &str) -> Result<Self, Error> {
        let mut signals = use_context::<WsSignals>().ok_or(Error::MissingServerSignals)?;
        Self::new_with_context(&mut signals, name)
    }

    pub fn new_with_context(signals: &mut WsSignals, name: &str) -> Result<Self, Error> {
        if let Some(signal) = signals.get_channel(name) {
            return Ok(signal);
        }
        let (send, _) = channel(32);
        let new_signal = Self {
            name: name.to_owned(),
            observers: Arc::new(send),
            server_callback: Arc::new(RwLock::new(None)),
        };
        let signal = new_signal.clone();

        match signals.create_channel(
            name,
            new_signal,
            &Messages::Channel(ChannelMessage::Establish(name.to_owned())),
        ) {
            Ok(()) => Ok(signal),
            Err(Error::AddingSignalFailed) => {
                signals.get_channel(name).ok_or(Error::AddingSignalFailed)
            }
            Err(e) => Err(e),
        }
    }

    /// Register a callback that gets called when a message arrives on the server side
    pub fn on_server<F>(&self, callback: F) -> Result<(), Error>
    where
        F: Fn(&T) + Send + Sync + 'static,
    {
        let Ok(mut server_callback) = self.server_callback.write() else {
            return Err(Error::AddingChannelHandlerFailed);
        };
        *server_callback = Some(Arc::new(callback));
        Ok(())
    }

    /// Register a callback that gets called when a message arrives on the client side
    pub fn on_client<F>(&self, _callback: F) -> Result<(), Error>
    where
        F: Fn(&T) + Send + Sync + 'static,
    {
        Ok(())
    }

    /// Send a message to the client
    pub fn send_message(&self, message: T) -> Result<(), Error> {
        let message = serde_json::to_value(&message)?;
        self.observers
            .send((
                None,
                Messages::Channel(ChannelMessage::Message(self.name.clone(), message)),
            ))
            .map_err(|_| Error::SendMessageFailed)?;

        Ok(())
    }

    pub fn delete(&self) -> Result<(), Error> {
        let mut signals = use_context::<WsSignals>().ok_or(Error::MissingServerSignals)?;
        signals.delete_channel(&self.name)
    }
}

impl<T> private::DeleteTrait for ServerChannelSignal<T>
where
    T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
    fn delete(&self) -> Result<(), Error> {
        self.observers
            .send((
                None,
                Messages::Channel(ChannelMessage::Delete(self.name.clone())),
            ))
            .map_err(|_| Error::SendMessageFailed)?;
        Ok(())
    }
}