#![allow(dead_code)]
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::adapter::ConnectionManager;
use crate::adapter::local_adapter::LocalAdapter;
use crate::channel::PresenceMemberInfo;
use crate::error::{Error, Result};
use crate::metrics::MetricsInterface;
use crate::websocket::SocketId;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, Notify};
use tokio::time::sleep;
use tracing::{info, warn};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RequestType {
ChannelMembers, ChannelSockets, ChannelSocketsCount, SocketExistsInChannel, TerminateUserConnections, ChannelsWithSocketsCount,
Sockets, Channels, SocketsCount, ChannelMembersCount, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RequestBody {
pub request_id: String,
pub node_id: String,
pub app_id: String,
pub request_type: RequestType,
pub channel: Option<String>,
pub socket_id: Option<String>,
pub user_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseBody {
pub request_id: String,
pub node_id: String,
pub app_id: String,
pub members: HashMap<String, PresenceMemberInfo>,
pub channels_with_sockets_count: HashMap<String, usize>,
pub socket_ids: Vec<String>,
pub sockets_count: usize,
pub exists: bool,
pub channels: HashSet<String>,
pub members_count: usize, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BroadcastMessage {
pub node_id: String,
pub app_id: String,
pub channel: String,
pub message: String,
pub except_socket_id: Option<String>,
}
#[derive(Clone)]
pub struct PendingRequest {
pub(crate) start_time: Instant,
pub(crate) app_id: String,
pub(crate) responses: Vec<ResponseBody>,
pub(crate) notify: Arc<Notify>,
}
pub struct HorizontalAdapter {
pub node_id: String,
pub local_adapter: LocalAdapter,
pub pending_requests: DashMap<String, PendingRequest>,
pub requests_timeout: u64,
pub metrics: Option<Arc<Mutex<dyn MetricsInterface + Send + Sync>>>,
}
impl HorizontalAdapter {
pub fn new() -> Self {
Self {
node_id: Uuid::new_v4().to_string(),
local_adapter: LocalAdapter::new(),
pending_requests: DashMap::new(),
requests_timeout: 5000, metrics: None,
}
}
pub fn start_request_cleanup(&mut self) {
let timeout = self.requests_timeout;
let pending_requests_clone = self.pending_requests.clone();
tokio::spawn(async move {
loop {
sleep(Duration::from_millis(1000)).await;
let now = Instant::now();
let mut expired_requests = Vec::new();
for entry in &pending_requests_clone {
let request_id = entry.key();
let request = entry.value();
if now.duration_since(request.start_time).as_millis() > timeout as u128 {
expired_requests.push(request_id.clone());
}
}
for request_id in expired_requests {
warn!("{}", format!("Request {} expired", request_id));
pending_requests_clone.remove(&request_id);
}
}
});
}
pub async fn process_request(&mut self, request: RequestBody) -> Result<ResponseBody> {
info!(
"{}",
format!(
"Processing request from node {}: {:?}",
request.node_id, request.request_type
)
);
if request.node_id == self.node_id {
return Err(Error::OwnRequestIgnored);
}
if let Some(ref metrics) = self.metrics {
let metrics = metrics.lock().await;
metrics.mark_horizontal_adapter_request_received(&request.app_id);
}
let mut response = ResponseBody {
request_id: request.request_id.clone(),
node_id: self.node_id.clone(),
app_id: request.app_id.clone(),
members: HashMap::new(),
socket_ids: Vec::new(),
sockets_count: 0,
channels_with_sockets_count: HashMap::new(),
exists: false,
channels: HashSet::new(),
members_count: 0,
};
match request.request_type {
RequestType::ChannelMembers => {
if let Some(channel) = &request.channel {
let members = self
.local_adapter
.get_channel_members(&request.app_id, channel)
.await?;
response.members = members;
}
}
RequestType::ChannelSockets => {
if let Some(channel) = &request.channel {
let channel_set = self
.local_adapter
.get_channel_sockets(&request.app_id, channel)
.await?;
response.socket_ids = channel_set
.iter()
.map(|socket_id| socket_id.0.clone())
.collect();
}
}
RequestType::ChannelSocketsCount => {
if let Some(channel) = &request.channel {
response.sockets_count = self
.local_adapter
.get_channel_socket_count(&request.app_id, channel)
.await;
}
}
RequestType::SocketExistsInChannel => {
if let (Some(channel), Some(socket_id)) = (&request.channel, &request.socket_id) {
let socket_id = SocketId(socket_id.clone());
response.exists = self
.local_adapter
.is_in_channel(&request.app_id, channel, &socket_id)
.await?;
}
}
RequestType::TerminateUserConnections => {
if let Some(user_id) = &request.user_id {
self.local_adapter
.terminate_user_connections(&request.app_id, user_id)
.await?;
response.exists = true;
}
}
RequestType::ChannelsWithSocketsCount => {
let channels = self
.local_adapter
.get_channels_with_socket_count(&request.app_id)
.await?;
response.channels_with_sockets_count = channels
.iter()
.map(|entry| (entry.key().clone(), *entry.value()))
.collect();
}
RequestType::Sockets => {
let connections = self
.local_adapter
.get_all_connections(&request.app_id)
.await;
response.socket_ids = connections
.iter()
.map(|entry| entry.key().0.clone())
.collect();
response.sockets_count = connections.len();
}
RequestType::Channels => {
let channels = self
.local_adapter
.get_channels_with_socket_count(&request.app_id)
.await?;
response.channels = channels.iter().map(|entry| entry.key().clone()).collect();
}
RequestType::SocketsCount => {
let connections = self
.local_adapter
.get_all_connections(&request.app_id)
.await;
response.sockets_count = connections.len();
}
RequestType::ChannelMembersCount => {
if let Some(channel) = &request.channel {
let members = self
.local_adapter
.get_channel_members(&request.app_id, channel)
.await?;
response.members_count = members.len();
}
}
}
Ok(response)
}
pub async fn process_response(&self, response: ResponseBody) -> Result<()> {
if let Some(metrics_ref) = &self.metrics {
let metrics = metrics_ref.lock().await;
metrics.mark_horizontal_adapter_response_received(&response.app_id);
}
if let Some(mut request) = self.pending_requests.get_mut(&response.request_id) {
request.responses.push(response);
request.notify.notify_one();
}
Ok(())
}
pub async fn send_request(
&mut self,
app_id: &str,
request_type: RequestType,
channel: Option<&str>,
socket_id: Option<&str>,
user_id: Option<&str>,
expected_node_count: usize,
) -> Result<ResponseBody> {
let request_id = Uuid::new_v4().to_string();
let start = Instant::now();
let _request = RequestBody {
request_id: request_id.clone(),
node_id: self.node_id.clone(),
app_id: app_id.to_string(),
request_type: request_type.clone(),
channel: channel.map(String::from),
socket_id: socket_id.map(String::from),
user_id: user_id.map(String::from),
};
self.pending_requests.insert(
request_id.clone(),
PendingRequest {
start_time: start,
app_id: app_id.to_string(),
responses: Vec::with_capacity(expected_node_count.saturating_sub(1)),
notify: Arc::new(Notify::new()),
},
);
if let Some(metrics_ref) = &self.metrics {
let metrics = metrics_ref.lock().await;
metrics.mark_horizontal_adapter_request_sent(app_id);
}
info!(
"Request {} created for type {:?} on app {} - broadcasting handled by adapter",
request_id, request_type, app_id
);
let timeout_duration = Duration::from_millis(self.requests_timeout);
let max_expected_responses = expected_node_count.saturating_sub(1);
if max_expected_responses == 0 {
info!(
"Single node deployment, no responses expected for request {}",
request_id
);
self.pending_requests.remove(&request_id);
return Ok(ResponseBody {
request_id: request_id.clone(),
node_id: self.node_id.clone(),
app_id: app_id.to_string(),
members: HashMap::new(),
socket_ids: Vec::new(),
sockets_count: 0,
channels_with_sockets_count: HashMap::new(),
exists: false,
channels: HashSet::new(),
members_count: 0,
});
}
let check_interval = Duration::from_millis(50);
let mut checks = 0;
let max_checks = (timeout_duration.as_millis() / check_interval.as_millis()) as usize;
let responses = loop {
if checks >= max_checks {
let current_responses = self
.pending_requests
.get(&request_id)
.map(|r| r.responses.len())
.unwrap_or(0);
warn!(
"Request {} timed out after {}ms, got {} responses out of {} expected",
request_id,
start.elapsed().as_millis(),
current_responses,
max_expected_responses
);
break self
.pending_requests
.remove(&request_id)
.map(|(_, req)| req.responses)
.unwrap_or_default();
}
if let Some(pending_request) = self.pending_requests.get(&request_id) {
if pending_request.responses.len() >= max_expected_responses {
info!(
"Request {} completed successfully with {}/{} responses in {}ms",
request_id,
pending_request.responses.len(),
max_expected_responses,
start.elapsed().as_millis()
);
break self
.pending_requests
.remove(&request_id)
.map(|(_, req)| req.responses)
.unwrap_or_default();
}
} else {
return Err(Error::Other(format!(
"Request {} was removed unexpectedly (possibly by cleanup task)",
request_id
)));
}
tokio::time::sleep(check_interval).await;
checks += 1;
};
let combined_response = self.aggregate_responses(
request_id.clone(),
self.node_id.clone(),
app_id.to_string(),
&request_type,
responses,
);
if let Err(e) = self.validate_aggregated_response(&combined_response, &request_type) {
warn!(
"Response validation failed for request {}: {}",
request_id, e
);
}
if let Some(metrics_ref) = &self.metrics {
let metrics = metrics_ref.lock().await;
let duration_ms = start.elapsed().as_millis() as f64;
metrics.track_horizontal_adapter_resolve_time(app_id, duration_ms);
let resolved = combined_response.sockets_count > 0
|| !combined_response.members.is_empty()
|| combined_response.exists
|| !combined_response.channels.is_empty()
|| combined_response.members_count > 0
|| !combined_response.channels_with_sockets_count.is_empty()
|| max_expected_responses == 0;
metrics.track_horizontal_adapter_resolved_promises(app_id, resolved);
}
Ok(combined_response)
}
pub fn aggregate_responses(
&self,
request_id: String,
node_id: String,
app_id: String,
request_type: &RequestType,
responses: Vec<ResponseBody>,
) -> ResponseBody {
let mut combined_response = ResponseBody {
request_id,
node_id,
app_id,
members: HashMap::new(),
socket_ids: Vec::new(),
sockets_count: 0,
channels_with_sockets_count: HashMap::new(),
exists: false,
channels: HashSet::new(),
members_count: 0,
};
if responses.is_empty() {
return combined_response;
}
let mut unique_socket_ids = HashSet::new();
for response in responses {
match request_type {
RequestType::ChannelMembers => {
combined_response.members.extend(response.members);
}
RequestType::ChannelSockets => {
for socket_id in response.socket_ids {
unique_socket_ids.insert(socket_id);
}
}
RequestType::ChannelSocketsCount => {
combined_response.sockets_count += response.sockets_count;
}
RequestType::SocketExistsInChannel => {
combined_response.exists = combined_response.exists || response.exists;
}
RequestType::TerminateUserConnections => {
combined_response.exists = combined_response.exists || response.exists;
}
RequestType::ChannelsWithSocketsCount => {
for (channel, socket_count) in response.channels_with_sockets_count {
*combined_response
.channels_with_sockets_count
.entry(channel)
.or_insert(0) += socket_count;
}
}
RequestType::Sockets => {
for socket_id in response.socket_ids {
unique_socket_ids.insert(socket_id);
}
combined_response.sockets_count += response.sockets_count;
}
RequestType::Channels => {
combined_response.channels.extend(response.channels);
}
RequestType::SocketsCount => {
combined_response.sockets_count += response.sockets_count;
}
RequestType::ChannelMembersCount => {
combined_response.members_count += response.members_count;
}
}
}
if matches!(
request_type,
RequestType::ChannelSockets | RequestType::Sockets
) {
combined_response.socket_ids = unique_socket_ids.into_iter().collect();
if matches!(request_type, RequestType::ChannelSockets) {
combined_response.sockets_count = combined_response.socket_ids.len();
}
}
combined_response
}
pub fn validate_aggregated_response(
&self,
response: &ResponseBody,
request_type: &RequestType,
) -> Result<()> {
match request_type {
RequestType::ChannelSocketsCount | RequestType::SocketsCount => {
if response.sockets_count == 0 && !response.socket_ids.is_empty() {
warn!("Inconsistent response: sockets_count is 0 but socket_ids is not empty");
}
}
RequestType::ChannelMembersCount => {
if response.members_count == 0 && !response.members.is_empty() {
warn!("Inconsistent response: members_count is 0 but members map is not empty");
}
}
RequestType::ChannelsWithSocketsCount => {
let total_from_channels: usize =
response.channels_with_sockets_count.values().sum();
if total_from_channels == 0 && response.sockets_count > 0 {
warn!("Inconsistent response: channels show 0 sockets but sockets_count > 0");
}
}
_ => {} }
Ok(())
}
}