use std::collections::HashMap;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use rocketmq_common::common::server::config::ServerConfig;
use rocketmq_rust::wait_for_signal;
use rocketmq_rust::ArcMut;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tokio::sync::OwnedSemaphorePermit;
use tokio::sync::Semaphore;
use tokio::time;
use tracing::error;
use tracing::info;
use tracing::warn;
use crate::base::channel_event_listener::ChannelEventListener;
use crate::base::connection_net_event::ConnectionNetEvent;
use crate::base::tokio_event::TokioEvent;
use crate::net::channel::Channel;
use crate::net::channel::ChannelInner;
use crate::remoting::inner::RemotingGeneralHandler;
use crate::runtime::connection_handler_context::ConnectionHandlerContext;
use crate::runtime::connection_handler_context::ConnectionHandlerContextWrapper;
use crate::runtime::processor::RequestProcessor;
use crate::runtime::RPCHook;
use crate::tls::TlsServerRuntime;
const DEFAULT_MAX_CONNECTIONS: usize = 1000;
const DEFAULT_CHANNEL_IDLE_TIMEOUT_SECONDS: u64 = 120;
pub struct ConnectionHandler<RP> {
connection_handler_context: ConnectionHandlerContext,
shutdown: Shutdown,
_shutdown_complete: mpsc::Sender<()>,
conn_disconnect_notify: Option<broadcast::Sender<SocketAddr>>,
cmd_handler: ArcMut<RemotingGeneralHandler<RP>>,
event_tx: Option<mpsc::UnboundedSender<TokioEvent>>,
idle_timeout: Duration,
}
impl<RP> Drop for ConnectionHandler<RP> {
fn drop(&mut self) {
if let Some(ref sender) = self.conn_disconnect_notify {
let socket_addr = self.connection_handler_context.remote_address();
warn!("connection[{}] disconnected, Send notify message.", socket_addr);
let _ = sender.send(socket_addr);
}
}
}
impl<RP: RequestProcessor + Sync + 'static> ConnectionHandler<RP> {
#[inline]
async fn handle(&mut self) -> rocketmq_error::RocketMQResult<()> {
let idle_timeout = self.idle_timeout;
let remote_addr = self.connection_handler_context.remote_address();
while !self.shutdown.is_shutdown {
let channel = self.connection_handler_context.channel_mut();
let frame = tokio::select! {
res = channel.connection_mut().receive_command() => res,
_ = self.shutdown.recv() => {
channel.connection_mut().close();
return Ok(());
}
_ = tokio::time::sleep(idle_timeout) => {
warn!(
"Connection idle timeout ({}s), remote: {}",
idle_timeout.as_secs(),
remote_addr
);
let channel_clone = channel.clone();
if let Some(ref event_tx) = self.event_tx {
let _ = event_tx.send(TokioEvent::new(
ConnectionNetEvent::IDLE,
remote_addr,
channel_clone,
));
}
channel.connection_mut().close();
return Ok(());
}
};
let cmd = match frame {
Some(Ok(frame)) => frame,
Some(Err(e)) => {
error!("Failed to decode command: {:?}", e);
let channel_clone = channel.clone();
if let Some(ref event_tx) = self.event_tx {
let _ = event_tx.send(TokioEvent::new(
ConnectionNetEvent::EXCEPTION,
remote_addr,
channel_clone,
));
}
channel.connection_mut().close();
return Err(e);
}
None => {
return Ok(());
}
};
self.cmd_handler
.process_message_received(&mut self.connection_handler_context, cmd)
.await;
}
Ok(())
}
}
struct ConnectionListener<RP> {
listener: TcpListener,
limit_connections: Arc<Semaphore>,
notify_shutdown: broadcast::Sender<()>,
shutdown_complete_tx: mpsc::Sender<()>,
conn_disconnect_notify: Option<broadcast::Sender<SocketAddr>>,
channel_event_listener: Option<Arc<dyn ChannelEventListener>>,
cmd_handler: ArcMut<RemotingGeneralHandler<RP>>,
tls_runtime: TlsServerRuntime,
}
impl<RP: RequestProcessor + Sync + 'static + Clone> ConnectionListener<RP> {
async fn run(&mut self) -> anyhow::Result<()> {
info!("Server ready to accept connections");
let (event_tx, mut event_rx) = tokio::sync::mpsc::unbounded_channel::<TokioEvent>();
if let Some(listener) = self.channel_event_listener.take() {
tokio::spawn(async move {
while let Some(event) = event_rx.recv().await {
let addr = event.remote_addr();
let addr_str = addr.to_string();
match event.type_() {
ConnectionNetEvent::CONNECTED(_) => {
listener.on_channel_connect(&addr_str, event.channel());
}
ConnectionNetEvent::DISCONNECTED => {
listener.on_channel_close(&addr_str, event.channel());
}
ConnectionNetEvent::EXCEPTION => {
listener.on_channel_exception(&addr_str, event.channel());
}
ConnectionNetEvent::IDLE => {
listener.on_channel_idle(&addr_str, event.channel());
}
}
}
info!("Event dispatcher task terminated");
});
}
loop {
let permit = acquire_connection_permit(&self.limit_connections).await?;
let (socket, remote_addr) = self.accept().await?;
if let Err(e) = socket.set_nodelay(true) {
warn!("Failed to set TCP_NODELAY for {}: {}", remote_addr, e);
}
let local_addr = socket.local_addr()?;
info!("Accepted connection: {} → {}", remote_addr, local_addr);
let tls_runtime = self.tls_runtime.clone();
let cmd_handler = self.cmd_handler.clone();
let notify_shutdown = self.notify_shutdown.subscribe();
let shutdown_complete_tx = self.shutdown_complete_tx.clone();
let conn_disconnect_notify = self.conn_disconnect_notify.clone();
let event_tx_clone = event_tx.clone();
tokio::spawn(async move {
let Some(connection) = tls_runtime.into_connection(socket, remote_addr).await else {
drop(permit);
return;
};
let channel_inner = ArcMut::new(ChannelInner::new(connection, cmd_handler.response_table.clone()));
let channel = Channel::new(channel_inner, local_addr, remote_addr);
let _ = event_tx_clone.send(TokioEvent::new(
ConnectionNetEvent::CONNECTED(remote_addr),
remote_addr,
channel.clone(),
));
let idle_timeout = Duration::from_secs(DEFAULT_CHANNEL_IDLE_TIMEOUT_SECONDS);
let handler_event_tx = event_tx_clone.clone();
let handler = ConnectionHandler {
connection_handler_context: ArcMut::new(ConnectionHandlerContextWrapper {
channel: channel.clone(),
}),
shutdown: Shutdown::new(notify_shutdown),
_shutdown_complete: shutdown_complete_tx,
conn_disconnect_notify,
cmd_handler,
event_tx: Some(handler_event_tx),
idle_timeout,
};
let mut handler = handler;
if let Err(err) = handler.handle().await {
error!(
remote_addr = %remote_addr,
error = ?err,
"Connection handler terminated with error"
);
}
let _ = event_tx_clone.send(TokioEvent::new(
ConnectionNetEvent::DISCONNECTED,
remote_addr,
handler.connection_handler_context.channel.clone(),
));
info!("Client {} disconnected", remote_addr);
drop(permit);
});
}
}
async fn accept(&mut self) -> anyhow::Result<(TcpStream, SocketAddr)> {
let mut backoff = 1;
const MAX_BACKOFF: u64 = 64;
loop {
match self.listener.accept().await {
Ok((socket, remote_addr)) => {
return Ok((socket, remote_addr));
}
Err(err) => {
if backoff > MAX_BACKOFF {
error!("Accept failed after {} retries, last error: {}", MAX_BACKOFF, err);
return Err(err.into());
}
warn!("Accept error (will retry in {}s): {}", backoff, err);
}
}
time::sleep(Duration::from_secs(backoff)).await;
backoff *= 2;
}
}
}
async fn acquire_connection_permit(limit_connections: &Arc<Semaphore>) -> anyhow::Result<OwnedSemaphorePermit> {
limit_connections
.clone()
.acquire_owned()
.await
.map_err(|err| anyhow::anyhow!("connection limit semaphore closed: {err}"))
}
pub struct RocketMQServer<RP> {
config: Arc<ServerConfig>,
rpc_hooks: Option<Vec<Arc<dyn RPCHook>>>,
_phantom_data: std::marker::PhantomData<RP>,
}
impl<RP> RocketMQServer<RP> {
pub fn new(config: Arc<ServerConfig>) -> Self {
Self {
config,
rpc_hooks: Some(vec![]),
_phantom_data: std::marker::PhantomData,
}
}
pub fn register_rpc_hook(&mut self, hook: Arc<dyn RPCHook>) {
if let Some(ref mut hooks) = self.rpc_hooks {
hooks.push(hook);
} else {
self.rpc_hooks = Some(vec![hook]);
}
}
}
impl<RP: RequestProcessor + Sync + 'static + Clone> RocketMQServer<RP> {
pub async fn run(&mut self, request_processor: RP, channel_event_listener: Option<Arc<dyn ChannelEventListener>>) {
self.run_with_shutdown(request_processor, channel_event_listener, wait_for_signal())
.await;
}
pub async fn run_with_shutdown<S>(
&mut self,
request_processor: RP,
channel_event_listener: Option<Arc<dyn ChannelEventListener>>,
shutdown: S,
) where
S: Future,
{
let addr = format!("{}:{}", self.config.bind_address, self.config.listen_port);
let listener = match TcpListener::bind(&addr).await {
Ok(listener) => listener,
Err(err) => {
error!(addr = %addr, error = %err, "failed to bind remoting_server");
return;
}
};
let rpc_hooks = self.rpc_hooks.take().unwrap_or_default();
let tls_runtime = TlsServerRuntime::new(self.config.tls_config.clone());
info!("Starting remoting_server at: {}", addr);
let (notify_conn_disconnect, _) = broadcast::channel::<SocketAddr>(100);
run_with_tls_config(
listener,
shutdown,
request_processor,
Some(notify_conn_disconnect),
rpc_hooks,
channel_event_listener,
tls_runtime,
)
.await;
}
}
pub async fn run<RP: RequestProcessor + Sync + 'static + Clone>(
listener: TcpListener,
shutdown: impl Future,
request_processor: RP,
conn_disconnect_notify: Option<broadcast::Sender<SocketAddr>>,
rpc_hooks: Vec<Arc<dyn RPCHook>>,
channel_event_listener: Option<Arc<dyn ChannelEventListener>>,
) {
run_with_tls_config(
listener,
shutdown,
request_processor,
conn_disconnect_notify,
rpc_hooks,
channel_event_listener,
TlsServerRuntime::new(Default::default()),
)
.await;
}
async fn run_with_tls_config<RP: RequestProcessor + Sync + 'static + Clone>(
listener: TcpListener,
shutdown: impl Future,
request_processor: RP,
conn_disconnect_notify: Option<broadcast::Sender<SocketAddr>>,
rpc_hooks: Vec<Arc<dyn RPCHook>>,
channel_event_listener: Option<Arc<dyn ChannelEventListener>>,
tls_runtime: TlsServerRuntime,
) {
let (notify_shutdown, _) = broadcast::channel(1);
let (shutdown_complete_tx, mut shutdown_complete_rx) = mpsc::channel(1);
let handler = RemotingGeneralHandler {
request_processor,
rpc_hooks,
response_table: ArcMut::new(HashMap::with_capacity(512)),
};
let mut listener = ConnectionListener {
listener,
notify_shutdown,
shutdown_complete_tx,
conn_disconnect_notify,
limit_connections: Arc::new(Semaphore::new(DEFAULT_MAX_CONNECTIONS)),
channel_event_listener,
cmd_handler: ArcMut::new(handler),
tls_runtime,
};
tokio::select! {
res = listener.run() => {
if let Err(err) = res {
error!(cause = %err, "failed to accept");
}
}
_ = shutdown => {
info!("Shutdown now.....");
}
}
let ConnectionListener {
shutdown_complete_tx,
notify_shutdown,
..
} = listener;
drop(notify_shutdown);
drop(shutdown_complete_tx);
let _ = shutdown_complete_rx.recv().await;
}
#[derive(Debug)]
pub(crate) struct Shutdown {
is_shutdown: bool,
notify: broadcast::Receiver<()>,
}
impl Shutdown {
pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown {
Shutdown {
is_shutdown: false,
notify,
}
}
pub(crate) fn is_shutdown(&self) -> bool {
self.is_shutdown
}
pub(crate) async fn recv(&mut self) {
if self.is_shutdown {
return;
}
let _ = self.notify.recv().await;
self.is_shutdown = true;
}
}
#[cfg(test)]
mod tests {
use std::future;
use std::sync::Arc;
use rocketmq_common::common::server::config::ServerConfig;
use super::*;
use crate::request_processor::default_request_processor::DefaultRemotingRequestProcessor;
#[tokio::test]
async fn acquire_connection_permit_closed_semaphore_returns_error_without_panicking() {
let semaphore = Arc::new(Semaphore::new(1));
semaphore.close();
let error = match acquire_connection_permit(&semaphore).await {
Ok(_) => panic!("closed semaphore should return an error"),
Err(error) => error,
};
assert!(error.to_string().contains("connection limit semaphore closed"));
}
#[tokio::test]
async fn run_with_shutdown_bind_error_returns_without_panicking() {
let config = Arc::new(ServerConfig {
bind_address: "127.0.0.1".to_string(),
listen_port: 70000,
..ServerConfig::default()
});
let mut server = RocketMQServer::<DefaultRemotingRequestProcessor>::new(config);
server
.run_with_shutdown(DefaultRemotingRequestProcessor, None, future::pending::<()>())
.await;
}
}