use std::io::IoSlice;
use bytes::Bytes;
use bytes::BytesMut;
use cheetah_string::CheetahString;
use futures_util::SinkExt;
use futures_util::StreamExt;
use rocketmq_error::RocketMQError;
use rocketmq_error::RocketMQResult;
use tokio::io::AsyncWriteExt;
use tokio::net::tcp::OwnedReadHalf;
use tokio::net::tcp::OwnedWriteHalf;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tokio::task::JoinHandle;
use tokio_util::codec::FramedRead;
use tokio_util::codec::FramedWrite;
use uuid::Uuid;
use crate::codec::remoting_command_codec::CompositeCodec;
use crate::protocol::remoting_command::RemotingCommand;
async fn write_all_vectored(writer: &mut OwnedWriteHalf, mut slices: &mut [IoSlice<'_>]) -> RocketMQResult<()> {
while !slices.is_empty() {
let written = writer.write_vectored(slices).await.map_err(|e| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"write_vectored",
format!("{}", e),
))
})?;
if written == 0 {
return Err(RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"write_vectored",
"Write returned 0 bytes",
)));
}
IoSlice::advance_slices(&mut slices, written);
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionState {
Healthy,
Degraded,
Closed,
}
pub(crate) enum WriteCommand {
SendCommand(RemotingCommand, oneshot::Sender<RocketMQResult<()>>),
SendBytes(Bytes, oneshot::Sender<RocketMQResult<()>>),
SendCommandsBatch(Vec<RemotingCommand>, oneshot::Sender<RocketMQResult<()>>),
SendBytesBatch(Vec<Bytes>, oneshot::Sender<RocketMQResult<()>>),
SendZeroCopy(Vec<Bytes>, oneshot::Sender<RocketMQResult<()>>),
SendHybrid(RemotingCommand, Vec<Bytes>, oneshot::Sender<RocketMQResult<()>>),
SendHybridVectored(Bytes, Vec<Bytes>, oneshot::Sender<RocketMQResult<()>>),
Close(oneshot::Sender<RocketMQResult<()>>),
}
pub struct RefactoredConnection {
framed_reader: FramedRead<OwnedReadHalf, CompositeCodec>,
framed_writer: FramedWrite<OwnedWriteHalf, CompositeCodec>,
encode_buffer: BytesMut,
state_tx: watch::Sender<ConnectionState>,
state_rx: watch::Receiver<ConnectionState>,
connection_id: CheetahString,
}
pub struct ConcurrentConnection {
framed_reader: FramedRead<OwnedReadHalf, CompositeCodec>,
write_tx: mpsc::Sender<WriteCommand>,
state_rx: watch::Receiver<ConnectionState>,
writer_handle: JoinHandle<()>,
connection_id: CheetahString,
}
impl RefactoredConnection {
pub fn new(stream: TcpStream) -> Self {
Self::with_capacity(stream, 1024 * 1024) }
pub fn with_capacity(stream: TcpStream, capacity: usize) -> Self {
let (read_half, write_half) = stream.into_split();
let framed_reader = FramedRead::with_capacity(read_half, CompositeCodec::new(), capacity);
let framed_writer = FramedWrite::new(write_half, CompositeCodec::new());
let (state_tx, state_rx) = watch::channel(ConnectionState::Healthy);
Self {
framed_reader,
framed_writer,
encode_buffer: BytesMut::with_capacity(capacity),
state_tx,
state_rx,
connection_id: CheetahString::from_string(Uuid::new_v4().to_string()),
}
}
pub async fn send_command(&mut self, mut command: RemotingCommand) -> RocketMQResult<()> {
command.fast_header_encode(&mut self.encode_buffer);
if let Some(body) = command.take_body() {
self.encode_buffer.extend_from_slice(&body);
}
let bytes = self.encode_buffer.split().freeze();
self.framed_writer.send(bytes).await
}
pub async fn recv_command(&mut self) -> RocketMQResult<Option<RemotingCommand>> {
self.framed_reader.next().await.transpose()
}
pub async fn send_bytes(&mut self, bytes: Bytes) -> RocketMQResult<()> {
self.framed_writer.flush().await?;
let inner = self.framed_writer.get_mut();
inner.write_all(&bytes).await?;
inner.flush().await?;
Ok(())
}
pub async fn send_commands_batch(&mut self, commands: Vec<RemotingCommand>) -> RocketMQResult<()> {
for mut command in commands {
command.fast_header_encode(&mut self.encode_buffer);
if let Some(body) = command.take_body() {
self.encode_buffer.extend_from_slice(&body);
}
let bytes = self.encode_buffer.split().freeze();
self.framed_writer.feed(bytes).await?;
}
self.framed_writer.flush().await
}
pub async fn send_bytes_batch(&mut self, chunks: Vec<Bytes>) -> RocketMQResult<()> {
self.framed_writer.flush().await?;
let inner = self.framed_writer.get_mut();
for chunk in chunks {
inner.write_all(&chunk).await?;
}
inner.flush().await?;
Ok(())
}
pub async fn send_bytes_zero_copy(&mut self, chunks: Vec<Bytes>) -> RocketMQResult<()> {
use std::io::IoSlice;
self.framed_writer.flush().await?;
let inner = self.framed_writer.get_mut();
let mut slices: Vec<IoSlice> = chunks.iter().map(|b| IoSlice::new(b.as_ref())).collect();
write_all_vectored(inner, &mut slices).await?;
inner.flush().await?;
Ok(())
}
pub async fn send_bytes_zero_copy_single(&mut self, data: Bytes) -> RocketMQResult<()> {
self.framed_writer.flush().await?;
let inner = self.framed_writer.get_mut();
inner.write_all(&data).await?;
inner.flush().await?;
Ok(())
}
pub async fn send_response_hybrid(
&mut self,
mut response_header: RemotingCommand,
message_bodies: Vec<Bytes>,
) -> RocketMQResult<()> {
response_header.fast_header_encode(&mut self.encode_buffer);
if let Some(body) = response_header.take_body() {
self.encode_buffer.extend_from_slice(&body);
}
let header_bytes = self.encode_buffer.split().freeze();
self.framed_writer.send(header_bytes).await?;
self.framed_writer.flush().await?;
let inner = self.framed_writer.get_mut();
for body in message_bodies {
inner.write_all(&body).await?;
}
inner.flush().await?;
Ok(())
}
pub async fn send_response_hybrid_vectored(
&mut self,
response_header_bytes: Bytes,
message_bodies: Vec<Bytes>,
) -> RocketMQResult<()> {
use std::io::IoSlice;
self.framed_writer.flush().await?;
let mut slices = Vec::with_capacity(1 + message_bodies.len());
slices.push(IoSlice::new(response_header_bytes.as_ref()));
for body in &message_bodies {
slices.push(IoSlice::new(body.as_ref()));
}
let inner = self.framed_writer.get_mut();
write_all_vectored(inner, &mut slices).await?;
inner.flush().await?;
Ok(())
}
pub fn state(&self) -> ConnectionState {
*self.state_rx.borrow()
}
pub fn mark_degraded(&self) {
let _ = self.state_tx.send(ConnectionState::Degraded);
}
pub fn mark_healthy(&self) {
let _ = self.state_tx.send(ConnectionState::Healthy);
}
pub async fn close(&mut self) -> RocketMQResult<()> {
let _ = self.state_tx.send(ConnectionState::Closed);
self.framed_writer.flush().await?;
self.framed_writer.get_mut().shutdown().await.map_err(|e| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
format!("{}", e),
))
})
}
pub fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
self.state_rx.clone()
}
pub fn connection_id(&self) -> &CheetahString {
&self.connection_id
}
pub fn framed_reader_mut(&mut self) -> &mut FramedRead<OwnedReadHalf, CompositeCodec> {
&mut self.framed_reader
}
pub fn framed_writer_mut(&mut self) -> &mut FramedWrite<OwnedWriteHalf, CompositeCodec> {
&mut self.framed_writer
}
}
impl ConcurrentConnection {
pub fn new(stream: TcpStream) -> Self {
Self::with_channel_capacity(stream, 1024)
}
pub fn with_channel_capacity(stream: TcpStream, channel_capacity: usize) -> Self {
let (read_half, write_half) = stream.into_split();
let framed_reader = FramedRead::new(read_half, CompositeCodec::default());
let framed_writer = FramedWrite::new(write_half, CompositeCodec::default());
let (write_tx, write_rx) = mpsc::channel(channel_capacity);
let (state_tx, state_rx) = watch::channel(ConnectionState::Healthy);
let writer_handle = tokio::spawn(Self::writer_task(framed_writer, write_rx, state_tx));
Self {
framed_reader,
write_tx,
state_rx,
writer_handle,
connection_id: CheetahString::from_string(format!("concurrent-{}", uuid::Uuid::new_v4())),
}
}
async fn writer_task(
mut framed_writer: FramedWrite<OwnedWriteHalf, CompositeCodec>,
mut write_rx: mpsc::Receiver<WriteCommand>,
state_tx: watch::Sender<ConnectionState>,
) {
let mut encode_buffer = BytesMut::with_capacity(1024 * 1024);
while let Some(cmd) = write_rx.recv().await {
match cmd {
WriteCommand::SendCommand(remote_cmd, response_tx) => {
let result = Self::handle_send_command(&mut framed_writer, &mut encode_buffer, remote_cmd).await;
let _ = response_tx.send(result);
}
WriteCommand::SendBytes(bytes, response_tx) => {
let result = Self::handle_send_bytes(&mut framed_writer, bytes).await;
let _ = response_tx.send(result);
}
WriteCommand::SendCommandsBatch(commands, response_tx) => {
let result =
Self::handle_send_commands_batch(&mut framed_writer, &mut encode_buffer, commands).await;
let _ = response_tx.send(result);
}
WriteCommand::SendBytesBatch(bytes_vec, response_tx) => {
let result = Self::handle_send_bytes_batch(&mut framed_writer, bytes_vec).await;
let _ = response_tx.send(result);
}
WriteCommand::SendZeroCopy(bytes_vec, response_tx) => {
let result = Self::handle_send_zero_copy(&mut framed_writer, bytes_vec).await;
let _ = response_tx.send(result);
}
WriteCommand::SendHybrid(remote_cmd, bodies, response_tx) => {
let result =
Self::handle_send_hybrid(&mut framed_writer, &mut encode_buffer, remote_cmd, bodies).await;
let _ = response_tx.send(result);
}
WriteCommand::SendHybridVectored(header_bytes, bodies, response_tx) => {
let result = Self::handle_send_hybrid_vectored(&mut framed_writer, header_bytes, bodies).await;
let _ = response_tx.send(result);
}
WriteCommand::Close(response_tx) => {
let _ = framed_writer.flush().await;
let _ = response_tx.send(Ok(()));
let _ = state_tx.send(ConnectionState::Closed);
break;
}
}
}
}
async fn handle_send_command(
framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
encode_buffer: &mut BytesMut,
mut remote_cmd: RemotingCommand,
) -> RocketMQResult<()> {
remote_cmd.fast_header_encode(encode_buffer);
if let Some(body) = remote_cmd.take_body() {
encode_buffer.extend_from_slice(&body);
}
let bytes = encode_buffer.split().freeze();
framed_writer.send(bytes).await?;
framed_writer.flush().await?;
Ok(())
}
async fn handle_send_bytes(
framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
bytes: Bytes,
) -> RocketMQResult<()> {
framed_writer.send(bytes).await?;
framed_writer.flush().await?;
Ok(())
}
async fn handle_send_commands_batch(
framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
encode_buffer: &mut BytesMut,
commands: Vec<RemotingCommand>,
) -> RocketMQResult<()> {
for mut cmd in commands {
cmd.fast_header_encode(encode_buffer);
if let Some(body) = cmd.take_body() {
encode_buffer.extend_from_slice(&body);
}
let bytes = encode_buffer.split().freeze();
framed_writer.feed(bytes).await?;
}
framed_writer.flush().await?;
Ok(())
}
async fn handle_send_bytes_batch(
framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
bytes_vec: Vec<Bytes>,
) -> RocketMQResult<()> {
for bytes in bytes_vec {
framed_writer.feed(bytes).await?;
}
framed_writer.flush().await?;
Ok(())
}
async fn handle_send_zero_copy(
framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
bytes_vec: Vec<Bytes>,
) -> RocketMQResult<()> {
let mut io_slices: Vec<IoSlice> = bytes_vec.iter().map(|b| IoSlice::new(b.as_ref())).collect();
write_all_vectored(framed_writer.get_mut(), &mut io_slices).await?;
framed_writer.flush().await?;
Ok(())
}
async fn handle_send_hybrid(
framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
encode_buffer: &mut BytesMut,
mut remote_cmd: RemotingCommand,
bodies: Vec<Bytes>,
) -> RocketMQResult<()> {
remote_cmd.fast_header_encode(encode_buffer);
if let Some(body) = remote_cmd.take_body() {
encode_buffer.extend_from_slice(&body);
}
let header_bytes = encode_buffer.split().freeze();
framed_writer.send(header_bytes).await?;
let mut io_slices: Vec<IoSlice> = bodies.iter().map(|b| IoSlice::new(b.as_ref())).collect();
write_all_vectored(framed_writer.get_mut(), &mut io_slices).await?;
framed_writer.flush().await?;
Ok(())
}
async fn handle_send_hybrid_vectored(
framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
header_bytes: Bytes,
bodies: Vec<Bytes>,
) -> RocketMQResult<()> {
let mut all_bytes = vec![header_bytes];
all_bytes.extend(bodies);
let mut io_slices: Vec<IoSlice> = all_bytes.iter().map(|b| IoSlice::new(b.as_ref())).collect();
write_all_vectored(framed_writer.get_mut(), &mut io_slices).await?;
framed_writer.flush().await?;
Ok(())
}
pub async fn send_command(&self, remote_cmd: RemotingCommand) -> RocketMQResult<()> {
let (tx, rx) = oneshot::channel();
self.write_tx
.send(WriteCommand::SendCommand(remote_cmd, tx))
.await
.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Writer task closed",
))
})?;
rx.await.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Response channel closed",
))
})?
}
pub async fn send_bytes(&self, bytes: Bytes) -> RocketMQResult<()> {
let (tx, rx) = oneshot::channel();
self.write_tx
.send(WriteCommand::SendBytes(bytes, tx))
.await
.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Writer task closed",
))
})?;
rx.await.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Response channel closed",
))
})?
}
pub async fn send_commands_batch(&self, commands: Vec<RemotingCommand>) -> RocketMQResult<()> {
let (tx, rx) = oneshot::channel();
self.write_tx
.send(WriteCommand::SendCommandsBatch(commands, tx))
.await
.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Writer task closed",
))
})?;
rx.await.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Response channel closed",
))
})?
}
pub async fn send_bytes_batch(&self, bytes_vec: Vec<Bytes>) -> RocketMQResult<()> {
let (tx, rx) = oneshot::channel();
self.write_tx
.send(WriteCommand::SendBytesBatch(bytes_vec, tx))
.await
.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Writer task closed",
))
})?;
rx.await.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Response channel closed",
))
})?
}
pub async fn send_bytes_zero_copy(&self, bytes_vec: Vec<Bytes>) -> RocketMQResult<()> {
let (tx, rx) = oneshot::channel();
self.write_tx
.send(WriteCommand::SendZeroCopy(bytes_vec, tx))
.await
.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Writer task closed",
))
})?;
rx.await.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Response channel closed",
))
})?
}
pub async fn send_response_hybrid(&self, response: RemotingCommand, bodies: Vec<Bytes>) -> RocketMQResult<()> {
let (tx, rx) = oneshot::channel();
self.write_tx
.send(WriteCommand::SendHybrid(response, bodies, tx))
.await
.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Writer task closed",
))
})?;
rx.await.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Response channel closed",
))
})?
}
pub async fn send_response_hybrid_vectored(&self, header_bytes: Bytes, bodies: Vec<Bytes>) -> RocketMQResult<()> {
let (tx, rx) = oneshot::channel();
self.write_tx
.send(WriteCommand::SendHybridVectored(header_bytes, bodies, tx))
.await
.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Writer task closed",
))
})?;
rx.await.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Response channel closed",
))
})?
}
pub async fn recv_command(&mut self) -> RocketMQResult<Option<RemotingCommand>> {
self.framed_reader.next().await.transpose()
}
pub fn state(&self) -> ConnectionState {
*self.state_rx.borrow()
}
pub fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
self.state_rx.clone()
}
pub fn connection_id(&self) -> &CheetahString {
&self.connection_id
}
pub(crate) fn clone_sender(&self) -> mpsc::Sender<WriteCommand> {
self.write_tx.clone()
}
pub async fn close(self) -> RocketMQResult<()> {
let (tx, rx) = oneshot::channel();
self.write_tx.send(WriteCommand::Close(tx)).await.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Writer task closed",
))
})?;
rx.await.map_err(|_| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
"Response channel closed",
))
})??;
self.writer_handle.await.map_err(|e| {
RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
"connection",
format!("{}", e),
))
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use tokio::net::TcpListener;
use tokio::time::sleep;
use tokio::time::Duration;
use super::*;
use crate::protocol::header::empty_header::EmptyHeader;
use crate::protocol::remoting_command::RemotingCommand;
#[tokio::test]
async fn test_framed_connection_basic() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let mut conn = RefactoredConnection::new(stream);
let cmd = RemotingCommand::create_request_command(100, EmptyHeader {}).set_body(Bytes::from("test data"));
conn.send_command(cmd).await.unwrap();
sleep(Duration::from_millis(100)).await;
});
let (socket, _) = listener.accept().await.unwrap();
let mut server_conn = RefactoredConnection::new(socket);
let received = server_conn.recv_command().await.unwrap();
assert!(received.is_some());
let cmd = received.unwrap();
assert_eq!(cmd.code(), 100);
let expected = Bytes::from("test data");
assert_eq!(&expected, cmd.body().as_ref().unwrap());
client.await.unwrap();
}
#[tokio::test]
async fn test_batch_send() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let mut conn = RefactoredConnection::new(stream);
let commands = vec![
RemotingCommand::create_request_command(101, EmptyHeader {}),
RemotingCommand::create_request_command(102, EmptyHeader {}),
RemotingCommand::create_request_command(103, EmptyHeader {}),
];
conn.send_commands_batch(commands).await.unwrap();
sleep(Duration::from_millis(100)).await;
});
let (socket, _) = listener.accept().await.unwrap();
let mut server_conn = RefactoredConnection::new(socket);
for expected_code in [101, 102, 103] {
let received = server_conn.recv_command().await.unwrap();
assert!(received.is_some());
assert_eq!(received.unwrap().code(), expected_code);
}
client.await.unwrap();
}
#[tokio::test]
async fn test_zero_copy_send() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let mut conn = RefactoredConnection::new(stream);
let chunks = vec![Bytes::from("Part1"), Bytes::from("Part2"), Bytes::from("Part3")];
conn.send_bytes_zero_copy(chunks).await.unwrap();
});
let (socket, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 1024];
sleep(Duration::from_millis(100)).await;
let n = socket.try_read(&mut buf).unwrap();
assert_eq!(&buf[..n], b"Part1Part2Part3");
client.await.unwrap();
}
#[tokio::test]
async fn test_hybrid_vectored() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let mut conn = RefactoredConnection::new(stream);
let header = Bytes::from("HEADER:");
let bodies = vec![Bytes::from("Body1"), Bytes::from("|"), Bytes::from("Body2")];
conn.send_response_hybrid_vectored(header, bodies).await.unwrap();
});
let (socket, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 1024];
sleep(Duration::from_millis(100)).await;
let n = socket.try_read(&mut buf).unwrap();
assert_eq!(&buf[..n], b"HEADER:Body1|Body2");
client.await.unwrap();
}
#[tokio::test]
async fn test_connection_state() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let stream = TcpStream::connect(addr).await.unwrap();
let mut conn = RefactoredConnection::new(stream);
assert_eq!(conn.state(), ConnectionState::Healthy);
conn.mark_degraded();
assert_eq!(conn.state(), ConnectionState::Degraded);
conn.mark_healthy();
assert_eq!(conn.state(), ConnectionState::Healthy);
conn.close().await.unwrap();
assert_eq!(conn.state(), ConnectionState::Closed);
}
#[tokio::test]
async fn test_state_subscription() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let _accept_handle = tokio::spawn(async move {
let _ = listener.accept().await;
});
let stream = TcpStream::connect(addr).await.unwrap();
let conn = RefactoredConnection::new(stream);
let state_rx = conn.subscribe_state();
conn.mark_degraded();
assert_eq!(*state_rx.borrow(), ConnectionState::Degraded);
}
#[tokio::test]
async fn test_zero_copy_single() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let mut conn = RefactoredConnection::new(stream);
let data = Bytes::from("LargeDataBlock");
conn.send_bytes_zero_copy_single(data).await.unwrap();
});
let (socket, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 1024];
sleep(Duration::from_millis(100)).await;
let n = socket.try_read(&mut buf).unwrap();
assert_eq!(&buf[..n], b"LargeDataBlock");
client.await.unwrap();
}
#[tokio::test]
async fn test_hybrid_standard() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let mut conn = RefactoredConnection::new(stream);
let response = RemotingCommand::create_response_command();
let bodies = vec![Bytes::from("Message1"), Bytes::from("Message2")];
conn.send_response_hybrid(response, bodies).await.unwrap();
});
let (socket, _) = listener.accept().await.unwrap();
let mut server_conn = RefactoredConnection::new(socket);
let received = server_conn.recv_command().await.unwrap();
assert!(received.is_some());
client.await.unwrap();
}
}
#[cfg(test)]
mod concurrent_tests {
use std::time::Duration;
use bytes::Bytes;
use tokio::net::TcpListener;
use tokio::net::TcpStream;
use tokio::time::sleep;
use super::*;
use crate::protocol::header::pull_message_response_header::PullMessageResponseHeader;
#[tokio::test]
async fn test_concurrent_basic() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let conn = ConcurrentConnection::new(stream);
let cmd = RemotingCommand::create_request_command(100, PullMessageResponseHeader::default());
conn.send_command(cmd).await.unwrap();
});
let (socket, _) = listener.accept().await.unwrap();
let mut server_conn = ConcurrentConnection::new(socket);
let received = server_conn.recv_command().await.unwrap();
assert!(received.is_some());
client.await.unwrap();
server_conn.close().await.unwrap();
}
#[tokio::test]
async fn test_concurrent_multi_writers() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let conn = ConcurrentConnection::new(stream);
let mut handles = vec![];
for i in 0..3 {
let conn_clone = conn.clone_sender();
let handle = tokio::spawn(async move {
let cmd = RemotingCommand::create_request_command(100 + i, PullMessageResponseHeader::default());
let (tx, rx) = oneshot::channel();
conn_clone.send(WriteCommand::SendCommand(cmd, tx)).await.unwrap();
rx.await.unwrap().unwrap();
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
conn.close().await.unwrap();
});
let (socket, _) = listener.accept().await.unwrap();
let mut server_conn = ConcurrentConnection::new(socket);
for _ in 0..3 {
let received = server_conn.recv_command().await.unwrap();
assert!(received.is_some());
}
client.await.unwrap();
server_conn.close().await.unwrap();
}
#[tokio::test]
async fn test_concurrent_batch() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let conn = ConcurrentConnection::new(stream);
let bytes_vec = vec![
Bytes::from("Message1"),
Bytes::from("Message2"),
Bytes::from("Message3"),
];
conn.send_bytes_batch(bytes_vec).await.unwrap();
});
let (socket, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 1024];
sleep(Duration::from_millis(100)).await;
let n = socket.try_read(&mut buf).unwrap();
let received = String::from_utf8_lossy(&buf[..n]);
assert!(received.contains("Message1"));
assert!(received.contains("Message2"));
assert!(received.contains("Message3"));
client.await.unwrap();
}
#[tokio::test]
async fn test_concurrent_zero_copy() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let conn = ConcurrentConnection::new(stream);
let chunks = vec![Bytes::from("Zero"), Bytes::from("Copy"), Bytes::from("Test")];
conn.send_bytes_zero_copy(chunks).await.unwrap();
});
let (socket, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 1024];
sleep(Duration::from_millis(100)).await;
let n = socket.try_read(&mut buf).unwrap();
assert_eq!(&buf[..n], b"ZeroCopyTest");
client.await.unwrap();
}
#[tokio::test]
async fn test_concurrent_hybrid() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let conn = ConcurrentConnection::new(stream);
let response = RemotingCommand::create_response_command();
let bodies = vec![Bytes::from("Body1"), Bytes::from("Body2")];
conn.send_response_hybrid(response, bodies).await.unwrap();
});
let (socket, _) = listener.accept().await.unwrap();
let mut server_conn = ConcurrentConnection::new(socket);
let received = server_conn.recv_command().await.unwrap();
assert!(received.is_some());
client.await.unwrap();
server_conn.close().await.unwrap();
}
#[tokio::test]
async fn test_concurrent_state() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let client = tokio::spawn(async move {
let stream = TcpStream::connect(addr).await.unwrap();
let conn = ConcurrentConnection::new(stream);
assert_eq!(conn.state(), ConnectionState::Healthy);
let cmd = RemotingCommand::create_request_command(100, PullMessageResponseHeader::default());
conn.send_command(cmd).await.unwrap();
conn.close().await.unwrap();
});
let (socket, _) = listener.accept().await.unwrap();
let mut server_conn = ConcurrentConnection::new(socket);
assert_eq!(server_conn.state(), ConnectionState::Healthy);
let received = server_conn.recv_command().await.unwrap();
assert!(received.is_some());
client.await.unwrap();
server_conn.close().await.unwrap();
}
}