use async_trait::async_trait;
#[cfg(feature = "bebop")]
use bebop::Record;
use dashmap::DashMap;
use std::collections::VecDeque;
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio_tungstenite::tungstenite::Message;
#[cfg(feature = "bebop")]
use crate::schema::Data;
use std::sync::Arc;
use crate::{
client_sender::ServerOptions,
helpers::{
common::make_disconnect_message, metrics::Metrics, retry::ExponentialBackoff,
traits::date_time::now,
},
log_debug, log_error,
};
use super::{common::make_expired_output_message, types::RwClientSenders};
pub struct ClientSenders {
clients: DashMap<String, ClientSender>,
handle_message_sx: Sender<(Vec<u8>, String)>,
handle_message_rx: std::sync::Mutex<Option<Receiver<(Vec<u8>, String)>>>,
options: std::sync::RwLock<ServerOptions>,
pub metrics: Arc<Metrics>,
spillover: std::sync::Mutex<VecDeque<(Vec<u8>, String)>>,
spillover_buffer_size: usize,
}
impl Default for ClientSenders {
fn default() -> Self {
Self::new()
}
}
impl ClientSenders {
pub fn new() -> Self {
Self::new_with_buffer_size(1024, 1024)
}
pub fn new_with_buffer_size(handler_buffer_size: usize, spillover_buffer_size: usize) -> Self {
let (handle_message_sx, handle_message_rx) = mpsc::channel(handler_buffer_size);
Self {
clients: DashMap::new(),
handle_message_sx,
handle_message_rx: std::sync::Mutex::new(Some(handle_message_rx)),
options: std::sync::RwLock::new(ServerOptions::default()),
metrics: Arc::new(Metrics::new()),
spillover: std::sync::Mutex::new(VecDeque::new()),
spillover_buffer_size,
}
}
pub async fn add(&self, peer: &str, sx: Sender<Message>) {
log_debug!(
"Add peer: {:?}, exists: {:?}",
peer,
self.clients.contains_key(peer)
);
if let Some(existing) = self.clients.get(peer) {
let _ = existing.sx.send(make_disconnect_message(peer)).await;
}
self.clients.insert(
peer.to_owned(),
ClientSender {
sx,
send_time: now().timestamp(),
},
);
self.metrics.inc_connections_total();
self.metrics.inc_connections_active();
}
pub fn get_handle_message_receiver(&self) -> Option<Receiver<(Vec<u8>, String)>> {
self.handle_message_rx
.lock()
.unwrap_or_else(|e| e.into_inner())
.take()
}
pub fn set_options(&self, options: ServerOptions) {
*self.options.write().unwrap_or_else(|e| e.into_inner()) = options;
}
pub fn options(&self) -> ServerOptions {
self.options
.read()
.unwrap_or_else(|e| e.into_inner())
.clone()
}
pub fn send_handle_message(&self, data: Vec<u8>, peer: &str) {
self.metrics.inc_messages_received();
let mut spillover = self.spillover.lock().unwrap_or_else(|e| e.into_inner());
while let Some(item) = spillover.front().cloned() {
match self.handle_message_sx.try_send(item) {
Ok(()) => {
spillover.pop_front();
}
Err(_) => break,
}
}
if spillover.is_empty() {
match self.handle_message_sx.try_send((data, peer.to_owned())) {
Ok(()) => (),
Err(tokio::sync::mpsc::error::TrySendError::Full(item)) => {
if spillover.len() < self.spillover_buffer_size {
spillover.push_back(item);
} else {
self.metrics.inc_messages_dropped();
}
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
log_error!("Handle message channel closed");
}
}
} else {
if spillover.len() < self.spillover_buffer_size {
spillover.push_back((data, peer.to_owned()));
} else {
self.metrics.inc_messages_dropped();
}
}
}
pub fn check_client_send_time(&self) {
let now = now().timestamp();
let timeout = self
.options
.read()
.unwrap_or_else(|e| e.into_inner())
.client_timeout_seconds as i64;
self.clients.retain(|_, client| {
let keep = client.send_time + timeout >= now;
if !keep {
self.metrics.dec_connections_active();
}
keep
});
}
pub fn remove(&self, peer: &str) {
if self.clients.remove(peer).is_some() {
self.metrics.dec_connections_active();
}
log_debug!("Remove peer: {:?}", peer);
}
pub fn write_time(&self, peer: &str) {
if let Some(mut client) = self.clients.get_mut(peer) {
client.write_time();
}
}
pub async fn send(&self, peer: &str, message: Message) -> bool {
let sender = {
let Some(client) = self.clients.get(peer) else {
return false;
};
client.sx.clone()
};
let mut backoff = ExponentialBackoff::default();
loop {
match sender.send(message.clone()).await {
Ok(_) => {
self.metrics.inc_messages_sent();
return true;
}
Err(e) => {
log_error!(
"Error sending message (attempt {}): {:?}",
backoff.count() + 1,
e
);
if !backoff.wait().await {
self.metrics.inc_send_errors();
log_error!("Failed to send after {} retries", backoff.count());
return false;
}
}
}
}
}
pub fn is_active(&self, peer: &str) -> bool {
self.clients.contains_key(peer)
}
pub fn len(&self) -> usize {
self.clients.len()
}
pub fn is_empty(&self) -> bool {
self.clients.is_empty()
}
pub fn peers(&self) -> Vec<String> {
self.clients
.iter()
.map(|entry| entry.key().clone())
.collect()
}
pub fn peers_except(&self, valid: &std::collections::HashSet<&String>) -> Vec<String> {
self.clients
.iter()
.filter(|entry| !valid.contains(entry.key()))
.map(|entry| entry.key().clone())
.collect()
}
pub fn peers_in(&self, target: &std::collections::HashSet<&String>) -> Vec<String> {
self.clients
.iter()
.filter(|entry| target.contains(entry.key()))
.map(|entry| entry.key().clone())
.collect()
}
}
#[async_trait]
pub trait ClientSendersTrait {
async fn add(&self, peer: &str, sx: Sender<Message>);
async fn get_handle_message_receiver(&self) -> Option<Receiver<(Vec<u8>, String)>>;
#[cfg(feature = "bebop")]
async fn send_handle_message(&self, data: Data<'_>, peer: &str);
#[cfg(not(feature = "bebop"))]
async fn send_handle_message(&self, data: Vec<u8>, peer: &str);
async fn send(&self, peer: &str, message: Message) -> bool;
async fn expire_send(&self, peer_list: &[String]);
async fn is_active(&self, peer: &str) -> bool;
async fn send_message_in_list(&self, peer_list: &[String], message: Message);
async fn send_all(&self, message: Message);
async fn send_all_in_list(&self, peer_list: &[String], message: Message);
}
#[async_trait]
impl ClientSendersTrait for RwClientSenders {
async fn add(&self, peer: &str, sx: Sender<Message>) {
(**self).add(peer, sx).await;
}
async fn get_handle_message_receiver(&self) -> Option<Receiver<(Vec<u8>, String)>> {
(**self).get_handle_message_receiver()
}
#[cfg(feature = "bebop")]
async fn send_handle_message(&self, data: Data<'_>, peer: &str) {
let mut buf = Vec::with_capacity(256);
if let Err(e) = data.serialize(&mut buf) {
log_error!("Failed to serialize data: {:?}", e);
return;
}
(**self).send_handle_message(buf, peer);
}
#[cfg(not(feature = "bebop"))]
async fn send_handle_message(&self, data: Vec<u8>, peer: &str) {
(**self).send_handle_message(data, peer);
}
async fn send(&self, peer: &str, message: Message) -> bool {
let result = (**self).send(peer, message).await;
match result {
true => (**self).write_time(peer),
false => (**self).remove(peer),
}
result
}
async fn expire_send(&self, peer_list: &[String]) {
use std::collections::HashSet;
let valid_peers: HashSet<&String> = peer_list.iter().collect();
let peers_to_expire = (**self).peers_except(&valid_peers);
for peer in peers_to_expire {
self.send(&peer, make_expired_output_message()).await;
}
}
async fn is_active(&self, peer: &str) -> bool {
(**self).is_active(peer)
}
async fn send_message_in_list(&self, peer_list: &[String], message: Message) {
use std::collections::HashSet;
let target_peers: HashSet<&String> = peer_list.iter().collect();
let peers = (**self).peers_in(&target_peers);
let futures: Vec<_> = peers
.iter()
.map(|peer| self.send(peer, message.clone()))
.collect();
futures_util::future::join_all(futures).await;
}
async fn send_all(&self, message: Message) {
let all_peers = (**self).peers();
let futures: Vec<_> = all_peers
.iter()
.map(|peer| self.send(peer, message.clone()))
.collect();
futures_util::future::join_all(futures).await;
}
async fn send_all_in_list(&self, peer_list: &[String], message: Message) {
use std::collections::HashSet;
let target_peers: HashSet<&String> = peer_list.iter().collect();
let peers = (**self).peers_in(&target_peers);
let futures: Vec<_> = peers
.iter()
.map(|peer| self.send(peer, message.clone()))
.collect();
futures_util::future::join_all(futures).await;
}
}
#[derive(Debug, Clone)]
struct ClientSender {
sx: Sender<Message>,
send_time: i64,
}
impl ClientSender {
pub fn write_time(&mut self) {
self.send_time = now().timestamp();
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_tungstenite::tungstenite::Bytes;
fn create_test_client_senders() -> ClientSenders {
ClientSenders::new()
}
#[test]
fn test_client_senders_new() {
let senders = create_test_client_senders();
assert!(senders.is_empty());
assert_eq!(senders.len(), 0);
}
#[test]
fn test_client_senders_default() {
let senders = ClientSenders::default();
assert!(senders.is_empty());
assert_eq!(senders.len(), 0);
}
#[tokio::test]
async fn test_client_senders_add() {
let senders = create_test_client_senders();
let (tx, _rx) = mpsc::channel(8);
senders.add("peer1", tx).await;
assert_eq!(senders.len(), 1);
assert!(senders.is_active("peer1"));
assert!(!senders.is_empty());
}
#[tokio::test]
async fn test_client_senders_add_multiple() {
let senders = create_test_client_senders();
let (tx1, _rx1) = mpsc::channel(8);
let (tx2, _rx2) = mpsc::channel(8);
senders.add("peer1", tx1).await;
senders.add("peer2", tx2).await;
assert_eq!(senders.len(), 2);
assert!(senders.is_active("peer1"));
assert!(senders.is_active("peer2"));
}
#[tokio::test]
async fn test_client_senders_remove() {
let senders = create_test_client_senders();
let (tx, _rx) = mpsc::channel(8);
senders.add("peer1", tx).await;
assert_eq!(senders.len(), 1);
senders.remove("peer1");
assert_eq!(senders.len(), 0);
assert!(!senders.is_active("peer1"));
}
#[tokio::test]
async fn test_client_senders_peers() {
let senders = create_test_client_senders();
let (tx1, _rx1) = mpsc::channel(8);
let (tx2, _rx2) = mpsc::channel(8);
senders.add("peer1", tx1).await;
senders.add("peer2", tx2).await;
let peers = senders.peers();
assert_eq!(peers.len(), 2);
assert!(peers.contains(&"peer1".to_string()));
assert!(peers.contains(&"peer2".to_string()));
}
#[test]
fn test_client_senders_is_active_nonexistent() {
let senders = create_test_client_senders();
assert!(!senders.is_active("nonexistent"));
}
#[tokio::test]
async fn test_client_senders_send_success() {
let senders = create_test_client_senders();
let (tx, mut rx) = mpsc::channel(8);
senders.add("peer1", tx).await;
let msg = Message::Binary(Bytes::from_static(b"test"));
let result = senders.send("peer1", msg).await;
assert!(result);
let received = rx.recv().await;
assert!(received.is_some());
}
#[tokio::test]
async fn test_client_senders_send_nonexistent_peer() {
let senders = create_test_client_senders();
let msg = Message::Binary(Bytes::from_static(b"test"));
let result = senders.send("nonexistent", msg).await;
assert!(!result);
}
#[test]
fn test_client_senders_get_handle_message_receiver() {
let senders = create_test_client_senders();
let rx = senders.get_handle_message_receiver();
assert!(rx.is_some(), "First call should return Some");
let rx2 = senders.get_handle_message_receiver();
assert!(rx2.is_none(), "Second call should return None");
}
#[tokio::test]
async fn test_client_senders_send_handle_message() {
let senders = create_test_client_senders();
let mut rx = senders.get_handle_message_receiver().expect("receiver");
senders.send_handle_message(vec![1, 2, 3], "peer1");
let received = rx.recv().await;
assert!(received.is_some());
let (data, peer) = received.unwrap();
assert_eq!(data, vec![1, 2, 3]);
assert_eq!(peer, "peer1");
}
#[tokio::test]
async fn test_client_senders_write_time() {
let senders = create_test_client_senders();
let (tx, _rx) = mpsc::channel(8);
senders.add("peer1", tx).await;
senders.write_time("peer1");
}
#[tokio::test]
async fn test_client_senders_replace_existing() {
let senders = create_test_client_senders();
let (tx1, mut rx1) = mpsc::channel(8);
let (tx2, _rx2) = mpsc::channel(8);
senders.add("peer1", tx1).await;
senders.add("peer1", tx2).await;
assert_eq!(senders.len(), 1);
let msg = rx1.recv().await;
assert!(msg.is_some());
}
#[test]
fn test_client_sender_write_time() {
let (tx, _rx) = mpsc::channel(8);
let mut sender = ClientSender {
sx: tx,
send_time: 0,
};
assert_eq!(sender.send_time, 0);
sender.write_time();
assert!(sender.send_time > 0);
}
#[tokio::test]
async fn test_check_client_send_time_removes_inactive_clients() {
let senders = create_test_client_senders();
let (tx, _rx) = mpsc::channel(8);
senders.add("peer1", tx).await;
senders.clients.get_mut("peer1").unwrap().send_time = 0;
assert!(senders.is_active("peer1"));
senders.check_client_send_time();
assert!(!senders.is_active("peer1"));
assert_eq!(senders.len(), 0);
}
#[tokio::test]
async fn test_check_client_send_time_keeps_active_clients() {
let senders = create_test_client_senders();
let (tx, _rx) = mpsc::channel(8);
senders.add("peer1", tx).await;
senders.write_time("peer1");
senders.check_client_send_time();
assert!(senders.is_active("peer1"));
assert_eq!(senders.len(), 1);
}
#[tokio::test]
async fn test_check_client_send_time_mixed_clients() {
let senders = create_test_client_senders();
let (tx1, _rx1) = mpsc::channel(8);
let (tx2, _rx2) = mpsc::channel(8);
let (tx3, _rx3) = mpsc::channel(8);
senders.add("inactive1", tx1).await;
senders.add("active", tx2).await;
senders.add("inactive2", tx3).await;
senders.clients.get_mut("inactive1").unwrap().send_time = 0;
senders.clients.get_mut("inactive2").unwrap().send_time = 0;
senders.write_time("active");
assert_eq!(senders.len(), 3);
senders.check_client_send_time();
assert!(!senders.is_active("inactive1"));
assert!(senders.is_active("active"));
assert!(!senders.is_active("inactive2"));
assert_eq!(senders.len(), 1);
}
#[test]
fn test_check_client_send_time_empty_clients() {
let senders = create_test_client_senders();
senders.check_client_send_time();
assert_eq!(senders.len(), 0);
}
#[tokio::test]
async fn test_write_time_updates_timestamp_correctly() {
let senders = create_test_client_senders();
let (tx, _rx) = mpsc::channel(8);
senders.add("peer1", tx).await;
let initial_time = senders.clients.get("peer1").unwrap().send_time;
let now_ts = crate::helpers::traits::date_time::now().timestamp();
assert!((initial_time - now_ts).abs() <= 1);
senders.write_time("peer1");
let updated_time = senders.clients.get("peer1").unwrap().send_time;
assert!(updated_time > 0);
let now = crate::helpers::traits::date_time::now().timestamp();
assert!((updated_time - now).abs() <= 1);
}
#[tokio::test]
async fn test_write_time_nonexistent_peer_no_panic() {
let senders = create_test_client_senders();
senders.write_time("nonexistent");
}
fn create_rw_client_senders() -> RwClientSenders {
Arc::new(ClientSenders::new())
}
#[tokio::test]
async fn test_trait_send_all_broadcasts_to_all_clients() {
let senders = create_rw_client_senders();
let (tx1, mut rx1) = mpsc::channel(8);
let (tx2, mut rx2) = mpsc::channel(8);
let (tx3, mut rx3) = mpsc::channel(8);
senders.add("peer1", tx1).await;
senders.add("peer2", tx2).await;
senders.add("peer3", tx3).await;
let msg = Message::Binary(Bytes::from_static(b"broadcast"));
senders.send_all(msg).await;
let recv1 = rx1.recv().await;
let recv2 = rx2.recv().await;
let recv3 = rx3.recv().await;
assert!(recv1.is_some());
assert!(recv2.is_some());
assert!(recv3.is_some());
}
#[tokio::test]
async fn test_trait_send_all_empty_clients() {
let senders = create_rw_client_senders();
let msg = Message::Binary(Bytes::from_static(b"broadcast"));
senders.send_all(msg).await;
}
#[tokio::test]
async fn test_trait_send_all_in_list_filters_correctly() {
let senders = create_rw_client_senders();
let (tx1, mut rx1) = mpsc::channel(8);
let (tx2, mut rx2) = mpsc::channel(8);
let (tx3, mut rx3) = mpsc::channel(8);
senders.add("peer1", tx1).await;
senders.add("peer2", tx2).await;
senders.add("peer3", tx3).await;
let target_list = vec!["peer1".to_string(), "peer3".to_string()];
let msg = Message::Binary(Bytes::from_static(b"filtered"));
senders.send_all_in_list(&target_list, msg).await;
let recv1 = tokio::time::timeout(std::time::Duration::from_millis(100), rx1.recv()).await;
let recv2 = tokio::time::timeout(std::time::Duration::from_millis(100), rx2.recv()).await;
let recv3 = tokio::time::timeout(std::time::Duration::from_millis(100), rx3.recv()).await;
assert!(recv1.is_ok() && recv1.unwrap().is_some());
assert!(recv2.is_err() || recv2.unwrap().is_none()); assert!(recv3.is_ok() && recv3.unwrap().is_some());
}
#[tokio::test]
async fn test_trait_send_message_in_list_filters_by_existing_peers() {
let senders = create_rw_client_senders();
let (tx1, mut rx1) = mpsc::channel(8);
let (tx2, mut rx2) = mpsc::channel(8);
senders.add("peer1", tx1).await;
senders.add("peer2", tx2).await;
let target_list = vec![
"peer1".to_string(),
"peer3".to_string(), "peer4".to_string(), ];
let msg = Message::Binary(Bytes::from_static(b"test"));
senders.send_message_in_list(&target_list, msg).await;
let recv1 = tokio::time::timeout(std::time::Duration::from_millis(100), rx1.recv()).await;
let recv2 = tokio::time::timeout(std::time::Duration::from_millis(100), rx2.recv()).await;
assert!(recv1.is_ok() && recv1.unwrap().is_some());
assert!(recv2.is_err() || recv2.unwrap().is_none());
}
#[tokio::test]
async fn test_trait_expire_send_sends_to_unlisted_peers() {
let senders = create_rw_client_senders();
let (tx1, mut rx1) = mpsc::channel(8);
let (tx2, mut rx2) = mpsc::channel(8);
let (tx3, mut rx3) = mpsc::channel(8);
senders.add("peer1", tx1).await;
senders.add("peer2", tx2).await;
senders.add("peer3", tx3).await;
let valid_list = vec!["peer2".to_string()];
senders.expire_send(&valid_list).await;
let recv1 = tokio::time::timeout(std::time::Duration::from_millis(100), rx1.recv()).await;
let recv2 = tokio::time::timeout(std::time::Duration::from_millis(100), rx2.recv()).await;
let recv3 = tokio::time::timeout(std::time::Duration::from_millis(100), rx3.recv()).await;
assert!(recv1.is_ok() && recv1.unwrap().is_some()); assert!(recv2.is_err() || recv2.unwrap().is_none()); assert!(recv3.is_ok() && recv3.unwrap().is_some()); }
#[tokio::test]
async fn test_trait_is_active_through_rwlock() {
let senders = create_rw_client_senders();
let (tx, _rx) = mpsc::channel(8);
assert!(!senders.is_active("peer1").await);
senders.add("peer1", tx).await;
assert!(senders.is_active("peer1").await);
assert!(!senders.is_active("peer2").await);
}
#[tokio::test]
async fn test_trait_send_updates_time_on_success() {
let senders = create_rw_client_senders();
let (tx, mut rx) = mpsc::channel(8);
senders.add("peer1", tx).await;
let initial_time = {
let time = senders.clients.get("peer1").unwrap().send_time;
assert!(time > 0);
time
};
let msg = Message::Binary(Bytes::from_static(b"test"));
let result = senders.send("peer1", msg).await;
assert!(result);
let _ = rx.recv().await;
{
let time = senders.clients.get("peer1").unwrap().send_time;
assert!(time >= initial_time);
}
}
#[tokio::test]
async fn test_trait_send_removes_peer_on_failure() {
let senders = create_rw_client_senders();
let (tx, rx) = mpsc::channel(1);
senders.add("peer1", tx).await;
assert!(senders.is_active("peer1").await);
drop(rx);
let msg = Message::Binary(Bytes::from_static(b"test"));
let result = senders.send("peer1", msg).await;
assert!(!result);
assert!(!senders.is_active("peer1").await);
}
}