use std::any::Any;
use std::sync::Arc;
use cheetah_string::CheetahString;
use rocketmq_common::common::message::message_queue::MessageQueue;
use rocketmq_error::RocketMQResult;
use rocketmq_error::RpcClientError;
use rocketmq_rust::ArcMut;
use tracing::error;
use tracing::trace;
use crate::clients::rocketmq_tokio_client::RocketmqDefaultClient;
use crate::clients::RemotingClient;
use crate::code::request_code::RequestCode;
use crate::code::response_code::ResponseCode;
use crate::protocol::command_custom_header::CommandCustomHeader;
use crate::protocol::command_custom_header::FromMap;
use crate::protocol::header::get_earliest_msg_storetime_response_header::GetEarliestMsgStoretimeResponseHeader;
use crate::protocol::header::get_max_offset_response_header::GetMaxOffsetResponseHeader;
use crate::protocol::header::get_min_offset_response_header::GetMinOffsetResponseHeader;
use crate::protocol::header::message_operation_header::TopicRequestHeaderTrait;
use crate::protocol::header::pull_message_response_header::PullMessageResponseHeader;
use crate::protocol::header::query_consumer_offset_response_header::QueryConsumerOffsetResponseHeader;
use crate::protocol::header::search_offset_response_header::SearchOffsetResponseHeader;
use crate::protocol::header::update_consumer_offset_header::UpdateConsumerOffsetResponseHeader;
use crate::protocol::remoting_command::RemotingCommand;
use crate::request_processor::default_request_processor::DefaultRemotingRequestProcessor;
use crate::rpc::client_metadata::ClientMetadata;
use crate::rpc::rpc_client::RpcClient;
use crate::rpc::rpc_client_hook::RpcClientHookFn;
use crate::rpc::rpc_client_utils::RpcClientUtils;
use crate::rpc::rpc_request::RpcRequest;
use crate::rpc::rpc_response::RpcResponse;
struct ResponseConfig {
success_codes: &'static [ResponseCode],
}
impl ResponseConfig {
const fn new(success_codes: &'static [ResponseCode]) -> Self {
Self { success_codes }
}
}
pub struct RpcClientImpl {
client_metadata: Arc<ClientMetadata>,
remoting_client: ArcMut<RocketmqDefaultClient<DefaultRemotingRequestProcessor>>,
client_hook_list: Vec<RpcClientHookFn>,
}
impl RpcClientImpl {
pub fn new(
client_metadata: Arc<ClientMetadata>,
remoting_client: ArcMut<RocketmqDefaultClient<DefaultRemotingRequestProcessor>>,
) -> Self {
RpcClientImpl {
client_metadata,
remoting_client,
client_hook_list: Vec::new(),
}
}
pub fn register_client_hook(&mut self, client_hook: RpcClientHookFn) {
self.client_hook_list.push(client_hook);
}
pub fn clear_client_hook(&mut self) {
self.client_hook_list.clear();
}
fn create_request_command<H>(
&self,
addr: &CheetahString,
request: RpcRequest<H>,
timeout_millis: u64,
) -> Result<(i32, RemotingCommand), RpcClientError>
where
H: CommandCustomHeader + TopicRequestHeaderTrait,
{
let request_code = request.code;
let request_command = RpcClientUtils::try_create_command_for_rpc_request(request).map_err(|err| {
RpcClientError::RequestFailed {
addr: addr.to_string(),
request_code,
timeout_ms: timeout_millis,
source: Box::new(err),
}
})?;
Ok((request_code, request_command))
}
fn get_broker_addr_by_name(&self, broker_name: &str) -> Result<CheetahString, RpcClientError> {
self.client_metadata
.find_master_broker_addr(broker_name)
.ok_or_else(|| RpcClientError::BrokerNotFound {
broker_name: broker_name.to_string(),
})
}
async fn handle_request<H, R>(
&self,
addr: &CheetahString,
request: RpcRequest<H>,
timeout_millis: u64,
config: ResponseConfig,
) -> Result<RpcResponse, RpcClientError>
where
H: CommandCustomHeader + TopicRequestHeaderTrait,
R: CommandCustomHeader + FromMap<Target = R, Error = rocketmq_error::RocketMQError> + Send + Sync + 'static,
{
trace!(
"Sending RPC request: addr={}, code={}, timeout={}ms",
addr,
request.code,
timeout_millis
);
let (request_code, request_command) = self.create_request_command(addr, request, timeout_millis)?;
let response = self
.remoting_client
.invoke_request(Some(addr), request_command, timeout_millis)
.await
.map_err(|err| {
error!(
"RPC request failed: addr={}, code={}, error={}",
addr, request_code, err
);
RpcClientError::RequestFailed {
addr: addr.to_string(),
request_code,
timeout_ms: timeout_millis,
source: Box::new(err),
}
})?;
let response_code = ResponseCode::from(response.code());
if !config.success_codes.contains(&response_code) {
return Err(RpcClientError::UnexpectedResponseCode {
code: response.code(),
code_name: format!("{:?}", response_code),
});
}
let response_header =
response
.decode_command_custom_header::<R>()
.map_err(|err| RpcClientError::RequestFailed {
addr: addr.to_string(),
request_code,
timeout_ms: timeout_millis,
source: Box::new(err),
})?;
let body = response.body().map(|value| Box::new(value.clone()) as Box<dyn Any>);
Ok(RpcResponse::new(response.code(), Box::new(response_header), body))
}
async fn handle_pull_message<H: CommandCustomHeader + TopicRequestHeaderTrait>(
&self,
addr: &CheetahString,
request: RpcRequest<H>,
timeout_millis: u64,
) -> Result<RpcResponse, RpcClientError> {
const PULL_SUCCESS_CODES: &[ResponseCode] = &[
ResponseCode::Success,
ResponseCode::PullNotFound,
ResponseCode::PullRetryImmediately,
ResponseCode::PullOffsetMoved,
];
self.handle_request::<H, PullMessageResponseHeader>(
addr,
request,
timeout_millis,
ResponseConfig::new(PULL_SUCCESS_CODES),
)
.await
}
async fn handle_query_consumer_offset<H: CommandCustomHeader + TopicRequestHeaderTrait>(
&self,
addr: &CheetahString,
request: RpcRequest<H>,
timeout_millis: u64,
) -> Result<RpcResponse, RpcClientError> {
let (request_code, request_command) = self.create_request_command(addr, request, timeout_millis)?;
let response = self
.remoting_client
.invoke_request(Some(addr), request_command, timeout_millis)
.await
.map_err(|err| RpcClientError::RequestFailed {
addr: addr.to_string(),
request_code,
timeout_ms: timeout_millis,
source: Box::new(err),
})?;
match ResponseCode::from(response.code()) {
ResponseCode::Success => {
let response_header = response
.decode_command_custom_header::<QueryConsumerOffsetResponseHeader>()
.map_err(|err| RpcClientError::RequestFailed {
addr: addr.to_string(),
request_code,
timeout_ms: timeout_millis,
source: Box::new(err),
})?;
let body = response.body().map(|value| Box::new(value.clone()) as Box<dyn Any>);
Ok(RpcResponse::new(response.code(), Box::new(response_header), body))
}
ResponseCode::QueryNotFound => {
Ok(RpcResponse::new_option(response.code(), None))
}
code => Err(RpcClientError::UnexpectedResponseCode {
code: response.code(),
code_name: format!("{:?}", code),
}),
}
}
fn execute_hooks<H: CommandCustomHeader + TopicRequestHeaderTrait>(
&self,
request: &RpcRequest<H>,
) -> RocketMQResult<Option<RpcResponse>> {
for hook in &self.client_hook_list {
if let Some(response) = hook(Some(&request.header), None)? {
trace!("Request intercepted by client hook");
return Ok(Some(response));
}
}
Ok(None)
}
}
impl RpcClient for RpcClientImpl {
async fn invoke<H: CommandCustomHeader + TopicRequestHeaderTrait>(
&self,
request: RpcRequest<H>,
timeout_millis: u64,
) -> RocketMQResult<RpcResponse> {
if let Some(response) = self.execute_hooks(&request)? {
return Ok(response);
}
let broker_name = request
.header
.broker_name()
.ok_or_else(|| RpcClientError::BrokerNotFound {
broker_name: "<missing brokerName>".to_string(),
})?;
let addr = self.get_broker_addr_by_name(broker_name.as_ref())?;
let result = match RequestCode::from(request.code) {
RequestCode::PullMessage => self.handle_pull_message(&addr, request, timeout_millis).await?,
RequestCode::GetMinOffset => {
self.handle_request::<H, GetMinOffsetResponseHeader>(
&addr,
request,
timeout_millis,
ResponseConfig::new(&[ResponseCode::Success]),
)
.await?
}
RequestCode::GetMaxOffset => {
self.handle_request::<H, GetMaxOffsetResponseHeader>(
&addr,
request,
timeout_millis,
ResponseConfig::new(&[ResponseCode::Success]),
)
.await?
}
RequestCode::SearchOffsetByTimestamp => {
self.handle_request::<H, SearchOffsetResponseHeader>(
&addr,
request,
timeout_millis,
ResponseConfig::new(&[ResponseCode::Success]),
)
.await?
}
RequestCode::GetEarliestMsgStoreTime => {
self.handle_request::<H, GetEarliestMsgStoretimeResponseHeader>(
&addr,
request,
timeout_millis,
ResponseConfig::new(&[ResponseCode::Success]),
)
.await?
}
RequestCode::QueryConsumerOffset => {
self.handle_query_consumer_offset(&addr, request, timeout_millis)
.await?
}
RequestCode::UpdateConsumerOffset => {
self.handle_request::<H, UpdateConsumerOffsetResponseHeader>(
&addr,
request,
timeout_millis,
ResponseConfig::new(&[ResponseCode::Success]),
)
.await?
}
RequestCode::GetTopicStatsInfo | RequestCode::GetTopicConfig => {
let (request_code, request_command) = self.create_request_command(&addr, request, timeout_millis)?;
let response = self
.remoting_client
.invoke_request(Some(&addr), request_command, timeout_millis)
.await
.map_err(|err| RpcClientError::RequestFailed {
addr: addr.to_string(),
request_code,
timeout_ms: timeout_millis,
source: Box::new(err),
})?;
if response.code() != ResponseCode::Success as i32 {
return Err(RpcClientError::UnexpectedResponseCode {
code: response.code(),
code_name: format!("{:?}", ResponseCode::from(response.code())),
}
.into());
}
let body = response.body().map(|value| Box::new(value.clone()) as Box<dyn Any>);
RpcResponse::new_option(response.code(), body)
}
_ => return Err(RpcClientError::UnsupportedRequestCode { code: request.code }.into()),
};
Ok(result)
}
async fn invoke_mq<H: CommandCustomHeader + TopicRequestHeaderTrait>(
&self,
mq: MessageQueue,
mut request: RpcRequest<H>,
timeout_millis: u64,
) -> RocketMQResult<RpcResponse> {
if let Some(broker_name) = self.client_metadata.get_broker_name_from_message_queue(&mq) {
request.header.set_broker_name(broker_name);
}
self.invoke(request, timeout_millis).await
}
}
impl RpcClientImpl {
pub fn invoke_with_callback<H, F>(
&self,
request: RpcRequest<H>,
timeout_millis: u64,
callback: F,
) -> tokio::task::JoinHandle<()>
where
H: CommandCustomHeader + TopicRequestHeaderTrait + Send + 'static,
F: FnOnce(RocketMQResult<RpcResponse>) + Send + 'static,
{
let client_metadata = self.client_metadata.clone();
let remoting_client = self.remoting_client.clone();
let hooks = self.client_hook_list.clone();
tokio::spawn(async move {
let temp_client = RpcClientImpl {
client_metadata,
remoting_client,
client_hook_list: hooks,
};
let result = temp_client.invoke(request, timeout_millis).await;
callback(result);
})
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::protocol::header::get_min_offset_request_header::GetMinOffsetRequestHeader;
use crate::runtime::config::client_config::TokioClientConfig;
#[test]
fn test_error_formatting() {
let err = RpcClientError::BrokerNotFound {
broker_name: "broker-a".to_string(),
};
assert!(err.to_string().contains("broker-a"));
let err = RpcClientError::UnexpectedResponseCode {
code: 1,
code_name: "SUCCESS".to_string(),
};
assert!(err.to_string().contains("Unexpected response code"));
}
#[tokio::test]
async fn invoke_without_broker_name_returns_typed_error_instead_of_panicking() {
let client_metadata = Arc::new(ClientMetadata::new());
let remoting_client = ArcMut::new(RocketmqDefaultClient::new(
Arc::new(TokioClientConfig::default()),
DefaultRemotingRequestProcessor,
));
let rpc_client = RpcClientImpl::new(client_metadata, remoting_client);
let request = RpcRequest::new(
RequestCode::GetMinOffset.into(),
GetMinOffsetRequestHeader {
topic: CheetahString::from_static_str("TopicA"),
queue_id: 0,
topic_request_header: None,
},
None,
);
let Err(error) = rpc_client.invoke(request, 3000).await else {
panic!("missing brokerName should return a typed error");
};
assert!(matches!(
error,
rocketmq_error::RocketMQError::Rpc(RpcClientError::BrokerNotFound { .. })
));
}
}