use super::ConnectionHandler;
use super::types::*;
use crate::app::config::App;
use crate::channel::ChannelManager;
use crate::channel::{ChannelType, PresenceMemberInfo};
use crate::error::Result;
use crate::presence::PresenceManager;
use crate::protocol::messages::{MessageData, PresenceData, PusherMessage};
use crate::utils::is_cache_channel;
use crate::websocket::SocketId;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug)]
pub struct SubscriptionResult {
pub success: bool,
pub auth_error: Option<String>,
pub member: Option<PresenceMember>,
pub channel_connections: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct PresenceMember {
pub user_id: String,
pub user_info: Value,
}
impl ConnectionHandler {
pub async fn execute_subscription(
&self,
socket_id: &SocketId,
app_config: &App,
request: &SubscriptionRequest,
is_authenticated: bool,
) -> Result<SubscriptionResult> {
let temp_message = PusherMessage {
channel: Some(request.channel.clone()),
event: Some("pusher:subscribe".to_string()),
data: Some(MessageData::Json(serde_json::json!({
"channel": request.channel,
"auth": request.auth,
"channel_data": request.channel_data
}))),
name: None,
user_id: None,
};
let subscription_result = ChannelManager::subscribe(
&self.connection_manager,
socket_id.as_ref(),
&temp_message,
&request.channel,
is_authenticated,
&app_config.id,
)
.await?;
if subscription_result.success
&& let Some(ref metrics) = self.metrics
{
let channel_type = ChannelType::from_name(&request.channel);
let channel_type_str = channel_type.as_str();
{
let metrics_locked = metrics.lock().await;
metrics_locked.mark_channel_subscription(&app_config.id, channel_type_str);
}
if subscription_result.channel_connections == Some(1) {
self.increment_active_channel_count(
&app_config.id,
channel_type_str,
metrics.clone(),
)
.await;
}
}
if subscription_result.success
&& ChannelType::from_name(&request.channel).requires_authentication()
&& let Err(e) = self
.clear_user_authentication_timeout(&app_config.id, socket_id)
.await
{
tracing::warn!(
"Failed to clear user auth timeout for socket {}: {}",
socket_id,
e
);
}
Ok(SubscriptionResult {
success: subscription_result.success,
auth_error: subscription_result.auth_error,
member: subscription_result.member.map(|m| PresenceMember {
user_id: m.user_id.to_string(),
user_info: m.user_info,
}),
channel_connections: subscription_result.channel_connections,
})
}
pub async fn handle_post_subscription(
&self,
socket_id: &SocketId,
app_config: &App,
request: &SubscriptionRequest,
subscription_result: &SubscriptionResult,
) -> Result<()> {
self.update_connection_subscription_state(
socket_id,
app_config,
request,
subscription_result,
)
.await?;
let channel_type = ChannelType::from_name(&request.channel);
match channel_type {
ChannelType::Presence => {
self.handle_presence_subscription_success(
socket_id,
app_config,
request,
subscription_result,
)
.await?;
}
_ => {
self.send_subscription_succeeded(socket_id, app_config, &request.channel, None)
.await?;
}
}
if subscription_result.channel_connections == Some(1)
&& let Some(webhook_integration) = self.webhook_integration.clone()
{
let app_config = app_config.clone();
let channel = request.channel.clone();
tokio::spawn(async move {
if let Err(e) = webhook_integration
.send_channel_occupied(&app_config, &channel)
.await
{
tracing::warn!(
"Failed to send channel_occupied webhook for {}: {}",
channel,
e
);
}
});
}
if !request.channel.starts_with("presence-")
&& let Some(webhook_integration) = &self.webhook_integration
{
let current_count = self
.connection_manager
.lock()
.await
.get_channel_socket_count(&app_config.id, &request.channel)
.await;
webhook_integration
.send_subscription_count_changed(app_config, &request.channel, current_count)
.await
.ok();
}
if is_cache_channel(&request.channel) {
self.send_missed_cache_if_exists(&app_config.id, socket_id, &request.channel)
.await?;
}
Ok(())
}
async fn update_connection_subscription_state(
&self,
socket_id: &SocketId,
app_config: &App,
request: &SubscriptionRequest,
subscription_result: &SubscriptionResult,
) -> Result<()> {
let mut connection_manager = self.connection_manager.lock().await;
if let Some(conn_arc) = connection_manager
.get_connection(socket_id, &app_config.id)
.await
{
let mut conn_locked = conn_arc.inner.lock().await;
conn_locked.subscribe_to_channel(request.channel.clone());
if let Some(ref member) = subscription_result.member {
conn_locked.state.user_id = Some(member.user_id.clone());
let presence_info = PresenceMemberInfo {
user_id: member.user_id.clone(),
user_info: Some(member.user_info.clone()),
};
conn_locked.add_presence_info(request.channel.clone(), presence_info);
drop(conn_locked);
drop(connection_manager);
self.connection_manager
.lock()
.await
.add_user(conn_arc.clone())
.await?;
} else {
drop(conn_locked);
drop(connection_manager);
}
}
Ok(())
}
async fn handle_presence_subscription_success(
&self,
socket_id: &SocketId,
app_config: &App,
request: &SubscriptionRequest,
subscription_result: &SubscriptionResult,
) -> Result<()> {
if let Some(ref presence_member) = subscription_result.member {
PresenceManager::handle_member_added(
Arc::clone(&self.connection_manager),
self.webhook_integration.as_ref(),
app_config,
&request.channel,
&presence_member.user_id,
Some(&presence_member.user_info),
Some(socket_id),
)
.await?;
let members_map = self
.connection_manager
.lock()
.await
.get_channel_members(&app_config.id, &request.channel)
.await?;
let presence_data = PresenceData {
ids: members_map.keys().cloned().collect::<Vec<String>>(),
hash: members_map
.iter()
.map(|(k, v)| (k.clone(), v.user_info.clone()))
.collect::<HashMap<String, Option<Value>>>(),
count: members_map.len(),
};
self.send_subscription_succeeded(
socket_id,
app_config,
&request.channel,
Some(presence_data),
)
.await?;
}
Ok(())
}
async fn send_subscription_succeeded(
&self,
socket_id: &SocketId,
app_config: &App,
channel: &str,
data: Option<PresenceData>,
) -> Result<()> {
let response_msg = PusherMessage::subscription_succeeded(channel.to_string(), data);
self.connection_manager
.lock()
.await
.send_message(&app_config.id, socket_id, response_msg)
.await
}
}