use std::collections::HashMap;
use std::fmt::Debug;
use std::fmt::Display;
use std::hash::Hash;
use std::hash::Hasher;
use std::net::SocketAddr;
use std::time::Duration;
use cheetah_string::CheetahString;
use flume::Receiver;
use flume::Sender;
use rocketmq_error::RocketMQError;
use rocketmq_rust::ArcMut;
use tokio::time::timeout;
use tracing::error;
use uuid::Uuid;
use crate::base::response_future::ResponseFuture;
use crate::connection::Connection;
use crate::protocol::remoting_command::RemotingCommand;
pub type ChannelId = CheetahString;
pub type ArcChannel = ArcMut<Channel>;
#[derive(Clone)]
pub struct Channel {
inner: ArcMut<ChannelInner>,
local_address: SocketAddr,
remote_address: SocketAddr,
channel_id: ChannelId,
}
impl Channel {
pub fn new(inner: ArcMut<ChannelInner>, local_address: SocketAddr, remote_address: SocketAddr) -> Self {
let channel_id = Uuid::new_v4().to_string().into();
Self {
inner,
local_address,
remote_address,
channel_id,
}
}
#[inline]
pub fn set_local_address(&mut self, local_address: SocketAddr) {
self.local_address = local_address;
}
#[inline]
pub fn set_remote_address(&mut self, remote_address: SocketAddr) {
self.remote_address = remote_address;
}
#[inline]
pub fn set_channel_id(&mut self, channel_id: impl Into<CheetahString>) {
self.channel_id = channel_id.into();
}
#[inline]
pub fn local_address(&self) -> SocketAddr {
self.local_address
}
#[inline]
pub fn remote_address(&self) -> SocketAddr {
self.remote_address
}
#[inline]
pub fn channel_id(&self) -> &str {
self.channel_id.as_str()
}
pub fn channel_id_owned(&self) -> CheetahString {
self.channel_id.clone()
}
#[inline]
pub fn connection_mut(&mut self) -> &mut Connection {
self.inner.connection.as_mut()
}
#[inline]
pub fn connection_ref(&self) -> &Connection {
self.inner.connection_ref()
}
pub fn channel_inner(&self) -> &ChannelInner {
self.inner.as_ref()
}
pub fn channel_inner_mut(&mut self) -> &mut ChannelInner {
self.inner.as_mut()
}
}
impl PartialEq for Channel {
fn eq(&self, other: &Self) -> bool {
self.local_address == other.local_address
&& self.remote_address == other.remote_address
&& self.channel_id == other.channel_id
}
}
impl Eq for Channel {}
impl Hash for Channel {
fn hash<H: Hasher>(&self, state: &mut H) {
self.local_address.hash(state);
self.remote_address.hash(state);
self.channel_id.hash(state);
}
}
impl Debug for Channel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Channel {{ local_address: {:?}, remote_address: {:?}, channel_id: {} }}",
self.local_address, self.remote_address, self.channel_id
)
}
}
impl Display for Channel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Channel {{ local_address: {}, remote_address: {}, channel_id: {} }}",
self.local_address, self.remote_address, self.channel_id
)
}
}
type ChannelMessage = (
RemotingCommand,
Option<tokio::sync::oneshot::Sender<rocketmq_error::RocketMQResult<RemotingCommand>>>,
Option<u64>,
);
pub struct ChannelInner {
outbound_queue_tx: Sender<ChannelMessage>,
pub(crate) connection: ArcMut<Connection>,
pub(crate) response_table: ArcMut<HashMap<i32, ResponseFuture>>,
}
pub(crate) async fn handle_send(
mut connection: ArcMut<Connection>,
rx: Receiver<ChannelMessage>,
mut response_table: ArcMut<HashMap<i32, ResponseFuture>>,
) {
loop {
let msg = match rx.recv_async().await {
Ok(msg) => msg,
Err(_) => {
break;
}
};
let (send, tx, timeout_millis) = msg;
let opaque = send.opaque();
if let Some(tx) = tx {
response_table.insert(
opaque,
ResponseFuture::new(opaque, timeout_millis.unwrap_or(0), true, tx),
);
}
match connection.send_command(send).await {
Ok(_) => {}
Err(error) => match error {
rocketmq_error::RocketMQError::IO(error) => {
error!("send request failed: {}", error);
response_table.remove(&opaque);
return;
}
_ => {
response_table.remove(&opaque);
}
},
};
}
}
impl ChannelInner {
pub fn new(connection: Connection, response_table: ArcMut<HashMap<i32, ResponseFuture>>) -> Self {
const QUEUE_CAPACITY: usize = 1024;
let (outbound_queue_tx, outbound_queue_rx) = flume::bounded(QUEUE_CAPACITY);
let connection = ArcMut::new(connection);
tokio::spawn(handle_send(
connection.clone(),
outbound_queue_rx,
response_table.clone(),
));
Self {
outbound_queue_tx,
connection,
response_table,
}
}
}
impl ChannelInner {
#[inline]
pub fn connection(&self) -> ArcMut<Connection> {
self.connection.clone()
}
#[inline]
pub fn connection_ref(&self) -> &Connection {
self.connection.as_ref()
}
#[inline]
pub fn connection_mut(&mut self) -> &mut Connection {
self.connection.as_mut()
}
pub async fn send_wait_response(
&mut self,
request: RemotingCommand,
timeout_millis: u64,
) -> rocketmq_error::RocketMQResult<RemotingCommand> {
let (response_tx, response_rx) =
tokio::sync::oneshot::channel::<rocketmq_error::RocketMQResult<RemotingCommand>>();
let opaque = request.opaque();
if let Err(err) = self
.outbound_queue_tx
.send_async((request, Some(response_tx), Some(timeout_millis)))
.await
{
return Err(RocketMQError::network_connection_failed(
"channel",
format!("send failed: {}", err),
));
}
match timeout(Duration::from_millis(timeout_millis), response_rx).await {
Ok(result) => match result {
Ok(response) => response,
Err(e) => {
self.response_table.remove(&opaque);
Err(RocketMQError::network_connection_failed(
"channel",
format!("connection dropped: {}", e),
))
}
},
Err(_) => {
self.response_table.remove(&opaque);
Err(RocketMQError::Timeout {
operation: "channel_recv",
timeout_ms: timeout_millis,
})
}
}
}
pub async fn send_oneway(
&self,
request: RemotingCommand,
timeout_millis: u64,
) -> rocketmq_error::RocketMQResult<()> {
let request = request.mark_oneway_rpc();
if let Err(err) = self
.outbound_queue_tx
.send_async((request, None, Some(timeout_millis)))
.await
{
error!("send oneway request failed: {}", err);
return Err(RocketMQError::network_connection_failed(
"channel",
format!("send oneway failed: {}", err),
));
}
Ok(())
}
pub async fn send(
&self,
request: RemotingCommand,
timeout_millis: Option<u64>,
) -> rocketmq_error::RocketMQResult<()> {
if let Err(err) = self.outbound_queue_tx.send_async((request, None, timeout_millis)).await {
error!("send request failed: {}", err);
return Err(RocketMQError::network_connection_failed(
"channel",
format!("send failed: {}", err),
));
}
Ok(())
}
#[inline]
pub fn is_healthy(&self) -> bool {
self.connection.is_healthy()
}
#[inline]
#[deprecated(since = "0.1.0", note = "Use `is_healthy()` instead")]
pub fn is_ok(&self) -> bool {
self.connection.is_healthy()
}
}