use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use serde_json::Value;
use tokio::sync::RwLock;
use crate::callback::{Binding, CallbackRegistry};
use crate::error::RealtimeError;
use crate::types::{
BroadcastConfig, ChannelState, JoinConfig, JoinPayload, PostgresChangesEvent,
PostgresChangesFilter, PresenceConfig, PresenceState, SubscriptionStatus,
};
pub struct ChannelBuilder {
pub(crate) name: String,
pub(crate) topic: String,
pub(crate) broadcast_config: BroadcastConfig,
pub(crate) presence_key: String,
pub(crate) presence_enabled: bool,
pub(crate) postgres_changes: Vec<PostgresChangesFilter>,
pub(crate) bindings: Vec<Binding>,
pub(crate) is_private: bool,
pub(crate) subscribe_timeout: Duration,
pub(crate) access_token: Option<String>,
pub(crate) client_sender: crate::client::ClientSender,
}
impl ChannelBuilder {
pub fn on_postgres_changes<F>(
mut self,
event: PostgresChangesEvent,
filter: PostgresChangesFilter,
callback: F,
) -> Self
where
F: Fn(crate::types::PostgresChangePayload) + Send + Sync + 'static,
{
let filter_index = self.postgres_changes.len();
let filter = filter.event(event);
self.postgres_changes.push(filter);
self.bindings.push(Binding::PostgresChanges {
filter_index,
event,
callback: Arc::new(callback),
});
self
}
pub fn on_broadcast<F>(mut self, event: &str, callback: F) -> Self
where
F: Fn(Value) + Send + Sync + 'static,
{
self.bindings.push(Binding::Broadcast {
event: event.to_string(),
callback: Arc::new(callback),
});
self
}
pub fn on_presence_sync<F>(mut self, callback: F) -> Self
where
F: Fn(&PresenceState) + Send + Sync + 'static,
{
self.presence_enabled = true;
self.bindings.push(Binding::PresenceSync(Arc::new(callback)));
self
}
pub fn on_presence_join<F>(mut self, callback: F) -> Self
where
F: Fn(String, Vec<crate::types::PresenceMeta>) + Send + Sync + 'static,
{
self.presence_enabled = true;
self.bindings
.push(Binding::PresenceJoin(Arc::new(callback)));
self
}
pub fn on_presence_leave<F>(mut self, callback: F) -> Self
where
F: Fn(String, Vec<crate::types::PresenceMeta>) + Send + Sync + 'static,
{
self.presence_enabled = true;
self.bindings
.push(Binding::PresenceLeave(Arc::new(callback)));
self
}
pub fn broadcast_ack(mut self, ack: bool) -> Self {
self.broadcast_config.ack = ack;
self
}
pub fn broadcast_self(mut self, self_send: bool) -> Self {
self.broadcast_config.self_send = self_send;
self
}
pub fn presence_key(mut self, key: &str) -> Self {
self.presence_enabled = true;
self.presence_key = key.to_string();
self
}
pub fn private(mut self) -> Self {
self.is_private = true;
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.subscribe_timeout = timeout;
self
}
pub async fn subscribe<F>(
self,
status_callback: F,
) -> Result<RealtimeChannel, RealtimeError>
where
F: Fn(SubscriptionStatus, Option<RealtimeError>) + Send + Sync + 'static,
{
let join_payload = JoinPayload {
config: JoinConfig {
broadcast: self.broadcast_config.clone(),
presence: PresenceConfig {
key: self.presence_key.clone(),
},
postgres_changes: self.postgres_changes.clone(),
},
access_token: self.access_token.clone(),
};
let registry = CallbackRegistry::new();
{
let mut bindings = registry.bindings.write().await;
for binding in self.bindings {
bindings.push(binding);
}
}
{
let mut status_cb = registry.status_callback.write().await;
*status_cb = Some(Arc::new(status_callback));
}
let inner = Arc::new(ChannelInner {
name: self.name.clone(),
topic: self.topic.clone(),
state: RwLock::new(ChannelState::Joining),
join_ref: RwLock::new(None),
join_payload: RwLock::new(join_payload.clone()),
registry,
presence_state: RwLock::new(PresenceState::new()),
pg_change_id_map: RwLock::new(HashMap::new()),
client_sender: self.client_sender.clone(),
});
let channel = RealtimeChannel {
inner: inner.clone(),
};
self.client_sender
.subscribe_channel(channel.clone(), join_payload, self.subscribe_timeout)
.await?;
Ok(channel)
}
}
#[derive(Clone)]
pub struct RealtimeChannel {
pub(crate) inner: Arc<ChannelInner>,
}
pub(crate) struct ChannelInner {
pub(crate) name: String,
pub(crate) topic: String,
pub(crate) state: RwLock<ChannelState>,
pub(crate) join_ref: RwLock<Option<String>>,
pub(crate) join_payload: RwLock<JoinPayload>,
pub(crate) registry: CallbackRegistry,
pub(crate) presence_state: RwLock<PresenceState>,
pub(crate) pg_change_id_map: RwLock<HashMap<u64, usize>>,
pub(crate) client_sender: crate::client::ClientSender,
}
impl RealtimeChannel {
pub fn topic(&self) -> &str {
&self.inner.topic
}
pub fn name(&self) -> &str {
&self.inner.name
}
pub async fn state(&self) -> ChannelState {
*self.inner.state.read().await
}
pub async fn send_broadcast(
&self,
event: &str,
payload: Value,
) -> Result<(), RealtimeError> {
let state = *self.inner.state.read().await;
if state != ChannelState::Joined {
return Err(RealtimeError::InvalidChannelState {
expected: ChannelState::Joined,
actual: state,
});
}
let join_ref = self.inner.join_ref.read().await;
let join_ref = join_ref
.as_deref()
.ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
self.inner
.client_sender
.send_broadcast(&self.inner.topic, event, payload, join_ref)
.await
}
pub async fn track(&self, payload: Value) -> Result<(), RealtimeError> {
let state = *self.inner.state.read().await;
if state != ChannelState::Joined {
return Err(RealtimeError::InvalidChannelState {
expected: ChannelState::Joined,
actual: state,
});
}
let join_ref = self.inner.join_ref.read().await;
let join_ref = join_ref
.as_deref()
.ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
self.inner
.client_sender
.send_presence_track(&self.inner.topic, payload, join_ref)
.await
}
pub async fn untrack(&self) -> Result<(), RealtimeError> {
let state = *self.inner.state.read().await;
if state != ChannelState::Joined {
return Err(RealtimeError::InvalidChannelState {
expected: ChannelState::Joined,
actual: state,
});
}
let join_ref = self.inner.join_ref.read().await;
let join_ref = join_ref
.as_deref()
.ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
self.inner
.client_sender
.send_presence_untrack(&self.inner.topic, join_ref)
.await
}
pub async fn presence_state(&self) -> PresenceState {
self.inner.presence_state.read().await.clone()
}
pub async fn unsubscribe(&self) -> Result<(), RealtimeError> {
let state = *self.inner.state.read().await;
if state == ChannelState::Closed || state == ChannelState::Leaving {
return Ok(());
}
let join_ref = self.inner.join_ref.read().await;
let join_ref = join_ref
.as_deref()
.ok_or_else(|| RealtimeError::Internal("No join_ref for leave".to_string()))?;
self.inner
.client_sender
.send_leave(&self.inner.topic, join_ref)
.await?;
*self.inner.state.write().await = ChannelState::Leaving;
Ok(())
}
pub async fn update_access_token(&self, token: &str) -> Result<(), RealtimeError> {
let state = *self.inner.state.read().await;
if state != ChannelState::Joined {
return Err(RealtimeError::InvalidChannelState {
expected: ChannelState::Joined,
actual: state,
});
}
{
let mut jp = self.inner.join_payload.write().await;
jp.access_token = Some(token.to_string());
}
let join_ref = self.inner.join_ref.read().await;
let join_ref = join_ref
.as_deref()
.ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
self.inner
.client_sender
.send_access_token(&self.inner.topic, token, join_ref)
.await
}
}