use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use cheetah_string::CheetahString;
use dashmap::DashMap;
use rocketmq_rust::ArcMut;
use rocketmq_rust::WeakArcMut;
use tokio::time;
use tokio_util::sync::CancellationToken;
use tracing::debug;
use tracing::error;
use tracing::info;
use tracing::warn;
use crate::base::connection_net_event::ConnectionNetEvent;
use crate::clients::connection_pool::ConnectionPool;
use crate::clients::nameserver_selector::LatencyTracker;
use crate::clients::reconnect::CircuitBreaker;
use crate::clients::Client;
use crate::clients::RemotingClient;
use crate::protocol::remoting_command::RemotingCommand;
use crate::remoting::inner::RemotingGeneralHandler;
use crate::remoting::RemotingService;
use crate::request_processor::default_request_processor::DefaultRemotingRequestProcessor;
use crate::runtime::config::client_config::TokioClientConfig;
use crate::runtime::processor::RequestProcessor;
use crate::runtime::RPCHook;
use crate::tls::TlsConfig;
pub struct RocketmqDefaultClient<PR = DefaultRemotingRequestProcessor> {
tokio_client_config: Arc<TokioClientConfig>,
connection_tables: Arc<DashMap<CheetahString , Client<PR>>>,
namesrv_addr_list: ArcMut<Vec<CheetahString>>,
namesrv_addr_choosed: ArcMut<Option<CheetahString>>,
available_namesrv_addr_set: ArcMut<HashSet<CheetahString>>,
latency_tracker: LatencyTracker,
circuit_breakers: Arc<DashMap<CheetahString, CircuitBreaker>>,
connection_pool: Option<ConnectionPool<PR>>,
shutdown_token: CancellationToken,
cmd_handler: ArcMut<RemotingGeneralHandler<PR>>,
tx: Option<tokio::sync::broadcast::Sender<ConnectionNetEvent>>,
}
impl<PR> Clone for RocketmqDefaultClient<PR> {
fn clone(&self) -> Self {
Self {
tokio_client_config: self.tokio_client_config.clone(),
connection_tables: self.connection_tables.clone(),
namesrv_addr_list: self.namesrv_addr_list.clone(),
namesrv_addr_choosed: self.namesrv_addr_choosed.clone(),
available_namesrv_addr_set: self.available_namesrv_addr_set.clone(),
latency_tracker: self.latency_tracker.clone(),
circuit_breakers: self.circuit_breakers.clone(),
connection_pool: self.connection_pool.clone(),
shutdown_token: self.shutdown_token.clone(),
cmd_handler: self.cmd_handler.clone(),
tx: self.tx.clone(),
}
}
}
impl<PR: RequestProcessor + Sync + Clone + 'static> RocketmqDefaultClient<PR> {
pub fn new(tokio_client_config: Arc<TokioClientConfig>, processor: PR) -> Self {
Self::new_with_cl(tokio_client_config, processor, None)
}
pub fn new_with_cl(
tokio_client_config: Arc<TokioClientConfig>,
processor: PR,
tx: Option<tokio::sync::broadcast::Sender<ConnectionNetEvent>>,
) -> Self {
let handler = RemotingGeneralHandler {
request_processor: processor,
rpc_hooks: vec![],
response_table: ArcMut::new(HashMap::with_capacity(512)),
};
Self {
tokio_client_config,
connection_tables: Arc::new(DashMap::with_capacity(64)),
namesrv_addr_list: ArcMut::new(Default::default()),
namesrv_addr_choosed: ArcMut::new(Default::default()),
available_namesrv_addr_set: ArcMut::new(Default::default()),
latency_tracker: LatencyTracker::new(),
circuit_breakers: Arc::new(DashMap::with_capacity(64)),
connection_pool: None,
shutdown_token: CancellationToken::new(),
cmd_handler: ArcMut::new(handler),
tx,
}
}
#[inline]
pub fn is_use_tls(&self) -> bool {
self.tokio_client_config.use_tls
}
#[inline]
pub fn tls_config(&self) -> &TlsConfig {
&self.tokio_client_config.tls_config
}
}
impl<PR: RequestProcessor + Sync + Clone + 'static> RocketmqDefaultClient<PR> {
pub fn enable_connection_pool(
&mut self,
max_connections: usize,
max_idle_duration: Duration,
cleanup_interval: Duration,
) -> tokio::task::JoinHandle<()> {
let pool = ConnectionPool::new(max_connections, max_idle_duration);
let cleanup_task = pool.start_cleanup_task(cleanup_interval);
self.connection_pool = Some(pool);
info!(
"Connection pool enabled: max={}, idle_timeout={:?}, cleanup_interval={:?}",
max_connections, max_idle_duration, cleanup_interval
);
cleanup_task
}
pub fn get_pool_stats(&self) -> Option<crate::clients::connection_pool::PoolStats> {
self.connection_pool.as_ref().map(|pool| pool.stats())
}
async fn get_and_create_nameserver_client(&self) -> Option<Client<PR>> {
let cached_addr = self.namesrv_addr_choosed.as_ref().clone();
if let Some(ref addr) = cached_addr {
if let Some(client) = self.connection_tables.get(addr) {
if client.connection().is_healthy() && self.latency_tracker.is_healthy(addr) {
return Some(client.value().clone());
}
debug!("Cached nameserver {} is unhealthy, selecting new one", addr);
}
}
let addr_list = self.namesrv_addr_list.as_ref();
if addr_list.is_empty() {
warn!("No nameservers configured in namesrv_addr_list");
return None;
}
let selected_addr = match self.latency_tracker.select_best(addr_list) {
Some(addr) => addr,
None => {
error!(
"Failed to select healthy nameserver. Available list: {:?}, Available set: {:?}",
addr_list,
self.available_namesrv_addr_set.as_ref()
);
return None;
}
};
info!(
"Selected nameserver: {} (P99: {:?}, errors: {})",
selected_addr,
self.latency_tracker
.get_p99(selected_addr)
.unwrap_or(Duration::from_secs(0)),
self.latency_tracker.get_error_count(selected_addr)
);
self.namesrv_addr_choosed.mut_from_ref().replace(selected_addr.clone());
self.create_client(
selected_addr,
Duration::from_millis(self.tokio_client_config.connect_timeout_millis as u64),
)
.await
}
async fn get_and_create_client(&self, addr: Option<&CheetahString>) -> Option<Client<PR>> {
let target_addr = match addr {
None => return self.get_and_create_nameserver_client().await,
Some(addr) if addr.is_empty() => return self.get_and_create_nameserver_client().await,
Some(addr) => addr,
};
if let Some(client_ref) = self.connection_tables.get(target_addr) {
let client = client_ref.value().clone();
if client.connection().is_healthy() {
return Some(client); }
debug!("Cached client for {} is unhealthy, reconnecting...", target_addr);
}
self.create_client(
target_addr,
Duration::from_millis(self.tokio_client_config.connect_timeout_millis as u64),
)
.await
}
async fn create_client(&self, addr: &CheetahString, duration: Duration) -> Option<Client<PR>> {
if let Some(ref pool) = self.connection_pool {
if let Some(pooled_conn) = pool.get(addr) {
if pooled_conn.is_healthy() {
debug!("Reusing pooled connection to {}", addr);
return Some(pooled_conn.client().clone());
}
pool.remove(addr);
}
}
if let Some(client_ref) = self.connection_tables.get(addr) {
let client = client_ref.value().clone();
if client.connection().is_healthy() {
return Some(client);
}
drop(client_ref); self.connection_tables.remove(addr);
}
let mut breaker = self
.circuit_breakers
.entry(addr.clone())
.or_insert_with(CircuitBreaker::default_breaker)
.clone();
if !breaker.allow_request() {
warn!("Circuit breaker OPEN for {}, rejecting connection attempt", addr);
return None;
}
let addr_inner = addr.to_string();
let mut tls_config = self.tokio_client_config.tls_config.clone();
tls_config.enable = self.tokio_client_config.use_tls;
let connect_result = time::timeout(duration, async {
Client::connect(addr_inner, self.cmd_handler.clone(), self.tx.as_ref(), tls_config).await
})
.await;
match connect_result {
Ok(Ok(new_client)) => {
breaker.record_success();
self.circuit_breakers.insert(addr.clone(), breaker);
if let Some(ref pool) = self.connection_pool {
if pool.insert(addr.clone(), new_client.clone()) {
info!("Added connection to pool: {} (pool size: {})", addr, pool.stats().total);
} else {
warn!("Connection pool at capacity, falling back to DashMap");
}
}
match self.connection_tables.entry(addr.clone()) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
if entry.get().connection().is_healthy() {
info!("Race condition: {} already connected by another task", addr);
return Some(entry.get().clone());
}
entry.insert(new_client.clone());
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(new_client.clone());
}
}
info!("Successfully created client for {}", addr);
Some(new_client)
}
Ok(Err(e)) => {
error!("Failed to connect to {}: {:?}", addr, e);
breaker.record_failure();
self.circuit_breakers.insert(addr.clone(), breaker);
None
}
Err(_) => {
error!("Connection to {} timed out after {:?}", addr, duration);
breaker.record_failure();
self.circuit_breakers.insert(addr.clone(), breaker);
None
}
}
}
async fn create_client_with_retry(
&self,
addr: &CheetahString,
duration: Duration,
max_attempts: u32,
) -> Option<Client<PR>> {
use crate::clients::reconnect::ExponentialBackoff;
let mut backoff = ExponentialBackoff::new(
Duration::from_secs(1), Duration::from_secs(10), max_attempts,
);
loop {
if let Some(client) = self.create_client(addr, duration).await {
return Some(client);
}
if let Some(delay) = backoff.next_delay() {
debug!(
"Connection to {} failed, retrying in {:?} (attempt {}/{})",
addr,
delay,
backoff.current_attempt(),
max_attempts
);
time::sleep(delay).await;
} else {
warn!(
"Connection to {} failed after {} attempts",
addr,
backoff.current_attempt()
);
return None;
}
}
}
async fn scan_available_name_srv(&self) {
let addr_list = self.namesrv_addr_list.as_ref();
if addr_list.is_empty() {
debug!("No nameservers configured, skipping availability scan");
return;
}
let stale_addrs: Vec<CheetahString> = self
.available_namesrv_addr_set
.as_ref()
.iter()
.filter(|addr| !addr_list.contains(addr))
.cloned()
.collect();
for stale_addr in stale_addrs {
warn!("Removing stale nameserver from available set: {}", stale_addr);
self.available_namesrv_addr_set.mut_from_ref().remove(&stale_addr);
}
use futures::future::join_all;
let probe_futures: Vec<_> = addr_list
.iter()
.map(|addr| {
let addr_clone = addr.clone();
async move {
let result = self.get_and_create_client(Some(&addr_clone)).await;
(addr_clone, result.is_some())
}
})
.collect();
let results = join_all(probe_futures).await;
for (namesrv_addr, is_available) in results {
if is_available {
if self
.available_namesrv_addr_set
.mut_from_ref()
.insert(namesrv_addr.clone())
{
info!("Nameserver {} is now available", namesrv_addr);
}
} else {
if self.available_namesrv_addr_set.mut_from_ref().remove(&namesrv_addr) {
warn!("Nameserver {} is now unavailable", namesrv_addr);
}
}
}
debug!(
"Availability scan complete: {}/{} nameservers available",
self.available_namesrv_addr_set.as_ref().len(),
addr_list.len()
);
}
fn scan_idle_connections(&self) {
let interval_ms = self.tokio_client_config.channel_not_active_interval;
if interval_ms <= 0 {
return;
}
let idle_threshold = Duration::from_millis(interval_ms as u64);
let mut stale_addrs = Vec::new();
for entry in self.connection_tables.iter() {
let addr = entry.key().clone();
let client = entry.value();
if !client.connection().is_healthy() {
stale_addrs.push(addr);
continue;
}
if let Some(ref pool) = self.connection_pool {
if let Some(pooled) = pool.get(&addr) {
if pooled.is_idle(idle_threshold) {
stale_addrs.push(addr);
}
}
}
}
for addr in &stale_addrs {
if self.connection_tables.remove(addr).is_some() {
warn!("[SCAN] Removed idle/unhealthy connection: {}", addr);
if let Some(ref pool) = self.connection_pool {
pool.remove(addr);
}
}
}
}
}
#[allow(unused_variables)]
impl<PR: RequestProcessor + Sync + Clone + 'static> RemotingService for RocketmqDefaultClient<PR> {
async fn start(&self, this: WeakArcMut<Self>) {
if let Some(client) = this.upgrade() {
let connect_timeout_millis = self.tokio_client_config.connect_timeout_millis as u64;
let token = self.shutdown_token.clone();
let client_for_scan = client.clone();
let scan_token = token.clone();
tokio::spawn(async move {
loop {
tokio::select! {
() = scan_token.cancelled() => break,
() = async {
client_for_scan.scan_available_name_srv().await;
time::sleep(Duration::from_millis(connect_timeout_millis)).await;
} => {}
}
}
});
let channel_not_active_interval = self.tokio_client_config.channel_not_active_interval as u64;
if channel_not_active_interval > 0 {
let idle_token = token.clone();
tokio::spawn(async move {
loop {
tokio::select! {
() = idle_token.cancelled() => break,
() = time::sleep(Duration::from_millis(channel_not_active_interval)) => {
client.scan_idle_connections();
}
}
}
});
}
}
}
fn shutdown(&mut self) {
self.shutdown_token.cancel();
self.connection_tables.clear();
self.namesrv_addr_list.clear();
self.available_namesrv_addr_set.clear();
info!("RemotingClient shutdown complete");
}
fn register_rpc_hook(&mut self, hook: Arc<dyn RPCHook>) {
self.cmd_handler.register_rpc_hook(hook);
}
fn clear_rpc_hook(&mut self) {
self.cmd_handler.clear_rpc_hook();
}
}
#[allow(unused_variables)]
impl<PR: RequestProcessor + Sync + Clone + 'static> RemotingClient for RocketmqDefaultClient<PR> {
async fn update_name_server_address_list(&self, addrs: Vec<CheetahString>) {
if addrs.is_empty() {
return;
}
let mut update = false;
{
let current: &Vec<CheetahString> = &self.namesrv_addr_list;
if current.is_empty() || addrs.len() != current.len() {
update = true;
} else {
for addr in &addrs {
if !current.contains(addr) {
update = true;
break;
}
}
}
}
if !update {
return;
}
info!(
"name server address updated. NEW : {:?} , OLD: {:?}",
addrs,
self.namesrv_addr_list.as_ref() as &Vec<CheetahString>
);
use rand::seq::SliceRandom;
let mut shuffled = addrs.clone();
shuffled.shuffle(&mut rand::rng());
let list = self.namesrv_addr_list.mut_from_ref();
list.clear();
list.extend(shuffled);
let stale_addr = self.namesrv_addr_choosed.as_ref().clone();
if let Some(namesrv_addr) = stale_addr {
if !addrs.contains(&namesrv_addr) {
self.namesrv_addr_choosed.mut_from_ref().take();
self.connection_tables.remove(&namesrv_addr);
}
}
}
fn get_name_server_address_list(&self) -> &[CheetahString] {
self.namesrv_addr_list.as_ref()
}
fn get_available_name_srv_list(&self) -> Vec<CheetahString> {
self.available_namesrv_addr_set.as_ref().clone().into_iter().collect()
}
async fn invoke_request(
&self,
addr: Option<&CheetahString>,
request: RemotingCommand,
timeout_millis: u64,
) -> rocketmq_error::RocketMQResult<RemotingCommand> {
let start = time::Instant::now();
let target_addr = addr.cloned().or_else(|| self.namesrv_addr_choosed.as_ref().clone());
let mut client = self.get_and_create_client(addr).await.ok_or_else(|| {
let target = addr.map(|a| a.as_str()).unwrap_or("<nameserver>");
if target == "<nameserver>" {
error!(
"Failed to get client for <nameserver>. Diagnostics: configured_list={:?}, available_set={:?}, \
cached_choice={:?}, connections={}",
self.namesrv_addr_list.as_ref(),
self.available_namesrv_addr_set.as_ref(),
self.namesrv_addr_choosed.as_ref(),
self.connection_tables.len()
);
} else {
error!("Failed to get client for {}", target);
}
if let Some(ref addr) = target_addr {
self.latency_tracker.record_error(addr);
}
rocketmq_error::RocketMQError::network_connection_failed(target.to_string(), "Failed to connect")
})?;
if self.shutdown_token.is_cancelled() {
return Err(rocketmq_error::RocketMQError::ClientNotStarted);
}
let mut request = request;
let remote_address = client.remote_address();
let request_for_after = if self.cmd_handler.has_rpc_hooks() {
request.make_custom_header_to_net();
self.cmd_handler
.do_before_rpc_hooks_with_addr(remote_address, Some(&mut request))?;
Some(request.clone())
} else {
None
};
let send_result = time::timeout(
Duration::from_millis(timeout_millis),
client.send_read(request, timeout_millis),
)
.await;
let latency = start.elapsed();
match send_result {
Ok(Ok(mut response)) => {
if let Some(request) = request_for_after.as_ref() {
self.cmd_handler
.do_after_rpc_hooks_with_addr(remote_address, request, Some(&mut response))?;
}
if let Some(ref addr) = target_addr {
let latency_ms = latency.as_millis() as u64;
self.latency_tracker.record_success(addr, latency);
if let Some(ref pool) = self.connection_pool {
pool.record_success(addr, latency_ms);
}
debug!("Request to {} completed in {:?}", addr, latency);
}
Ok(response)
}
Ok(Err(err)) => {
if let Some(ref addr) = target_addr {
self.latency_tracker.record_error(addr);
if let Some(ref pool) = self.connection_pool {
pool.record_error(addr);
}
warn!("Request to {} failed after {:?}: {:?}", addr, latency, err);
}
Err(err)
}
Err(_) => {
if let Some(ref addr) = target_addr {
self.latency_tracker.record_error(addr);
if let Some(ref pool) = self.connection_pool {
pool.record_error(addr);
}
}
Err(rocketmq_error::RocketMQError::Timeout {
operation: "send_request",
timeout_ms: timeout_millis,
})
}
}
}
async fn invoke_request_oneway(&self, addr: &CheetahString, request: RemotingCommand, timeout_millis: u64) {
let client = self.get_and_create_client(Some(addr)).await;
match client {
None => {
error!("invokeOneway: get client for {} failed", addr);
}
Some(mut client) => {
let mut request = request;
if self.cmd_handler.has_rpc_hooks() {
let remote_address = client.remote_address();
request.make_custom_header_to_net();
if let Err(error) = self
.cmd_handler
.do_before_rpc_hooks_with_addr(remote_address, Some(&mut request))
{
warn!("invokeOneway: before RPC hook failed for {}: {:?}", addr, error);
return;
}
}
let addr_clone = addr.clone();
tokio::spawn(async move {
match time::timeout(Duration::from_millis(timeout_millis), async move {
let mut request = request;
request.mark_oneway_rpc_ref();
client.send(request).await
})
.await
{
Ok(Ok(())) => {}
Ok(Err(e)) => {
warn!("invokeOneway: send request to {} failed: {:?}", addr_clone, e);
}
Err(_) => {
warn!(
"invokeOneway: send request to {} timeout ({}ms)",
addr_clone, timeout_millis
);
}
}
});
}
}
}
fn invoke_oneway_unbounded(&self, addr: CheetahString, request: RemotingCommand) {
let client_owner = self.clone();
tokio::spawn(async move {
if client_owner.shutdown_token.is_cancelled() {
tracing::debug!(
"invoke_oneway_unbounded: client is shut down, skipping send to {}",
addr
);
return;
}
let Some(mut client) = client_owner.get_and_create_client(Some(&addr)).await else {
tracing::warn!("invoke_oneway_unbounded: failed to get or create client for {}", addr);
return;
};
let mut request = request;
request.mark_oneway_rpc_ref();
if client_owner.cmd_handler.has_rpc_hooks() {
let remote_address = client.remote_address();
request.make_custom_header_to_net();
if let Err(error) = client_owner
.cmd_handler
.do_before_rpc_hooks_with_addr(remote_address, Some(&mut request))
{
tracing::warn!(
"invoke_oneway_unbounded: before RPC hook failed for {}: {:?}",
addr,
error
);
return;
}
}
if let Err(error) = client.send(request).await {
tracing::warn!("invoke_oneway_unbounded: send request to {} failed: {:?}", addr, error);
}
});
}
fn is_address_reachable(&mut self, addr: &CheetahString) {
if let Some(client_ref) = self.connection_tables.get(addr) {
if client_ref.value().connection().is_healthy() {
return;
}
drop(client_ref);
self.connection_tables.remove(addr);
warn!("Removed unhealthy connection for {}", addr);
} else {
debug!("No connection found for {}", addr);
}
}
fn close_clients(&mut self, addrs: Vec<String>) {
for addr in &addrs {
let key = CheetahString::from(addr.as_str());
if let Some((_, _client)) = self.connection_tables.remove(&key) {
info!("Closed client connection for {}", addr);
}
}
}
fn register_processor(&mut self, processor: impl RequestProcessor + Sync) {
let _ = &processor;
warn!("dynamic request processor registration is not supported by RocketmqDefaultClient after construction");
}
}
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use rocketmq_error::RocketMQResult;
use tokio::net::TcpListener;
use super::*;
use crate::code::request_code::RequestCode;
use crate::code::response_code::ResponseCode;
use crate::connection::Connection;
use crate::request_processor::default_request_processor::DefaultRemotingRequestProcessor;
use crate::runtime::config::client_config::TokioClientConfig;
#[derive(Default)]
struct CountingHook {
before_count: AtomicUsize,
after_count: AtomicUsize,
}
impl RPCHook for CountingHook {
fn do_before_request(&self, _remote_addr: SocketAddr, request: &mut RemotingCommand) -> RocketMQResult<()> {
self.before_count.fetch_add(1, Ordering::SeqCst);
request.ensure_ext_fields_initialized();
request.add_ext_field("hooked", "true");
Ok(())
}
fn do_after_response(
&self,
_remote_addr: SocketAddr,
_request: &RemotingCommand,
response: &mut RemotingCommand,
) -> RocketMQResult<()> {
self.after_count.fetch_add(1, Ordering::SeqCst);
response.ensure_ext_fields_initialized();
response.add_ext_field("afterHook", "true");
Ok(())
}
}
#[test]
fn is_use_tls_reflects_client_config() {
let config = TokioClientConfig {
use_tls: true,
tls_config: TlsConfig {
enable: true,
..TlsConfig::default()
},
..Default::default()
};
let client = RocketmqDefaultClient::new(Arc::new(config), DefaultRemotingRequestProcessor);
assert!(client.is_use_tls());
assert!(client.tls_config().enable);
}
#[tokio::test]
async fn invoke_request_runs_outbound_rpc_hooks() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind listener");
let addr = listener.local_addr().expect("listener addr");
let server = tokio::spawn(async move {
let (socket, _) = listener.accept().await.expect("accept client");
let mut connection = Connection::new(socket);
let request = connection
.receive_command()
.await
.expect("request frame")
.expect("request command");
let hooked = request
.ext_fields()
.and_then(|fields| fields.get("hooked"))
.map(|value| value.as_str());
assert_eq!(hooked, Some("true"));
let mut response = RemotingCommand::create_response_command_with_code(ResponseCode::Success);
response.set_opaque_mut(request.opaque());
connection.send_command(response).await.expect("send response");
});
let hook = Arc::new(CountingHook::default());
let mut client =
RocketmqDefaultClient::new(Arc::new(TokioClientConfig::default()), DefaultRemotingRequestProcessor);
client.register_rpc_hook(hook.clone());
let target = CheetahString::from_string(addr.to_string());
let request = RemotingCommand::create_remoting_command(RequestCode::GetBrokerClusterInfo);
let response = client
.invoke_request(Some(&target), request, 3_000)
.await
.expect("invoke request");
assert_eq!(hook.before_count.load(Ordering::SeqCst), 1);
assert_eq!(hook.after_count.load(Ordering::SeqCst), 1);
assert_eq!(
response
.ext_fields()
.and_then(|fields| fields.get("afterHook"))
.map(|value| value.as_str()),
Some("true")
);
server.await.expect("server task");
client.shutdown();
}
#[tokio::test]
async fn invoke_oneway_unbounded_creates_connection_before_sending() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind listener");
let addr = listener.local_addr().expect("listener addr");
let (received_tx, received_rx) = tokio::sync::oneshot::channel();
let server = tokio::spawn(async move {
let (socket, _) = listener.accept().await.expect("accept client");
let mut connection = Connection::new(socket);
let request = time::timeout(Duration::from_secs(3), connection.receive_command())
.await
.expect("oneway request should arrive")
.expect("request frame")
.expect("request command");
let hooked = request
.ext_fields()
.and_then(|fields| fields.get("hooked"))
.map(|value| value.as_str() == "true")
.unwrap_or(false);
let _ = received_tx.send((request.code(), request.is_oneway_rpc(), hooked));
});
let hook = Arc::new(CountingHook::default());
let mut client =
RocketmqDefaultClient::new(Arc::new(TokioClientConfig::default()), DefaultRemotingRequestProcessor);
client.register_rpc_hook(hook.clone());
let target = CheetahString::from_string(addr.to_string());
let request = RemotingCommand::create_remoting_command(RequestCode::GetBrokerClusterInfo);
client.invoke_oneway_unbounded(target, request);
let (code, is_oneway, hooked) = time::timeout(Duration::from_secs(3), received_rx)
.await
.expect("server should receive unbounded oneway request")
.expect("server should report received request");
assert_eq!(code, RequestCode::GetBrokerClusterInfo.to_i32());
assert!(is_oneway);
assert!(hooked);
assert_eq!(hook.before_count.load(Ordering::SeqCst), 1);
assert_eq!(hook.after_count.load(Ordering::SeqCst), 0);
server.await.expect("server task");
client.shutdown();
}
}