use std::any::Any;
use std::sync::{Arc, RwLock};
use crate::error::Error;
use crate::messages::{ChannelMessage, Messages};
use crate::traits::{ChannelSignalTrait, private};
use crate::ws_signals::WsSignals;
use async_trait::async_trait;
use leptos::prelude::*;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::broadcast::{Sender, channel};
#[derive(Clone)]
pub struct ServerChannelSignal<T>
where
T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de>,
{
name: String,
observers: Arc<Sender<(Option<String>, Messages)>>,
server_callback: Arc<RwLock<Option<Arc<dyn Fn(&T) + Send + Sync + 'static>>>>,
}
#[async_trait]
impl<T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static> ChannelSignalTrait
for ServerChannelSignal<T>
{
fn as_any(&self) -> &dyn Any {
self
}
fn subscribe(
&self,
) -> Result<tokio::sync::broadcast::Receiver<(Option<String>, Messages)>, Error> {
Ok(self.observers.subscribe())
}
fn handle_message(&self, message: Value) -> Result<(), Error> {
if let Ok(lock) = self.server_callback.read()
&& let Some(callback) = lock.as_ref()
&& let Ok(message) = serde_json::from_value(message)
{
callback(&message);
}
Ok(())
}
fn on_reconnect_message(&self) -> Result<Messages, Error> {
Ok(Messages::Channel(ChannelMessage::Establish(
self.name.clone(),
)))
}
}
impl<T> ServerChannelSignal<T>
where
T: Clone + Serialize + Send + Sync + for<'de> Deserialize<'de> + 'static,
{
pub fn new(name: &str) -> Result<Self, Error> {
let mut signals = use_context::<WsSignals>().ok_or(Error::MissingServerSignals)?;
Self::new_with_context(&mut signals, name)
}
pub fn new_with_context(signals: &mut WsSignals, name: &str) -> Result<Self, Error> {
if let Some(signal) = signals.get_channel(name) {
return Ok(signal);
}
let (send, _) = channel(32);
let new_signal = Self {
name: name.to_owned(),
observers: Arc::new(send),
server_callback: Arc::new(RwLock::new(None)),
};
let signal = new_signal.clone();
match signals.create_channel(
name,
new_signal,
&Messages::Channel(ChannelMessage::Establish(name.to_owned())),
) {
Ok(()) => Ok(signal),
Err(Error::AddingSignalFailed) => {
signals.get_channel(name).ok_or(Error::AddingSignalFailed)
}
Err(e) => Err(e),
}
}
pub fn on_server<F>(&self, callback: F) -> Result<(), Error>
where
F: Fn(&T) + Send + Sync + 'static,
{
let Ok(mut server_callback) = self.server_callback.write() else {
return Err(Error::AddingChannelHandlerFailed);
};
*server_callback = Some(Arc::new(callback));
Ok(())
}
pub fn on_client<F>(&self, _callback: F) -> Result<(), Error>
where
F: Fn(&T) + Send + Sync + 'static,
{
Ok(())
}
pub fn send_message(&self, message: T) -> Result<(), Error> {
let message = serde_json::to_value(&message)?;
self.observers
.send((
None,
Messages::Channel(ChannelMessage::Message(self.name.clone(), message)),
))
.map_err(|_| Error::SendMessageFailed)?;
Ok(())
}
pub fn delete(&self) -> Result<(), Error> {
let mut signals = use_context::<WsSignals>().ok_or(Error::MissingServerSignals)?;
signals.delete_channel(&self.name)
}
}
impl<T> private::DeleteTrait for ServerChannelSignal<T>
where
T: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
{
fn delete(&self) -> Result<(), Error> {
self.observers
.send((
None,
Messages::Channel(ChannelMessage::Delete(self.name.clone())),
))
.map_err(|_| Error::SendMessageFailed)?;
Ok(())
}
}