use std::net::SocketAddr;
use rocketmq_common::common::tls_config::TlsConfig;
use rocketmq_error::RocketMQResult;
use rocketmq_rust::ArcMut;
use tokio::net::TcpStream;
use tokio::sync::broadcast;
use tokio::sync::mpsc::Receiver;
use crate::base::connection_net_event::ConnectionNetEvent;
use crate::base::response_future::ResponseFuture;
use crate::connection::Connection;
use crate::error_helpers::connection_invalid;
use crate::error_helpers::io_error;
use crate::error_helpers::remote_error;
use crate::net::channel::Channel;
use crate::net::channel::ChannelInner;
use crate::protocol::remoting_command::RemotingCommand;
use crate::remoting::inner::RemotingGeneralHandler;
use crate::remoting_server::rocketmq_tokio_server::Shutdown;
use crate::runtime::connection_handler_context::ConnectionHandlerContext;
use crate::runtime::connection_handler_context::ConnectionHandlerContextWrapper;
use crate::runtime::processor::RequestProcessor;
#[cfg(feature = "tls")]
use crate::tls::connect_tls_stream;
#[cfg(not(feature = "tls"))]
use crate::tls::tls_disabled_error;
#[cfg(not(feature = "tls"))]
use crate::tls::TLS_DISABLED_ERROR_REASON;
#[derive(Clone)]
pub struct Client<PR> {
inner: ArcMut<ClientInner<PR>>,
notify_shutdown: broadcast::Sender<()>,
tx: tokio::sync::mpsc::Sender<SendMessage>,
}
type SendMessage = (
RemotingCommand,
Option<tokio::sync::oneshot::Sender<RocketMQResult<RemotingCommand>>>,
Option<u64>,
);
struct ClientInner<PR> {
cmd_handler: ArcMut<RemotingGeneralHandler<PR>>,
ctx: ConnectionHandlerContext,
shutdown: Shutdown,
}
impl<PR> ClientInner<PR>
where
PR: RequestProcessor + Sync + 'static,
{
pub async fn connect(
addr: String,
cmd_handler: ArcMut<RemotingGeneralHandler<PR>>,
tx: Option<&tokio::sync::broadcast::Sender<ConnectionNetEvent>>,
notify: broadcast::Receiver<()>,
tls_config: TlsConfig,
) -> RocketMQResult<(tokio::sync::mpsc::Sender<SendMessage>, ArcMut<ClientInner<PR>>)> {
let stream = TcpStream::connect(addr.as_str()).await.map_err(io_error)?;
let local_addr = stream.local_addr()?;
let remote_address = stream.peer_addr()?;
let connection = if tls_config.enable {
#[cfg(feature = "tls")]
{
let server_name = server_name_from_addr(addr.as_str());
let tls_stream = connect_tls_stream(stream, &server_name, &tls_config).await?;
Connection::new_with_stream(tls_stream)
}
#[cfg(not(feature = "tls"))]
{
let _ = stream;
debug_assert_eq!(
TLS_DISABLED_ERROR_REASON,
"rocketmq-remoting was compiled without the tls feature"
);
return Err(tls_disabled_error());
}
} else {
Connection::new(stream)
};
let channel_inner = ArcMut::new(ChannelInner::new(connection, cmd_handler.response_table.clone()));
let channel = Channel::new(channel_inner, local_addr, remote_address);
let (tx_, rx) = tokio::sync::mpsc::channel(1024);
let client = ClientInner {
cmd_handler,
ctx: ArcMut::new(ConnectionHandlerContextWrapper::new(
channel,
)),
shutdown: Shutdown::new(notify),
};
let client_inner = ArcMut::new(client);
let mut client_ = client_inner.clone();
tokio::spawn(async move {
let _ = client_.run_recv().await;
});
let mut client_ = client_inner.clone();
tokio::spawn(async move {
client_.run_send(rx).await;
});
if let Some(tx) = tx {
let _ = tx.send(ConnectionNetEvent::CONNECTED(client_inner.ctx.channel.remote_address()));
}
Ok((tx_, client_inner))
}
async fn run_recv(&mut self) -> RocketMQResult<()> {
loop {
let channel = self.ctx.channel_mut();
let frame = tokio::select! {
res = channel.connection_mut().receive_command() => res,
_ = self.shutdown.recv() =>{
channel.connection_mut().close();
return Ok(());
}
};
let cmd = match frame {
Some(frame) => frame?,
None => {
return Ok(());
}
};
self.cmd_handler.process_message_received(&mut self.ctx, cmd).await;
}
}
async fn run_send(&mut self, mut rx: Receiver<SendMessage>) {
while let Some((request, tx, timeout)) = rx.recv().await {
let _ = self.send(request, tx, timeout).await;
}
}
pub async fn send(
&mut self,
request: RemotingCommand,
tx: Option<tokio::sync::oneshot::Sender<RocketMQResult<RemotingCommand>>>,
timeout_millis: Option<u64>,
) -> RocketMQResult<()> {
let opaque = request.opaque();
if let Some(tx) = tx {
self.cmd_handler.response_table.insert(
opaque,
ResponseFuture::new(opaque, timeout_millis.unwrap_or(0), true, tx),
);
}
match self.ctx.connection_mut().send_command(request).await {
Ok(_) => Ok(()),
Err(error) => {
if matches!(error, rocketmq_error::RocketMQError::IO(_)) {
self.cmd_handler.response_table.remove(&opaque);
return Err(connection_invalid(error.to_string()));
}
self.cmd_handler.response_table.remove(&opaque);
Err(error)
}
}
}
}
impl<PR> Client<PR>
where
PR: RequestProcessor + Sync + 'static,
{
pub(crate) async fn connect(
addr: String,
cmd_handler: ArcMut<RemotingGeneralHandler<PR>>,
tx: Option<&tokio::sync::broadcast::Sender<ConnectionNetEvent>>,
tls_config: TlsConfig,
) -> RocketMQResult<Client<PR>> {
let (notify_shutdown, _) = broadcast::channel(1);
let receiver = notify_shutdown.subscribe();
let (tx, inner) = ClientInner::connect(addr, cmd_handler, tx, receiver, tls_config).await?;
Ok(Client {
inner,
notify_shutdown,
tx,
})
}
pub async fn send_read(
&mut self,
request: RemotingCommand,
timeout_millis: u64,
) -> RocketMQResult<RemotingCommand> {
let (tx, rx) = tokio::sync::oneshot::channel::<RocketMQResult<RemotingCommand>>();
if let Err(err) = self.tx.send((request, Some(tx), Some(timeout_millis))).await {
return Err(remote_error(err.to_string()));
}
match rx.await {
Ok(value) => value,
Err(error) => Err(remote_error(error.to_string())),
}
}
pub async fn invoke_with_callback<F>(&self, request: RemotingCommand, mut func: F)
where
F: FnMut(),
{
let (tx, rx) = tokio::sync::oneshot::channel::<RocketMQResult<RemotingCommand>>();
if self.tx.send((request, Some(tx), None)).await.is_err() {
return;
}
let _ = rx.await;
func();
}
pub async fn send(&mut self, request: RemotingCommand) -> RocketMQResult<()> {
if let Err(err) = self.tx.send((request, None, None)).await {
return Err(remote_error(err.to_string()));
}
Ok(())
}
pub async fn send_batch(&mut self, requests: Vec<RemotingCommand>) -> RocketMQResult<()> {
for request in requests {
if let Err(err) = self.tx.send((request, None, None)).await {
return Err(remote_error(err.to_string()));
}
}
Ok(())
}
pub async fn send_batch_read(
&mut self,
requests: Vec<RemotingCommand>,
timeout_millis: u64,
) -> RocketMQResult<Vec<RocketMQResult<RemotingCommand>>> {
let mut receivers = Vec::with_capacity(requests.len());
for request in requests {
let (tx, rx) = tokio::sync::oneshot::channel::<RocketMQResult<RemotingCommand>>();
if let Err(err) = self.tx.send((request, Some(tx), Some(timeout_millis))).await {
return Err(remote_error(err.to_string()));
}
receivers.push(rx);
}
let mut results = Vec::with_capacity(receivers.len());
for rx in receivers {
let result = match rx.await {
Ok(value) => value,
Err(error) => Err(remote_error(error.to_string())),
};
results.push(result);
}
Ok(results)
}
async fn read(&mut self) -> RocketMQResult<RemotingCommand> {
match self.inner.ctx.connection_mut().receive_command().await {
Some(Ok(response)) => Ok(response),
Some(Err(error)) => {
if matches!(error, rocketmq_error::RocketMQError::IO(_)) {
Err(connection_invalid(error.to_string()))
} else {
Err(error)
}
}
None => Err(connection_invalid("connection disconnected")),
}
}
pub fn connection(&self) -> &Connection {
self.inner.ctx.connection_ref()
}
pub fn remote_address(&self) -> SocketAddr {
self.inner.ctx.channel.remote_address()
}
pub fn connection_mut(&mut self) -> &mut Connection {
self.inner.ctx.connection_mut()
}
}
fn server_name_from_addr(addr: &str) -> String {
if let Some(rest) = addr.strip_prefix('[') {
if let Some((host, _)) = rest.split_once(']') {
return host.to_string();
}
}
match addr.rsplit_once(':') {
Some((host, _)) if !host.contains(':') => host.to_string(),
_ => addr.to_string(),
}
}
#[cfg(test)]
mod tls_tests {
use super::server_name_from_addr;
#[test]
fn server_name_parser_handles_common_socket_forms() {
assert_eq!(server_name_from_addr("broker.example.com:10911"), "broker.example.com");
assert_eq!(server_name_from_addr("127.0.0.1:10911"), "127.0.0.1");
assert_eq!(server_name_from_addr("[::1]:10911"), "::1");
assert_eq!(server_name_from_addr("::1"), "::1");
}
}