use std::hash::Hash;
use std::hash::Hasher;
use bytes::BufMut;
use bytes::Bytes;
use bytes::BytesMut;
use cheetah_string::CheetahString;
use futures_util::stream::SplitSink;
use futures_util::stream::SplitStream;
use futures_util::SinkExt;
use futures_util::StreamExt;
use tokio::io::AsyncRead;
use tokio::io::AsyncWrite;
use tokio::net::TcpStream;
use tokio::sync::watch;
use tokio_util::codec::Framed;
use uuid::Uuid;
use crate::codec::remoting_command_codec::CompositeCodec;
use crate::protocol::remoting_command::RemotingCommand;
pub type ConnectionId = CheetahString;
pub trait ConnectionTransport: AsyncRead + AsyncWrite + Send + Unpin {}
impl<T> ConnectionTransport for T where T: AsyncRead + AsyncWrite + Send + Unpin {}
pub type BoxedConnectionTransport = Box<dyn ConnectionTransport>;
pub type ConnectionFramed = Framed<BoxedConnectionTransport, CompositeCodec>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Healthy,
Degraded,
Closed,
}
pub struct Connection {
outbound_sink: SplitSink<ConnectionFramed, Bytes>,
inbound_stream: SplitStream<ConnectionFramed>,
state_tx: watch::Sender<ConnectionState>,
state_rx: watch::Receiver<ConnectionState>,
encode_buffer: BytesMut,
connection_id: ConnectionId,
}
impl Hash for Connection {
fn hash<H: Hasher>(&self, state: &mut H) {
self.connection_id.hash(state);
}
}
impl PartialEq for Connection {
fn eq(&self, other: &Self) -> bool {
self.connection_id == other.connection_id
}
}
impl Eq for Connection {}
impl Connection {
pub fn new(tcp_stream: TcpStream) -> Connection {
Self::new_with_stream(tcp_stream)
}
pub fn new_with_stream<S>(stream: S) -> Connection
where
S: ConnectionTransport + 'static,
{
const CAPACITY: usize = 1024 * 1024; const BUFFER_SIZE: usize = 8 * 1024; let framed = Framed::with_capacity(
Box::new(stream) as BoxedConnectionTransport,
CompositeCodec::new(),
CAPACITY,
);
let (outbound_sink, inbound_stream) = framed.split();
let (state_tx, state_rx) = watch::channel(ConnectionState::Healthy);
Self {
outbound_sink,
inbound_stream,
state_tx,
state_rx,
encode_buffer: BytesMut::with_capacity(BUFFER_SIZE),
connection_id: CheetahString::from_string(Uuid::new_v4().to_string()),
}
}
#[inline]
pub fn inbound_stream(&self) -> &SplitStream<ConnectionFramed> {
&self.inbound_stream
}
#[inline]
pub fn outbound_sink(&self) -> &SplitSink<ConnectionFramed, Bytes> {
&self.outbound_sink
}
pub async fn receive_command(&mut self) -> Option<rocketmq_error::RocketMQResult<RemotingCommand>> {
self.inbound_stream.next().await
}
pub async fn send_command(&mut self, mut command: RemotingCommand) -> rocketmq_error::RocketMQResult<()> {
command.fast_header_encode(&mut self.encode_buffer);
if let Some(body_inner) = command.take_body() {
self.encode_buffer.put(body_inner);
}
let len = self.encode_buffer.len();
let bytes = self.encode_buffer.split_to(len).freeze();
match self.outbound_sink.send(bytes).await {
Ok(()) => Ok(()),
Err(e) => {
self.mark_degraded();
Err(e)
}
}
}
pub async fn send_command_ref(&mut self, command: &mut RemotingCommand) -> rocketmq_error::RocketMQResult<()> {
command.fast_header_encode(&mut self.encode_buffer);
if let Some(body_inner) = command.take_body() {
self.encode_buffer.put(body_inner);
}
let len = self.encode_buffer.len();
let bytes = self.encode_buffer.split_to(len).freeze();
match self.outbound_sink.send(bytes).await {
Ok(()) => Ok(()),
Err(e) => {
self.mark_degraded();
Err(e)
}
}
}
pub async fn send_batch(&mut self, mut commands: Vec<RemotingCommand>) -> rocketmq_error::RocketMQResult<()> {
if commands.is_empty() {
return Ok(());
}
for command in &mut commands {
command.fast_header_encode(&mut self.encode_buffer);
if let Some(body_inner) = command.take_body() {
self.encode_buffer.put(body_inner);
}
}
let len = self.encode_buffer.len();
let bytes = self.encode_buffer.split_to(len).freeze();
match self.outbound_sink.send(bytes).await {
Ok(()) => Ok(()),
Err(e) => {
self.mark_degraded();
Err(e)
}
}
}
pub async fn send_bytes(&mut self, bytes: Bytes) -> rocketmq_error::RocketMQResult<()> {
match self.outbound_sink.send(bytes).await {
Ok(()) => Ok(()),
Err(e) => {
self.mark_degraded();
Err(e)
}
}
}
pub async fn send_slice(&mut self, slice: &'static [u8]) -> rocketmq_error::RocketMQResult<()> {
let bytes = slice.into();
match self.outbound_sink.send(bytes).await {
Ok(()) => Ok(()),
Err(e) => {
self.mark_degraded();
Err(e)
}
}
}
#[inline]
pub fn connection_id(&self) -> &ConnectionId {
&self.connection_id
}
#[inline]
pub fn state(&self) -> ConnectionState {
*self.state_rx.borrow()
}
#[inline]
pub fn is_healthy(&self) -> bool {
self.state() == ConnectionState::Healthy
}
pub fn subscribe(&self) -> watch::Receiver<ConnectionState> {
self.state_tx.subscribe()
}
#[inline]
fn mark_degraded(&self) {
let _ = self.state_tx.send(ConnectionState::Degraded);
}
#[inline]
fn mark_closed(&self) {
let _ = self.state_tx.send(ConnectionState::Closed);
}
pub fn close(&self) {
self.mark_closed();
}
#[inline]
#[deprecated(since = "0.7.0", note = "Use `is_healthy()` or `state()` instead")]
pub fn connection_is_ok(&self) -> bool {
self.is_healthy()
}
}