use crate::adapter::adapter::Adapter;
use crate::app::auth::AuthValidator;
use crate::app::config::App;
use crate::app::manager::AppManager;
use crate::cache::manager::CacheManager;
use crate::channel::{ChannelManager, ChannelType, PresenceMemberInfo};
use crate::metrics::MetricsInterface;
use crate::options::ServerOptions;
use crate::protocol::constants::{
CHANNEL_NAME_MAX_LENGTH as DEFAULT_CHANNEL_NAME_MAX_LENGTH, CLIENT_EVENT_PREFIX,
EVENT_NAME_MAX_LENGTH as DEFAULT_EVENT_NAME_MAX_LENGTH,
};
use crate::protocol::messages::{ErrorData, MessageData, PusherApiMessage, PusherMessage};
use crate::rate_limiter::{RateLimiter, memory_limiter::MemoryRateLimiter};
use crate::utils::{is_cache_channel, validate_channel_name};
use crate::watchlist::WatchlistManager;
use crate::webhook::integration::WebhookIntegration;
use crate::websocket::{SocketId, UserInfo, WebSocketRef};
use crate::{
error::{Error, Result}, utils,
};
use dashmap::DashMap;
use fastwebsockets::{
FragmentCollectorRead, Frame, OpCode, Payload, WebSocketError, WebSocketWrite, upgrade,
};
use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; use serde_json::{Value, json};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::WriteHalf; use tokio::sync::Mutex;
use tokio::sync::RwLock;
use tokio::time::sleep;
use tracing::{error, info, warn};
pub struct ConnectionHandler {
pub(crate) app_manager: Arc<dyn AppManager + Send + Sync>,
pub(crate) channel_manager: Arc<RwLock<ChannelManager>>,
pub(crate) connection_manager: Arc<Mutex<Box<dyn Adapter + Send + Sync>>>,
pub(crate) cache_manager: Arc<Mutex<dyn CacheManager + Send + Sync>>,
pub(crate) metrics: Option<Arc<Mutex<dyn MetricsInterface + Send + Sync>>>,
pub(crate) webhook_integration: Option<Arc<WebhookIntegration>>,
pub(crate) client_event_limiters: Arc<DashMap<SocketId, Arc<dyn RateLimiter + Send + Sync>>>,
pub(crate) watchlist_manager: Arc<WatchlistManager>,
pub(crate) server_options: Arc<ServerOptions>,
}
impl ConnectionHandler {
pub fn new(
app_manager: Arc<dyn AppManager + Send + Sync>,
channel_manager: Arc<RwLock<ChannelManager>>,
connection_manager: Arc<Mutex<Box<dyn Adapter + Send + Sync>>>,
cache_manager: Arc<Mutex<dyn CacheManager + Send + Sync>>,
metrics: Option<Arc<Mutex<dyn MetricsInterface + Send + Sync>>>,
webhook_integration: Option<Arc<WebhookIntegration>>,
server_options: ServerOptions,
) -> Self {
Self {
app_manager,
channel_manager,
connection_manager,
cache_manager,
metrics,
webhook_integration,
client_event_limiters: Arc::new(DashMap::new()),
watchlist_manager: Arc::new(WatchlistManager::new()),
server_options: Arc::new(server_options),
}
}
async fn user_has_other_connections_in_presence_channel(
&self,
app_id: &str,
channel_name: &str,
user_id: &str,
) -> Result<bool> {
let mut connection_manager = self.connection_manager.lock().await;
let user_sockets = connection_manager.get_user_sockets(user_id, app_id).await?;
for ws_ref in user_sockets.iter() {
let socket_state_guard = ws_ref.0.lock().await;
if socket_state_guard.state.is_subscribed(channel_name) {
return Ok(true);
}
}
Ok(false)
}
#[allow(dead_code)]
async fn send_webhook_event<F, Fut>(&self, app: &App, webhook_fn: F) -> Result<()>
where
F: FnOnce(&WebhookIntegration, &App) -> Fut,
Fut: Future<Output = Result<()>>,
{
if let Some(webhook_integration_instance) = &self.webhook_integration {
if webhook_integration_instance.is_enabled() {
match webhook_fn(webhook_integration_instance, app).await {
Ok(_) => Ok(()),
Err(e) => {
warn!("Webhook event failed to send: {}", e);
Ok(())
}
}
} else {
Ok(())
}
} else {
Ok(())
}
}
pub async fn send_missed_cache_if_exists(
&self,
app_id: &str,
socket_id: &SocketId,
channel: &str,
) -> Result<()> {
let mut cache_manager = self.cache_manager.lock().await;
let key = format!("app:{}:channel:{}:cache_miss", app_id, channel);
let cache_result = cache_manager.get(key.as_str()).await;
match cache_result {
Ok(Some(cache_content)) => {
let cache_message: PusherMessage = serde_json::from_str(&cache_content)?;
self.connection_manager
.lock()
.await
.send_message(app_id, socket_id, cache_message)
.await?;
}
Ok(None) => {
let message = PusherMessage {
channel: Some(channel.to_string()),
name: None,
event: Some("pusher:cache_miss".to_string()),
data: None,
};
self.connection_manager
.lock()
.await
.send_message(app_id, socket_id, message)
.await?;
if let Some(app_config) = self.app_manager.find_by_id(app_id).await? {
if let Some(webhook_integration_instance) = &self.webhook_integration {
webhook_integration_instance
.send_cache_missed(&app_config, channel)
.await?;
}
}
info!("No missed cache for channel: {}", channel);
}
Err(e) => {
error!("Failed to get cache for channel {}: {}", channel, e);
return Err(e);
}
}
Ok(())
}
async fn send_error_and_close_ws(
ws_tx: &mut WebSocketWrite<WriteHalf<TokioIo<Upgraded>>>,
error: &Error,
) {
let error_data = ErrorData {
message: error.to_string(),
code: Some(error.close_code()),
};
let error_json_message = PusherMessage::error(
error_data.code.unwrap_or(4000),
error_data.message.clone(),
None,
);
if let Ok(payload_str) = serde_json::to_string(&error_json_message) {
if let Err(e) = ws_tx
.write_frame(Frame::text(Payload::from(payload_str.into_bytes())))
.await
{
warn!("Failed to send pusher:error message before close: {}", e);
}
} else {
warn!("Failed to serialize pusher:error message.");
}
if let Err(e) = ws_tx
.write_frame(Frame::close(
error.close_code(),
error.to_string().as_bytes(),
))
.await
{
warn!("Failed to send WebSocket close frame: {}", e);
}
}
pub async fn handle_socket(&self, fut: upgrade::UpgradeFut, app_key: String) -> Result<()> {
let app_config_option = self.app_manager.find_by_key(&app_key).await;
let (socket_rx_frag, mut socket_tx_direct) = match fut.await {
Ok(ws) => ws.split(tokio::io::split),
Err(e) => {
error!("WebSocket upgrade failed for app_key {}: {}", app_key, e);
return Err(Error::WebSocketError(e)); }
};
let app_config = match app_config_option {
Ok(Some(app)) => app,
Ok(None) => {
ConnectionHandler::send_error_and_close_ws(
&mut socket_tx_direct,
&Error::ApplicationNotFound,
)
.await;
return Ok(());
}
Err(db_err) => {
error!(
"Database error during app lookup for key {}: {}",
app_key, db_err
);
let internal_err = Error::InternalError("App lookup failed".to_string());
ConnectionHandler::send_error_and_close_ws(&mut socket_tx_direct, &internal_err)
.await;
return Ok(());
}
};
if !app_config.enabled {
ConnectionHandler::send_error_and_close_ws(
&mut socket_tx_direct,
&Error::ApplicationDisabled,
)
.await;
return Ok(());
}
let max_connections = app_config.max_connections;
if max_connections > 0 {
let current_connections_result = self
.connection_manager
.lock()
.await
.get_sockets_count(&app_config.id)
.await;
match current_connections_result {
Ok(count) if count >= max_connections as usize => {
ConnectionHandler::send_error_and_close_ws(
&mut socket_tx_direct,
&Error::OverConnectionQuota,
)
.await;
return Ok(());
}
Err(e) => {
error!(
"Error getting sockets count for app {}: {}",
app_config.id, e
);
let internal_err =
Error::InternalError("Failed to check connection quota".to_string());
ConnectionHandler::send_error_and_close_ws(
&mut socket_tx_direct,
&internal_err,
)
.await;
return Ok(());
}
_ => {} }
}
let socket_id = SocketId::new();
info!("New socket: {} for app: {}", socket_id, app_config.id);
{
let mut connection_manager_locked = self.connection_manager.lock().await;
if let Some(conn) = connection_manager_locked
.get_connection(&socket_id, &app_config.id)
.await
{
connection_manager_locked
.cleanup_connection(&app_config.id, WebSocketRef(conn))
.await;
}
if let Err(e) = connection_manager_locked
.add_socket(
socket_id.clone(),
socket_tx_direct, &app_config.id,
&self.app_manager,
)
.await
{
error!(
"Fatal error: Failed to add socket {} to connection manager: {}. Connection cannot proceed.",
socket_id, e
);
return Err(e); }
if let Some(ref metrics) = self.metrics {
let metrics_locked = metrics.lock().await;
metrics_locked.mark_new_connection(&app_config.id, &socket_id)
}
}
if app_config.max_client_events_per_second > 0 {
let limiter = Arc::new(MemoryRateLimiter::new(
app_config.max_client_events_per_second,
1, ));
self.client_event_limiters
.insert(socket_id.clone(), limiter);
info!(
"Initialized client event rate limiter for socket {}: {} events/sec",
socket_id, app_config.max_client_events_per_second
);
}
if let Err(e) = self
.send_connection_established(&app_config.id, &socket_id)
.await
{
self.send_error(&app_config.id, &socket_id, &e, None)
.await
.unwrap_or_else(|err_send| {
error!(
"Failed to send pusher:error after send_connection_established failed: {}",
err_send
);
});
let mut connection_manager_locked = self.connection_manager.lock().await;
if let Some(conn_arc) = connection_manager_locked
.get_connection(&socket_id, &app_config.id)
.await
{
let mut conn_locked = conn_arc.lock().await;
if let Err(close_err) = conn_locked.close(e.close_code(), e.to_string()).await {
warn!(
"Failed to send WebSocket close frame to socket {} after send_connection_established failed: {}",
socket_id, close_err
);
}
}
drop(connection_manager_locked);
if let Err(disconnect_err) = self.handle_disconnect(&app_config.id, &socket_id).await {
error!(
"Error during handle_disconnect after send_connection_established failed for {}: {}",
socket_id, disconnect_err
);
}
self.client_event_limiters.remove(&socket_id); return Ok(()); }
self.set_activity_timeout(&app_config.id, &socket_id)
.await?;
if app_config.enable_user_authentication.unwrap_or(false) {
let auth_timeout = self.server_options.user_authentication_timeout;
self.set_user_authentication_timeout(&app_config.id, &socket_id, auth_timeout)
.await?;
}
let mut fragment_collector = FragmentCollectorRead::new(socket_rx_frag);
while let Ok(frame) = fragment_collector
.read_frame(&mut move |_| async { Ok::<_, WebSocketError>(()) })
.await
{
match frame.opcode {
OpCode::Close => {
info!("Received Close frame from socket {}", socket_id);
if let Some(ref metrics) = self.metrics {
let metrics_locked = metrics.lock().await;
metrics_locked.mark_disconnection(&app_config.id, &socket_id);
}
if let Err(e) = self.handle_disconnect(&app_config.id, &socket_id).await {
error!(
"Error during client-initiated disconnect for socket {}: {}",
socket_id, e
);
}
break; }
OpCode::Text | OpCode::Binary => {
if let Err(e) = self
.handle_message(frame, &socket_id, app_config.clone())
.await
{
error!(
"Message handling for socket {} resulted in error: {}. Connection loop will terminate.",
socket_id, e
);
break; }
}
OpCode::Ping => {
let mut connection_manager_locked = self.connection_manager.lock().await;
if let Some(conn_arc) = connection_manager_locked
.get_connection(&socket_id, &app_config.id)
.await
{
let mut conn_locked = conn_arc.lock().await;
conn_locked.state.update_ping();
}
}
_ => {
warn!(
"Unsupported opcode received from {}: {:?}",
socket_id, frame.opcode
);
}
}
}
self.client_event_limiters.remove(&socket_id);
info!("Message loop terminated for socket {}", socket_id);
Ok(())
}
pub async fn handle_message(
&self,
frame: Frame<'static>,
socket_id: &SocketId,
app_config: App,
) -> Result<()> {
self.update_activity_timeout(&app_config.id, socket_id)
.await?;
let msg_payload = String::from_utf8(frame.payload.to_vec())
.map_err(|e| Error::InvalidMessageFormat(format!("Invalid UTF-8: {}", e)))?;
let message: PusherMessage = serde_json::from_str(&msg_payload)
.map_err(|e| Error::InvalidMessageFormat(format!("Invalid JSON: {}", e)))?;
info!("Received message from {}: {:?}", socket_id, message);
let event_name_str = message
.event
.as_deref()
.ok_or_else(|| Error::InvalidEventName("Event name is required".into()))?;
let channel_name_option = message.channel.clone();
if event_name_str.starts_with(CLIENT_EVENT_PREFIX) {
if let Some(limiter_arc) = self.client_event_limiters.get(socket_id) {
let limiter = limiter_arc.value(); let limit_result = limiter.increment(socket_id.as_ref()).await?;
if !limit_result.allowed {
warn!(
"Client event rate limit exceeded for socket {}: event '{}'",
socket_id, event_name_str
);
self.send_error(
&app_config.id,
socket_id,
&Error::ClientEventRateLimit, channel_name_option.clone(),
)
.await?;
return Err(Error::ClientEventRateLimit); }
} else if app_config.max_client_events_per_second > 0 {
warn!(
"Client event rate limiter not found for socket {} though app config expects one. App: {}, Event: {}",
socket_id, app_config.id, event_name_str
);
let err = Error::InternalError("Rate limiter misconfiguration".to_string());
self.send_error(&app_config.id, socket_id, &err, channel_name_option.clone())
.await?;
return Err(err);
}
}
let processing_result = match event_name_str {
"pusher:ping" => self.handle_ping(&app_config.id, socket_id).await,
"pusher:subscribe" => {
self.handle_subscribe(socket_id, &app_config, &message)
.await
}
"pusher:unsubscribe" => {
self.handle_unsubscribe(socket_id, &message, &app_config)
.await
}
"pusher:signin" => {
self.handle_signin(socket_id, message.clone(), &app_config) .await
}
_ if event_name_str.starts_with(CLIENT_EVENT_PREFIX) => {
self.handle_client_event(
&app_config,
socket_id,
event_name_str,
message.channel.as_deref(),
message
.data
.and_then(|d| serde_json::to_value(d).ok())
.unwrap_or_default(),
)
.await
}
_ => {
warn!(
"Received unknown Pusher event '{}' from socket {}",
event_name_str, socket_id
);
Ok(())
}
};
if let Err(e) = processing_result {
if !matches!(e, Error::ClientEventRateLimit) {
self.send_error(&app_config.id, socket_id, &e, channel_name_option)
.await
.unwrap_or_else(|send_err| {
error!("Failed to send error to socket {}: {}", socket_id, send_err);
});
}
if e.is_fatal() {
info!(
"Fatal error encountered for socket {}: {}. Closing connection.",
socket_id, e
);
let mut connection_manager_locked = self.connection_manager.lock().await;
if let Some(conn_arc) = connection_manager_locked
.get_connection(socket_id, &app_config.id)
.await
{
let mut conn_locked = conn_arc.lock().await;
if let Err(close_err) = conn_locked.close(e.close_code(), e.to_string()).await {
warn!(
"Attempted to send WebSocket close frame to socket {} due to fatal error, but failed: {}. The connection might already be closing or closed.",
socket_id, close_err
);
}
} else {
warn!(
"Fatal error for socket {}: connection not found in manager for explicit close.",
socket_id
);
}
drop(connection_manager_locked);
if let Err(disconnect_err) = self.handle_disconnect(&app_config.id, socket_id).await
{
error!(
"Error during handle_disconnect after fatal error processing message for socket {}: {}",
socket_id, disconnect_err
);
}
}
return Err(e); }
if let Some(ref metrics) = self.metrics {
let metrics_locked = metrics.lock().await;
let message_size = msg_payload.len();
metrics_locked.mark_ws_message_received(&app_config.id, message_size);
}
Ok(())
}
pub async fn handle_ping(&self, app_id: &str, socket_id: &SocketId) -> Result<()> {
self.connection_manager
.lock()
.await
.send_message(app_id, socket_id, PusherMessage::pong())
.await
}
fn extract_signature(&self, message: &PusherMessage) -> Result<String> {
match &message.data {
Some(MessageData::String(_)) => {
Err(Error::InvalidMessageFormat(
"Subscribe message data should be structured, not a plain string for auth."
.into(),
))
}
Some(MessageData::Json(data_val)) => Ok(data_val
.get("auth")
.and_then(Value::as_str)
.unwrap_or("")
.to_string()),
Some(MessageData::Structured { extra, .. }) => Ok(extra
.get("auth")
.and_then(Value::as_str)
.unwrap_or("")
.to_string()),
None => {
Err(Error::InvalidMessageFormat(
"Missing data field in message requiring authentication.".into(),
))
}
}
}
pub async fn handle_subscribe(
&self,
socket_id: &SocketId,
app_config: &App,
message: &PusherMessage,
) -> Result<()> {
let channel_str = match &message.data {
Some(MessageData::Structured { channel, .. }) => {
channel.as_ref().map(|s| s.as_str()).ok_or_else(|| {
Error::ChannelError("Missing channel field in structured data".into())
})?
}
Some(MessageData::Json(data_val)) => data_val
.get("channel")
.and_then(Value::as_str)
.ok_or_else(|| {
Error::ChannelError("Missing 'channel' field in JSON data".into())
})?,
_ => {
return Err(Error::InvalidMessageFormat(
"Subscribe message data malformed or missing channel".into(),
));
}
};
if !app_config.enabled {
return Err(Error::ApplicationDisabled);
}
validate_channel_name(app_config, channel_str).await?;
let is_authenticated = {
let channel_manager_locked = self.channel_manager.read().await;
let signature = match self.extract_signature(message) {
Ok(s) => s,
Err(_)
if !(channel_str.starts_with("presence-")
|| channel_str.starts_with("private-")) =>
{
String::new()
}
Err(e) => return Err(e), };
if (channel_str.starts_with("presence-") || channel_str.starts_with("private-"))
&& signature.is_empty()
{
return Err(Error::AuthError(
"Authentication signature required for this channel".into(),
));
}
if signature.is_empty()
&& !(channel_str.starts_with("presence-") || channel_str.starts_with("private-"))
{
true } else {
channel_manager_locked.signature_is_valid(
app_config.clone(),
socket_id,
&signature,
message.clone(),
)
}
};
if (channel_str.starts_with("presence-") || channel_str.starts_with("private-"))
&& !is_authenticated
{
return Err(Error::AuthError("Invalid authentication signature".into()));
}
if channel_str.starts_with("presence-") {
let user_info_from_data = match &message.data {
Some(MessageData::Structured { channel_data, .. }) => {
Some(channel_data.as_ref().unwrap().as_str())
}
Some(MessageData::Json(json_data)) => Some(
json_data
.get("channel_data")
.and_then(Value::as_str)
.unwrap(),
),
_ => None,
};
if let Some(cd_str) = user_info_from_data {
let user_info_payload: Value = serde_json::from_str(cd_str).map_err(|_| {
Error::InvalidMessageFormat("Invalid channel_data JSON for presence".into())
})?;
let user_info_for_size_calc = user_info_payload
.get("user_info")
.cloned()
.unwrap_or_default();
let user_info_size_kb =
utils::data_to_bytes_flexible(vec![user_info_for_size_calc]) / 1024;
if let Some(max_size) = app_config.max_presence_member_size_in_kb {
if user_info_size_kb > max_size as usize {
return Err(Error::ChannelError(format!(
"Presence member data size ({}KB) exceeds limit ({}KB)",
user_info_size_kb, max_size
)));
}
}
} else {
return Err(Error::InvalidMessageFormat(
"Missing 'channel_data' for presence channel subscription.".into(),
));
}
if let Some(max_members) = app_config.max_presence_members_per_channel {
let current_members = self
.connection_manager
.lock()
.await
.get_channel_members(&app_config.id, channel_str) .await?
.len();
if current_members >= max_members as usize {
return Err(Error::OverCapacity); }
}
}
let subscription_result = {
let channel_manager_locked = self.channel_manager.write().await;
channel_manager_locked
.subscribe(
socket_id.0.as_str(),
message,
channel_str,
is_authenticated, &app_config.id,
)
.await? };
if !subscription_result.success {
return Err(Error::AuthError(
subscription_result.auth_error.unwrap_or_else(|| {
"Subscription failed due to an authentication issue within channel manager"
.to_string()
}),
));
}
if subscription_result.channel_connections == Some(1) {
if let Some(webhook_integration_instance) = &self.webhook_integration {
webhook_integration_instance
.send_channel_occupied(app_config, channel_str)
.await
.ok();
}
}
if !channel_str.starts_with("presence-") {
if let Some(webhook_integration_instance) = &self.webhook_integration {
let current_count = self
.connection_manager
.lock()
.await
.get_channel_socket_count(&app_config.id, channel_str)
.await;
info!(
"Sending subscription_count webhook for channel {} (count: {}) after subscribe",
channel_str, current_count
);
webhook_integration_instance
.send_subscription_count_changed(app_config, channel_str, current_count)
.await
.ok();
}
}
let channel_type = ChannelType::from_name(channel_str);
let presence_data_tuple = if channel_type == ChannelType::Presence {
subscription_result.member.as_ref().map(|presence_member| {
(
presence_member.user_id.as_str(),
PresenceMemberInfo {
user_id: presence_member.user_id.clone(),
user_info: Some(presence_member.user_info.clone()),
},
)
})
} else {
None
};
{
let mut connection_manager_locked = self.connection_manager.lock().await;
if let Some(conn_arc) = connection_manager_locked
.get_connection(socket_id, &app_config.id)
.await
{
let mut conn_locked = conn_arc.lock().await;
conn_locked
.state
.subscribed_channels
.insert(channel_str.to_string());
if let Some((user_id_str, presence_info_val)) = presence_data_tuple {
conn_locked.state.user_id = Some(user_id_str.to_string());
if let Some(ref mut presence_map_val) = conn_locked.state.presence {
presence_map_val.insert(channel_str.to_string(), presence_info_val);
} else {
let mut new_presence_map = HashMap::new();
new_presence_map.insert(channel_str.to_string(), presence_info_val);
conn_locked.state.presence = Some(new_presence_map);
}
}
}
}
if channel_type == ChannelType::Presence {
if let Some(presence_member) = subscription_result.member {
let user_id_str = &presence_member.user_id;
let presence_info_val = PresenceMemberInfo {
user_id: user_id_str.clone(),
user_info: Some(presence_member.user_info.clone()),
};
if let Some(webhook_integration_instance) = &self.webhook_integration {
webhook_integration_instance
.send_member_added(app_config, channel_str, user_id_str)
.await
.ok();
}
let members_map = {
let mut connection_manager_locked = self.connection_manager.lock().await;
let current_members = connection_manager_locked
.get_channel_members(&app_config.id, channel_str)
.await?;
let member_added_msg = PusherMessage::member_added(
channel_str.to_string(),
user_id_str.clone(),
presence_info_val.user_info.clone(),
);
connection_manager_locked
.send(
channel_str,
member_added_msg,
Some(socket_id),
&app_config.id,
)
.await?;
current_members
};
let presence_message_val = json!({
"presence": {
"ids": members_map.keys().collect::<Vec<&String>>(),
"hash": members_map.iter()
.map(|(k, v)| (k.as_str(), v.user_info.clone()))
.collect::<HashMap<&str, Option<Value>>>(),
"count": members_map.len()
}
});
let subscription_succeeded_msg = PusherMessage::subscription_succeeded(
channel_str.to_string(),
Some(presence_message_val),
);
self.connection_manager
.lock()
.await
.send_message(&app_config.id, socket_id, subscription_succeeded_msg)
.await?;
}
} else {
let response_msg = PusherMessage::subscription_succeeded(channel_str.to_string(), None);
self.connection_manager
.lock()
.await
.send_message(&app_config.id, socket_id, response_msg)
.await?;
}
if is_cache_channel(channel_str) {
self.send_missed_cache_if_exists(&app_config.id, socket_id, channel_str)
.await?;
}
Ok(())
}
pub async fn handle_unsubscribe(
&self,
socket_id: &SocketId,
message: &PusherMessage,
app_config: &App,
) -> Result<()> {
let message_data_ref = message.data.as_ref().ok_or_else(|| {
Error::InvalidMessageFormat("Missing data in unsubscribe message".into())
})?;
let channel_name_str = match message_data_ref {
MessageData::String(channel_str_val) => channel_str_val.as_str(),
MessageData::Json(data_val) => data_val
.get("channel")
.and_then(Value::as_str)
.ok_or_else(|| {
Error::InvalidMessageFormat("Missing channel in unsubscribe message".into())
})?,
MessageData::Structured { channel, .. } => {
channel.as_ref().map(|s| s.as_str()).ok_or_else(|| {
Error::InvalidMessageFormat("Missing channel in unsubscribe message".into())
})?
}
};
let user_id_of_socket: Option<String> = {
let mut conn_manager = self.connection_manager.lock().await;
if let Some(conn) = conn_manager.get_connection(socket_id, &app_config.id).await {
conn.lock().await.state.user_id.clone()
} else {
None
}
};
let _leave_response = {
let channel_manager_locked = self.channel_manager.write().await;
channel_manager_locked
.unsubscribe(
socket_id.0.as_str(),
channel_name_str,
&app_config.id,
user_id_of_socket.as_deref(),
)
.await? };
{
let mut conn_manager = self.connection_manager.lock().await;
if let Some(conn_arc) = conn_manager.get_connection(socket_id, &app_config.id).await {
let mut conn_state_guard = conn_arc.lock().await;
conn_state_guard
.state
.subscribed_channels
.remove(channel_name_str);
if channel_name_str.starts_with("presence-") {
if let Some(presence_map) = conn_state_guard.state.presence.as_mut() {
presence_map.remove(channel_name_str);
}
}
}
}
let current_sub_count = self
.connection_manager
.lock()
.await
.get_channel_socket_count(&app_config.id, channel_name_str)
.await;
if channel_name_str.starts_with("presence-") {
if let Some(user_id_that_left) = user_id_of_socket {
let has_other_connections = self
.user_has_other_connections_in_presence_channel(
&app_config.id,
channel_name_str,
&user_id_that_left,
)
.await?;
if !has_other_connections {
if let Some(webhook_integration_instance) = &self.webhook_integration {
info!(
"Sending member_removed webhook for user {} from channel {}",
user_id_that_left, channel_name_str
);
webhook_integration_instance
.send_member_removed(app_config, channel_name_str, &user_id_that_left)
.await
.ok();
}
let member_removed_msg = PusherMessage::member_removed(
channel_name_str.to_string(),
user_id_that_left.clone(),
);
self.connection_manager
.lock()
.await
.send(
channel_name_str,
member_removed_msg,
Some(socket_id),
&app_config.id,
)
.await?;
}
}
} else if let Some(webhook_integration_instance) = &self.webhook_integration {
info!(
"Sending subscription_count webhook for channel {} (count: {}) after unsubscribe",
channel_name_str, current_sub_count
);
webhook_integration_instance
.send_subscription_count_changed(app_config, channel_name_str, current_sub_count)
.await
.ok();
}
if current_sub_count == 0 {
if let Some(webhook_integration_instance) = &self.webhook_integration {
info!(
"Sending channel_vacated webhook for channel {}",
channel_name_str
);
webhook_integration_instance
.send_channel_vacated(app_config, channel_name_str)
.await
.ok();
}
}
Ok(())
}
pub async fn handle_signin(
&self,
socket_id: &SocketId,
data: PusherMessage,
app_config: &App,
) -> Result<()> {
if !app_config.enable_user_authentication.unwrap_or(false) {
return Err(Error::AuthError(
"User authentication is disabled for this app".into(),
));
}
let message_data_val = data
.data
.ok_or_else(|| Error::AuthError("Missing data in signin message".into()))?;
let (user_data_str, auth_str) = {
let extract_field = |field: &str| -> Result<&str> {
match &message_data_val {
MessageData::Json(json_val) => {
json_val.get(field).and_then(|v| v.as_str()).ok_or_else(|| {
Error::AuthError(format!(
"Missing '{}' field in signin JSON data",
field
))
})
}
MessageData::Structured { extra, .. } => {
extra.get(field).and_then(|v| v.as_str()).ok_or_else(|| {
Error::AuthError(format!(
"Missing '{}' field in signin structured data",
field
))
})
}
MessageData::String(_) => Err(Error::InvalidMessageFormat(
"Signin data should be structured, not a plain string.".into(),
)),
}
};
(extract_field("user_data")?, extract_field("auth")?)
};
let user_info_val: Value = serde_json::from_str(user_data_str)
.map_err(|e| Error::AuthError(format!("Invalid user_data JSON: {}", e)))?;
let user_id = user_info_val
.get("id")
.and_then(|id| id.as_str())
.ok_or_else(|| Error::AuthError("Missing 'id' field in user_data".into()))?
.to_string();
let watchlist = user_info_val
.get("watchlist")
.and_then(|w| w.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect::<Vec<String>>()
});
let auth_validator = AuthValidator::new(self.app_manager.clone());
let is_valid_auth = auth_validator
.validate_channel_auth(socket_id.clone(), &app_config.key, user_data_str, auth_str)
.await?;
if !is_valid_auth {
return Err(Error::AuthError(
"Connection not authorized for signin.".into(),
));
}
self.clear_user_authentication_timeout(&app_config.id, socket_id)
.await?;
let mut connection_manager_locked = self.connection_manager.lock().await;
let connection_arc = connection_manager_locked
.get_connection(socket_id, &app_config.id)
.await
.ok_or_else(|| Error::ConnectionNotFound)?;
{
let mut conn_locked = connection_arc.lock().await;
if let Some(handle) = conn_locked.state.auth_timeout_handle.take() {
handle.abort();
}
conn_locked.state.user_id = Some(user_id.clone());
conn_locked.state.user_info = Some(UserInfo {
id: user_id.clone(),
watchlist: watchlist.clone(),
info: Some(user_info_val.clone()),
});
conn_locked.state.user = Some(user_info_val.clone());
}
let temp_socket_tx = {
let mut conn_locked = connection_arc.lock().await;
conn_locked.socket.take()
};
if let Some(socket_tx_val) = temp_socket_tx {
connection_manager_locked
.add_socket(
socket_id.clone(),
socket_tx_val,
app_config.id.as_str(),
&self.app_manager,
)
.await?;
} else {
error!(
"Socket write half was None during signin for socket {}",
socket_id
);
return Err(Error::InternalError(
"Socket state inconsistent during signin".into(),
));
}
connection_manager_locked
.add_user(connection_arc.clone())
.await?;
drop(connection_manager_locked);
let mut watchlist_events_for_user = Vec::new();
let mut watchers_to_notify = Vec::new();
if app_config.enable_watchlist_events.unwrap_or(false) && watchlist.is_some() {
info!(
"Processing watchlist for user {} with {} watched users",
user_id,
watchlist.as_ref().unwrap().len()
);
let events = self
.watchlist_manager
.add_user_with_watchlist(
&app_config.id,
&user_id,
socket_id.clone(),
watchlist.clone(),
)
.await?;
watchlist_events_for_user = events;
watchers_to_notify = self.get_watchers_for_user(&app_config.id, &user_id).await?;
info!(
"User {} signin: sending {} watchlist events to user, notifying {} watchers",
user_id,
watchlist_events_for_user.len(),
watchers_to_notify.len()
);
}
let success_message_val = PusherMessage {
channel: None,
name: None,
event: Some("pusher:signin_success".into()),
data: Some(MessageData::Json(json!({
"user_data": user_data_str,
"auth": auth_str
}))),
};
self.connection_manager
.lock()
.await
.send_message(&app_config.id, socket_id, success_message_val)
.await?;
for event in watchlist_events_for_user {
if let Err(e) = self
.connection_manager
.lock()
.await
.send_message(&app_config.id, socket_id, event)
.await
{
warn!("Failed to send watchlist event to user {}: {}", user_id, e);
}
}
if !watchers_to_notify.is_empty() {
let online_event = PusherMessage::watchlist_online_event(vec![user_id.clone()]);
if let Some(ref metrics) = self.metrics {
let metrics_locked = metrics.lock().await;
let online_event_json = serde_json::to_value(&online_event).map_err(|e| {
Error::InternalError(format!("Failed to serialize event: {}", e))
})?;
let size = utils::data_to_bytes_flexible(vec![online_event_json]);
metrics_locked.mark_ws_message_sent(&app_config.id, size);
}
for watcher_socket_id in watchers_to_notify {
if let Err(e) = self
.connection_manager
.lock()
.await
.send_message(&app_config.id, &watcher_socket_id, online_event.clone())
.await
{
warn!(
"Failed to send online notification to watcher {}: {}",
watcher_socket_id, e
);
}
}
}
info!(
"User {} successfully signed in on socket {}",
user_id, socket_id
);
Ok(())
}
async fn handle_client_event(
&self,
app_config: &App,
socket_id: &SocketId,
event: &str,
channel: Option<&str>,
data: Value,
) -> Result<()> {
let channel_name = channel.ok_or_else(|| {
Error::ClientEventError("Channel name is required for client event".into())
})?;
let max_event_name_len = app_config
.max_event_name_length
.unwrap_or(DEFAULT_EVENT_NAME_MAX_LENGTH as u32);
if event.len() > max_event_name_len as usize {
return Err(Error::InvalidEventName(format!(
"Client event name '{}' exceeds maximum length of {}",
event, max_event_name_len
)));
}
if let Some(max_payload_kb) = app_config.max_event_payload_in_kb {
let payload_size_bytes = utils::data_to_bytes_flexible(vec![data.clone()]);
if payload_size_bytes > (max_payload_kb as usize * 1024) {
return Err(Error::ClientEventError(format!(
"Client event payload size ({} bytes) for event '{}' exceeds limit ({}KB)",
payload_size_bytes, event, max_payload_kb
)));
}
}
if !event.starts_with(CLIENT_EVENT_PREFIX) {
return Err(Error::InvalidEventName(
"Client events must start with 'client-'".into(),
));
}
let max_channel_len = app_config
.max_channel_name_length
.unwrap_or(DEFAULT_CHANNEL_NAME_MAX_LENGTH as u32);
if channel_name.len() > max_channel_len as usize {
return Err(Error::InvalidChannelName(format!(
"Channel name '{}' for client event exceeds maximum length of {}",
channel_name, max_channel_len
)));
}
let channel_type = ChannelType::from_name(channel_name);
if !matches!(channel_type, ChannelType::Private | ChannelType::Presence) {
return Err(Error::ClientEventError(
"Client events can only be sent to private or presence channels".into(),
));
}
if !app_config.enable_client_messages {
return Err(Error::ClientEventError(
"Client events are not enabled for this app".into(),
));
}
let (is_subscribed_globally, user_id_for_webhook) = {
let mut conn_manager_locked = self.connection_manager.lock().await;
let subscribed = conn_manager_locked
.is_in_channel(&app_config.id, channel_name, socket_id)
.await?;
let user_id = if let Some(conn_arc) = conn_manager_locked
.get_connection(socket_id, &app_config.id)
.await
{
conn_arc
.lock()
.await
.state
.presence
.as_ref()
.and_then(|p_map| p_map.get(channel_name))
.map(|pi| pi.user_id.clone())
} else {
None
};
(subscribed, user_id)
};
if !is_subscribed_globally {
return Err(Error::ClientEventError(format!(
"Client {} is not subscribed to channel {} (or subscription check failed)",
socket_id, channel_name
)));
}
let message_to_send = PusherMessage {
channel: Some(channel_name.to_string()),
name: None, event: Some(event.to_string()),
data: Some(MessageData::Json(data.clone())),
};
{
let mut conn_manager_locked = self.connection_manager.lock().await;
conn_manager_locked
.send(
channel_name,
message_to_send.clone(),
Some(socket_id), &app_config.id,
)
.await?;
}
if let Some(webhook_integration_val) = &self.webhook_integration {
let final_user_id_for_webhook = if channel_name.starts_with("presence-") {
user_id_for_webhook.as_deref()
} else {
None
};
webhook_integration_val
.send_client_event(
app_config,
channel_name,
event,
data, Some(socket_id.as_ref()),
final_user_id_for_webhook,
)
.await
.unwrap_or_else(|e| {
warn!(
"Failed to send client_event webhook for {}: {}",
channel_name, e
);
});
}
Ok(())
}
pub async fn send_error(
&self,
app_id: &str,
socket_id: &SocketId,
error: &Error,
channel: Option<String>,
) -> Result<()> {
let error_data = ErrorData {
message: error.to_string(),
code: Some(error.close_code()),
};
let error_message =
PusherMessage::error(error_data.code.unwrap_or(4000), error_data.message, channel);
self.connection_manager
.lock()
.await
.send_message(app_id, socket_id, error_message) .await
}
pub async fn send_connection_established(
&self,
app_id: &str,
socket_id: &SocketId,
) -> Result<()> {
let connection_message = PusherMessage::connection_established(socket_id.0.clone());
self.connection_manager
.lock()
.await
.send_message(app_id, socket_id, connection_message)
.await
}
pub async fn handle_disconnect(&self, app_id: &str, socket_id: &SocketId) -> Result<()> {
info!("Handling disconnect for socket: {}", socket_id);
self.clear_activity_timeout(app_id, socket_id).await?;
self.clear_user_authentication_timeout(app_id, socket_id)
.await?;
if self.client_event_limiters.remove(socket_id).is_some() {
info!(
"Removed client event rate limiter for socket: {}",
socket_id
);
}
let app_config = match self.app_manager.find_by_id(app_id).await? {
Some(app) => app,
None => {
error!("App not found during disconnect: {}", app_id);
let mut conn_manager = self.connection_manager.lock().await;
if let Some(conn_to_cleanup) = conn_manager.get_connection(socket_id, app_id).await
{
conn_manager
.cleanup_connection(app_id, WebSocketRef(conn_to_cleanup))
.await;
}
conn_manager.remove_connection(socket_id, app_id).await.ok();
return Err(Error::ApplicationNotFound);
}
};
let (subscribed_channels_set, user_id_of_disconnected_socket, user_watchlist) = {
let mut connection_manager_locked = self.connection_manager.lock().await;
match connection_manager_locked
.get_connection(socket_id, app_id)
.await
{
Some(conn_val_arc) => {
let conn_locked = conn_val_arc.lock().await;
if let Some(handle) = &conn_locked.state.activity_timeout_handle {
handle.abort();
}
if let Some(handle) = &conn_locked.state.auth_timeout_handle {
handle.abort();
}
let watchlist = conn_locked
.state
.user_info
.as_ref()
.and_then(|ui| ui.watchlist.clone());
(
conn_locked.state.subscribed_channels.clone(),
conn_locked.state.user_id.clone(),
watchlist,
)
}
None => {
warn!(
"No connection found for socket during disconnect: {}. Already cleaned up?",
socket_id
);
return Ok(());
}
}
};
if !subscribed_channels_set.is_empty() {
info!(
"Processing {} channels for disconnecting socket: {}",
subscribed_channels_set.len(),
socket_id
);
let channel_manager_locked = self.channel_manager.write().await;
for channel_str in &subscribed_channels_set {
info!(
"Processing channel {} for disconnect of socket {}",
channel_str, socket_id
);
match channel_manager_locked
.unsubscribe(
socket_id.0.as_str(),
channel_str,
app_id,
user_id_of_disconnected_socket.as_deref(),
)
.await
{
Ok(_leave_response) => {
let current_sub_count_after_cm_unsubscribe = self
.connection_manager
.lock()
.await
.get_channel_socket_count(app_id, channel_str)
.await;
if channel_str.starts_with("presence-") {
if let Some(ref disconnected_user_id) = user_id_of_disconnected_socket {
let has_other_connections = self
.user_has_other_connections_in_presence_channel(
app_id,
channel_str,
disconnected_user_id,
)
.await?;
if !has_other_connections {
if let Some(webhook_integration_instance) =
&self.webhook_integration
{
info!(
"Sending member_removed webhook for user {} from channel {}",
disconnected_user_id, channel_str
);
webhook_integration_instance
.send_member_removed(
&app_config,
channel_str,
disconnected_user_id,
)
.await
.ok();
}
let member_removed_msg = PusherMessage::member_removed(
channel_str.to_string(),
disconnected_user_id.clone(),
);
self.connection_manager
.lock()
.await
.send(
channel_str,
member_removed_msg,
Some(socket_id),
app_id,
)
.await
.ok();
}
}
} else {
if let Some(webhook_integration_instance) = &self.webhook_integration {
info!(
"Sending subscription_count webhook for channel {} (count: {}) after disconnect processing",
channel_str, current_sub_count_after_cm_unsubscribe
);
webhook_integration_instance
.send_subscription_count_changed(
&app_config,
channel_str,
current_sub_count_after_cm_unsubscribe,
)
.await
.ok();
}
}
if current_sub_count_after_cm_unsubscribe == 0 {
if let Some(webhook_integration_instance) = &self.webhook_integration {
info!(
"Sending channel_vacated webhook for channel {}",
channel_str
);
webhook_integration_instance
.send_channel_vacated(&app_config, channel_str)
.await
.ok();
}
}
}
Err(e) => {
error!(
"Error unsubscribing socket {} from channel {} during disconnect: {}",
socket_id, channel_str, e
);
}
}
}
}
let mut watchers_to_notify = Vec::new();
let mut offline_notification_count = 0;
if let Some(ref user_id_str) = user_id_of_disconnected_socket {
if app_config.enable_watchlist_events.unwrap_or(false) && user_watchlist.is_some() {
info!(
"Processing watchlist disconnect for user {} on socket {}",
user_id_str, socket_id
);
let offline_events = self
.watchlist_manager
.remove_user_connection(app_id, user_id_str, socket_id)
.await?;
if !offline_events.is_empty() {
watchers_to_notify = self.get_watchers_for_user(app_id, user_id_str).await?;
info!(
"User {} went offline: notifying {} watchers",
user_id_str,
watchers_to_notify.len()
);
for event in offline_events {
let event_json = serde_json::to_string(&event).unwrap_or_default();
let message_size = event_json.len();
for watcher_socket_id in &watchers_to_notify {
match self
.connection_manager
.lock()
.await
.send_message(app_id, watcher_socket_id, event.clone())
.await
{
Ok(_) => {
offline_notification_count += 1;
if let Some(ref metrics) = self.metrics {
let metrics_locked = metrics.lock().await;
metrics_locked.mark_ws_message_sent(app_id, message_size);
}
}
Err(e) => {
warn!(
"Failed to send offline notification to watcher {}: {}",
watcher_socket_id, e
);
}
}
}
}
}
}
}
{
let mut connection_manager_locked = self.connection_manager.lock().await;
if let Some(conn_to_cleanup) = connection_manager_locked
.get_connection(socket_id, app_id)
.await
{
connection_manager_locked
.cleanup_connection(app_id, WebSocketRef(conn_to_cleanup))
.await;
}
connection_manager_locked
.remove_connection(socket_id, app_id)
.await
.ok();
}
if let Some(ref metrics) = self.metrics {
let metrics_locked = metrics.lock().await;
metrics_locked.mark_disconnection(app_id, socket_id);
info!(
"Disconnect metrics updated for socket {} (sent {} offline notifications)",
socket_id, offline_notification_count
);
}
info!(
"Successfully processed full disconnect for socket: {} (user: {:?}, notified {} watchers)",
socket_id,
user_id_of_disconnected_socket,
watchers_to_notify.len()
);
Ok(())
}
pub async fn send_message(
&self,
app_id: &str,
socket_id: Option<&SocketId>,
message: PusherApiMessage,
channel: &str,
) {
let pusher_message_val = PusherMessage {
event: message.name,
data: message.data.map(|api_data| match api_data {
crate::protocol::messages::ApiMessageData::String(s) => MessageData::String(s),
crate::protocol::messages::ApiMessageData::Json(j) => MessageData::Json(j),
}),
channel: Some(channel.to_string()),
name: None,
};
if let Some(ref metrics) = self.metrics {
let metrics_locked = metrics.lock().await;
let message_size_val =
serde_json::to_string(&pusher_message_val).map_or(0, |s| s.len());
metrics_locked.mark_ws_message_sent(app_id, message_size_val);
}
if let Err(e) = self
.connection_manager
.lock()
.await
.send(channel, pusher_message_val, socket_id, app_id)
.await
{
error!("Failed to send message to channel {}: {:?}", channel, e);
} else {
info!(
"Message sent to channel {} successfully (via HTTP API path)",
channel
);
}
}
async fn get_watchers_for_user(&self, app_id: &str, user_id: &str) -> Result<Vec<SocketId>> {
let mut watcher_sockets = Vec::new();
let watchers = self
.watchlist_manager
.get_watchers_for_user(app_id, user_id)
.await?;
let mut connection_manager_locked = self.connection_manager.lock().await;
for watcher_user_id in watchers {
let user_sockets = connection_manager_locked
.get_user_sockets(&watcher_user_id, app_id)
.await?;
for socket_ref in user_sockets {
let socket_guard = socket_ref.0.lock().await;
watcher_sockets.push(socket_guard.state.socket_id.clone());
}
}
Ok(watcher_sockets)
}
async fn set_activity_timeout(&self, app_id: &str, socket_id: &SocketId) -> Result<()> {
let socket_id_clone = socket_id.clone();
let app_id_clone = app_id.to_string();
let connection_manager = self.connection_manager.clone();
self.clear_activity_timeout(app_id, socket_id).await?;
let timeout_handle = tokio::spawn(async move {
sleep(Duration::from_secs(120)).await;
let mut conn_manager = connection_manager.lock().await;
if let Some(conn) = conn_manager
.get_connection(&socket_id_clone, &app_id_clone)
.await
{
let mut ws = conn.lock().await;
let error_frame = Frame::text(Payload::from(
serde_json::to_string(&PusherMessage::error(
4201,
"Pong reply not received in time".to_string(),
None,
))
.unwrap_or_default()
.into_bytes(),
));
ws.message_sender.send(error_frame);
ws.close(4201, "Activity timeout".to_string()).await;
}
});
let mut conn_manager = self.connection_manager.lock().await;
if let Some(conn) = conn_manager.get_connection(socket_id, app_id).await {
let mut ws = conn.lock().await;
ws.state.activity_timeout_handle = Some(timeout_handle);
}
Ok(())
}
async fn clear_activity_timeout(&self, app_id: &str, socket_id: &SocketId) -> Result<()> {
let mut conn_manager = self.connection_manager.lock().await;
if let Some(conn) = conn_manager.get_connection(socket_id, app_id).await {
let mut ws = conn.lock().await;
if let Some(handle) = ws.state.activity_timeout_handle.take() {
handle.abort();
}
}
Ok(())
}
async fn update_activity_timeout(&self, app_id: &str, socket_id: &SocketId) -> Result<()> {
self.set_activity_timeout(app_id, socket_id).await
}
async fn set_user_authentication_timeout(
&self,
app_id: &str,
socket_id: &SocketId,
timeout_seconds: u64,
) -> Result<()> {
let socket_id_clone = socket_id.clone();
let app_id_clone = app_id.to_string();
let connection_manager = self.connection_manager.clone();
self.clear_user_authentication_timeout(app_id, socket_id)
.await?;
let timeout_handle = tokio::spawn(async move {
sleep(Duration::from_secs(timeout_seconds)).await;
let mut conn_manager = connection_manager.lock().await;
if let Some(conn) = conn_manager
.get_connection(&socket_id_clone, &app_id_clone)
.await
{
let mut ws = conn.lock().await;
if ws.state.user.is_none() {
let error_frame = Frame::text(Payload::from(
serde_json::to_string(&PusherMessage::error(
4009,
"Connection not authorized within timeout.".to_string(),
None,
))
.unwrap_or_default()
.into_bytes(),
));
let _ = ws.message_sender.send(error_frame);
let _ = ws.close(4009, "Authentication timeout".to_string()).await;
}
}
});
let mut conn_manager = self.connection_manager.lock().await;
if let Some(conn) = conn_manager.get_connection(socket_id, app_id).await {
let mut ws = conn.lock().await;
ws.state.auth_timeout_handle = Some(timeout_handle);
}
Ok(())
}
async fn clear_user_authentication_timeout(
&self,
app_id: &str,
socket_id: &SocketId,
) -> Result<()> {
let mut conn_manager = self.connection_manager.lock().await;
if let Some(conn) = conn_manager.get_connection(socket_id, app_id).await {
let mut ws = conn.lock().await;
if let Some(handle) = ws.state.auth_timeout_handle.take() {
handle.abort();
}
}
Ok(())
}
}