1use std::io::IoSlice;
24
25use bytes::Bytes;
26use bytes::BytesMut;
27use cheetah_string::CheetahString;
28use futures_util::SinkExt;
29use futures_util::StreamExt;
30use rocketmq_error::RocketMQError;
31use rocketmq_error::RocketMQResult;
32use tokio::io::AsyncWriteExt;
33use tokio::net::tcp::OwnedReadHalf;
34use tokio::net::tcp::OwnedWriteHalf;
35use tokio::net::TcpStream;
36use tokio::sync::mpsc;
37use tokio::sync::oneshot;
38use tokio::sync::watch;
39use tokio::task::JoinHandle;
40use tokio_util::codec::FramedRead;
41use tokio_util::codec::FramedWrite;
42use uuid::Uuid;
43
44use crate::codec::remoting_command_codec::CompositeCodec;
45use crate::protocol::remoting_command::RemotingCommand;
46
47async fn write_all_vectored(
51 writer: &mut OwnedWriteHalf,
52 mut slices: &mut [IoSlice<'_>],
53) -> RocketMQResult<()> {
54 while !slices.is_empty() {
55 let written = writer.write_vectored(slices).await.map_err(|e| {
56 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
57 "write_vectored",
58 format!("{}", e),
59 ))
60 })?;
61
62 if written == 0 {
63 return Err(RocketMQError::Network(
64 rocketmq_error::NetworkError::connection_failed(
65 "write_vectored",
66 "Write returned 0 bytes",
67 ),
68 ));
69 }
70
71 IoSlice::advance_slices(&mut slices, written);
73 }
74 Ok(())
75}
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum ConnectionState {
80 Healthy,
81 Degraded,
82 Closed,
83}
84
85pub(crate) enum WriteCommand {
87 SendCommand(RemotingCommand, oneshot::Sender<RocketMQResult<()>>),
89 SendBytes(Bytes, oneshot::Sender<RocketMQResult<()>>),
91 SendCommandsBatch(Vec<RemotingCommand>, oneshot::Sender<RocketMQResult<()>>),
93 SendBytesBatch(Vec<Bytes>, oneshot::Sender<RocketMQResult<()>>),
95 SendZeroCopy(Vec<Bytes>, oneshot::Sender<RocketMQResult<()>>),
97 SendHybrid(
99 RemotingCommand,
100 Vec<Bytes>,
101 oneshot::Sender<RocketMQResult<()>>,
102 ),
103 SendHybridVectored(Bytes, Vec<Bytes>, oneshot::Sender<RocketMQResult<()>>),
105 Close(oneshot::Sender<RocketMQResult<()>>),
107}
108
109pub struct RefactoredConnection {
144 framed_reader: FramedRead<OwnedReadHalf, CompositeCodec>,
148
149 framed_writer: FramedWrite<OwnedWriteHalf, CompositeCodec>,
153
154 encode_buffer: BytesMut,
158
159 state_tx: watch::Sender<ConnectionState>,
161 state_rx: watch::Receiver<ConnectionState>,
162
163 connection_id: CheetahString,
165}
166
167pub struct ConcurrentConnection {
203 framed_reader: FramedRead<OwnedReadHalf, CompositeCodec>,
205
206 write_tx: mpsc::Sender<WriteCommand>,
208
209 state_rx: watch::Receiver<ConnectionState>,
211
212 writer_handle: JoinHandle<()>,
214
215 connection_id: CheetahString,
217}
218
219impl RefactoredConnection {
220 pub fn new(stream: TcpStream) -> Self {
222 Self::with_capacity(stream, 1024 * 1024) }
224
225 pub fn with_capacity(stream: TcpStream, capacity: usize) -> Self {
227 let (read_half, write_half) = stream.into_split();
229
230 let framed_reader = FramedRead::with_capacity(read_half, CompositeCodec::new(), capacity);
232 let framed_writer = FramedWrite::new(write_half, CompositeCodec::new());
233
234 let (state_tx, state_rx) = watch::channel(ConnectionState::Healthy);
236
237 Self {
238 framed_reader,
239 framed_writer,
240 encode_buffer: BytesMut::with_capacity(capacity),
241 state_tx,
242 state_rx,
243 connection_id: CheetahString::from_string(Uuid::new_v4().to_string()),
244 }
245 }
246
247 pub async fn send_command(&mut self, mut command: RemotingCommand) -> RocketMQResult<()> {
258 command.fast_header_encode(&mut self.encode_buffer);
260 if let Some(body) = command.take_body() {
261 self.encode_buffer.extend_from_slice(&body);
262 }
263
264 let bytes = self.encode_buffer.split().freeze();
266
267 self.framed_writer.send(bytes).await
269 }
270
271 pub async fn recv_command(&mut self) -> RocketMQResult<Option<RemotingCommand>> {
277 self.framed_reader.next().await.transpose()
279 }
280
281 pub async fn send_bytes(&mut self, bytes: Bytes) -> RocketMQResult<()> {
290 self.framed_writer.flush().await?;
292
293 let inner = self.framed_writer.get_mut();
295 inner.write_all(&bytes).await?;
296 inner.flush().await?;
297
298 Ok(())
299 }
300
301 pub async fn send_commands_batch(
310 &mut self,
311 commands: Vec<RemotingCommand>,
312 ) -> RocketMQResult<()> {
313 for mut command in commands {
315 command.fast_header_encode(&mut self.encode_buffer);
317 if let Some(body) = command.take_body() {
318 self.encode_buffer.extend_from_slice(&body);
319 }
320
321 let bytes = self.encode_buffer.split().freeze();
323
324 self.framed_writer.feed(bytes).await?;
325 }
326
327 self.framed_writer.flush().await
329 }
330
331 pub async fn send_bytes_batch(&mut self, chunks: Vec<Bytes>) -> RocketMQResult<()> {
341 self.framed_writer.flush().await?;
343
344 let inner = self.framed_writer.get_mut();
346 for chunk in chunks {
347 inner.write_all(&chunk).await?;
348 }
349
350 inner.flush().await?;
352
353 Ok(())
354 }
355
356 pub async fn send_bytes_zero_copy(&mut self, chunks: Vec<Bytes>) -> RocketMQResult<()> {
369 use std::io::IoSlice;
370
371 self.framed_writer.flush().await?;
373
374 let inner = self.framed_writer.get_mut();
376
377 let mut slices: Vec<IoSlice> = chunks.iter().map(|b| IoSlice::new(b.as_ref())).collect();
379
380 write_all_vectored(inner, &mut slices).await?;
382 inner.flush().await?;
383
384 Ok(())
385 }
386
387 pub async fn send_bytes_zero_copy_single(&mut self, data: Bytes) -> RocketMQResult<()> {
391 self.framed_writer.flush().await?;
393
394 let inner = self.framed_writer.get_mut();
396 inner.write_all(&data).await?;
397 inner.flush().await?;
398
399 Ok(())
400 }
401
402 pub async fn send_response_hybrid(
419 &mut self,
420 mut response_header: RemotingCommand,
421 message_bodies: Vec<Bytes>,
422 ) -> RocketMQResult<()> {
423 response_header.fast_header_encode(&mut self.encode_buffer);
426 if let Some(body) = response_header.take_body() {
427 self.encode_buffer.extend_from_slice(&body);
428 }
429 let header_bytes = self.encode_buffer.split().freeze();
430
431 self.framed_writer.send(header_bytes).await?;
432
433 self.framed_writer.flush().await?;
435
436 let inner = self.framed_writer.get_mut();
438 for body in message_bodies {
439 inner.write_all(&body).await?;
440 }
441
442 inner.flush().await?;
444
445 Ok(())
446 }
447
448 pub async fn send_response_hybrid_vectored(
461 &mut self,
462 response_header_bytes: Bytes,
463 message_bodies: Vec<Bytes>,
464 ) -> RocketMQResult<()> {
465 use std::io::IoSlice;
466
467 self.framed_writer.flush().await?;
469
470 let mut slices = Vec::with_capacity(1 + message_bodies.len());
472 slices.push(IoSlice::new(response_header_bytes.as_ref()));
473 for body in &message_bodies {
474 slices.push(IoSlice::new(body.as_ref()));
475 }
476
477 let inner = self.framed_writer.get_mut();
479 write_all_vectored(inner, &mut slices).await?;
480 inner.flush().await?;
481
482 Ok(())
483 }
484
485 pub fn state(&self) -> ConnectionState {
489 *self.state_rx.borrow()
490 }
491
492 pub fn mark_degraded(&self) {
496 let _ = self.state_tx.send(ConnectionState::Degraded);
497 }
498
499 pub fn mark_healthy(&self) {
501 let _ = self.state_tx.send(ConnectionState::Healthy);
502 }
503
504 pub async fn close(&mut self) -> RocketMQResult<()> {
512 let _ = self.state_tx.send(ConnectionState::Closed);
513
514 self.framed_writer.flush().await?;
516
517 self.framed_writer.get_mut().shutdown().await.map_err(|e| {
518 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
519 "connection",
520 format!("{}", e),
521 ))
522 })
523 }
524
525 pub fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
529 self.state_rx.clone()
530 }
531
532 pub fn connection_id(&self) -> &CheetahString {
534 &self.connection_id
535 }
536
537 pub fn framed_reader_mut(&mut self) -> &mut FramedRead<OwnedReadHalf, CompositeCodec> {
545 &mut self.framed_reader
546 }
547
548 pub fn framed_writer_mut(&mut self) -> &mut FramedWrite<OwnedWriteHalf, CompositeCodec> {
554 &mut self.framed_writer
555 }
556}
557
558impl ConcurrentConnection {
559 pub fn new(stream: TcpStream) -> Self {
561 Self::with_channel_capacity(stream, 1024)
562 }
563
564 pub fn with_channel_capacity(stream: TcpStream, channel_capacity: usize) -> Self {
566 let (read_half, write_half) = stream.into_split();
567
568 let framed_reader = FramedRead::new(read_half, CompositeCodec::default());
569 let framed_writer = FramedWrite::new(write_half, CompositeCodec::default());
570
571 let (write_tx, write_rx) = mpsc::channel(channel_capacity);
572 let (state_tx, state_rx) = watch::channel(ConnectionState::Healthy);
573
574 let writer_handle =
576 tokio::spawn(Self::writer_task(framed_writer, write_rx, state_tx.clone()));
577
578 Self {
579 framed_reader,
580 write_tx,
581 state_rx,
582 writer_handle,
583 connection_id: CheetahString::from_string(format!(
584 "concurrent-{}",
585 uuid::Uuid::new_v4()
586 )),
587 }
588 }
589
590 async fn writer_task(
592 mut framed_writer: FramedWrite<OwnedWriteHalf, CompositeCodec>,
593 mut write_rx: mpsc::Receiver<WriteCommand>,
594 state_tx: watch::Sender<ConnectionState>,
595 ) {
596 let mut encode_buffer = BytesMut::with_capacity(1024 * 1024);
597
598 while let Some(cmd) = write_rx.recv().await {
599 match cmd {
600 WriteCommand::SendCommand(remote_cmd, response_tx) => {
601 let result = Self::handle_send_command(
602 &mut framed_writer,
603 &mut encode_buffer,
604 remote_cmd,
605 )
606 .await;
607 let _ = response_tx.send(result);
608 }
609 WriteCommand::SendBytes(bytes, response_tx) => {
610 let result = Self::handle_send_bytes(&mut framed_writer, bytes).await;
611 let _ = response_tx.send(result);
612 }
613 WriteCommand::SendCommandsBatch(commands, response_tx) => {
614 let result = Self::handle_send_commands_batch(
615 &mut framed_writer,
616 &mut encode_buffer,
617 commands,
618 )
619 .await;
620 let _ = response_tx.send(result);
621 }
622 WriteCommand::SendBytesBatch(bytes_vec, response_tx) => {
623 let result = Self::handle_send_bytes_batch(&mut framed_writer, bytes_vec).await;
624 let _ = response_tx.send(result);
625 }
626 WriteCommand::SendZeroCopy(bytes_vec, response_tx) => {
627 let result = Self::handle_send_zero_copy(&mut framed_writer, bytes_vec).await;
628 let _ = response_tx.send(result);
629 }
630 WriteCommand::SendHybrid(remote_cmd, bodies, response_tx) => {
631 let result = Self::handle_send_hybrid(
632 &mut framed_writer,
633 &mut encode_buffer,
634 remote_cmd,
635 bodies,
636 )
637 .await;
638 let _ = response_tx.send(result);
639 }
640 WriteCommand::SendHybridVectored(header_bytes, bodies, response_tx) => {
641 let result =
642 Self::handle_send_hybrid_vectored(&mut framed_writer, header_bytes, bodies)
643 .await;
644 let _ = response_tx.send(result);
645 }
646 WriteCommand::Close(response_tx) => {
647 let _ = framed_writer.flush().await;
648 let _ = response_tx.send(Ok(()));
649 let _ = state_tx.send(ConnectionState::Closed);
650 break;
651 }
652 }
653 }
654 }
655
656 async fn handle_send_command(
658 framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
659 encode_buffer: &mut BytesMut,
660 mut remote_cmd: RemotingCommand,
661 ) -> RocketMQResult<()> {
662 remote_cmd.fast_header_encode(encode_buffer);
663 if let Some(body) = remote_cmd.take_body() {
664 encode_buffer.extend_from_slice(&body);
665 }
666 let bytes = encode_buffer.split().freeze();
667 framed_writer.send(bytes).await?;
668 framed_writer.flush().await?;
669 Ok(())
670 }
671
672 async fn handle_send_bytes(
674 framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
675 bytes: Bytes,
676 ) -> RocketMQResult<()> {
677 framed_writer.send(bytes).await?;
678 framed_writer.flush().await?;
679 Ok(())
680 }
681
682 async fn handle_send_commands_batch(
684 framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
685 encode_buffer: &mut BytesMut,
686 commands: Vec<RemotingCommand>,
687 ) -> RocketMQResult<()> {
688 for mut cmd in commands {
689 cmd.fast_header_encode(encode_buffer);
690 if let Some(body) = cmd.take_body() {
691 encode_buffer.extend_from_slice(&body);
692 }
693 let bytes = encode_buffer.split().freeze();
694 framed_writer.feed(bytes).await?;
695 }
696 framed_writer.flush().await?;
697 Ok(())
698 }
699
700 async fn handle_send_bytes_batch(
702 framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
703 bytes_vec: Vec<Bytes>,
704 ) -> RocketMQResult<()> {
705 for bytes in bytes_vec {
706 framed_writer.feed(bytes).await?;
707 }
708 framed_writer.flush().await?;
709 Ok(())
710 }
711
712 async fn handle_send_zero_copy(
714 framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
715 bytes_vec: Vec<Bytes>,
716 ) -> RocketMQResult<()> {
717 let mut io_slices: Vec<IoSlice> =
718 bytes_vec.iter().map(|b| IoSlice::new(b.as_ref())).collect();
719 write_all_vectored(framed_writer.get_mut(), &mut io_slices).await?;
720 framed_writer.flush().await?;
721 Ok(())
722 }
723
724 async fn handle_send_hybrid(
726 framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
727 encode_buffer: &mut BytesMut,
728 mut remote_cmd: RemotingCommand,
729 bodies: Vec<Bytes>,
730 ) -> RocketMQResult<()> {
731 remote_cmd.fast_header_encode(encode_buffer);
733 if let Some(body) = remote_cmd.take_body() {
734 encode_buffer.extend_from_slice(&body);
735 }
736 let header_bytes = encode_buffer.split().freeze();
737 framed_writer.send(header_bytes).await?;
738
739 let mut io_slices: Vec<IoSlice> = bodies.iter().map(|b| IoSlice::new(b.as_ref())).collect();
741 write_all_vectored(framed_writer.get_mut(), &mut io_slices).await?;
742 framed_writer.flush().await?;
743 Ok(())
744 }
745
746 async fn handle_send_hybrid_vectored(
748 framed_writer: &mut FramedWrite<OwnedWriteHalf, CompositeCodec>,
749 header_bytes: Bytes,
750 bodies: Vec<Bytes>,
751 ) -> RocketMQResult<()> {
752 let mut all_bytes = vec![header_bytes];
753 all_bytes.extend(bodies);
754
755 let mut io_slices: Vec<IoSlice> =
756 all_bytes.iter().map(|b| IoSlice::new(b.as_ref())).collect();
757 write_all_vectored(framed_writer.get_mut(), &mut io_slices).await?;
758 framed_writer.flush().await?;
759 Ok(())
760 }
761
762 pub async fn send_command(&self, remote_cmd: RemotingCommand) -> RocketMQResult<()> {
764 let (tx, rx) = oneshot::channel();
765 self.write_tx
766 .send(WriteCommand::SendCommand(remote_cmd, tx))
767 .await
768 .map_err(|_| {
769 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
770 "connection",
771 "Writer task closed",
772 ))
773 })?;
774 rx.await.map_err(|_| {
775 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
776 "connection",
777 "Response channel closed",
778 ))
779 })?
780 }
781
782 pub async fn send_bytes(&self, bytes: Bytes) -> RocketMQResult<()> {
784 let (tx, rx) = oneshot::channel();
785 self.write_tx
786 .send(WriteCommand::SendBytes(bytes, tx))
787 .await
788 .map_err(|_| {
789 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
790 "connection",
791 "Writer task closed",
792 ))
793 })?;
794 rx.await.map_err(|_| {
795 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
796 "connection",
797 "Response channel closed",
798 ))
799 })?
800 }
801
802 pub async fn send_commands_batch(&self, commands: Vec<RemotingCommand>) -> RocketMQResult<()> {
804 let (tx, rx) = oneshot::channel();
805 self.write_tx
806 .send(WriteCommand::SendCommandsBatch(commands, tx))
807 .await
808 .map_err(|_| {
809 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
810 "connection",
811 "Writer task closed",
812 ))
813 })?;
814 rx.await.map_err(|_| {
815 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
816 "connection",
817 "Response channel closed",
818 ))
819 })?
820 }
821
822 pub async fn send_bytes_batch(&self, bytes_vec: Vec<Bytes>) -> RocketMQResult<()> {
824 let (tx, rx) = oneshot::channel();
825 self.write_tx
826 .send(WriteCommand::SendBytesBatch(bytes_vec, tx))
827 .await
828 .map_err(|_| {
829 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
830 "connection",
831 "Writer task closed",
832 ))
833 })?;
834 rx.await.map_err(|_| {
835 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
836 "connection",
837 "Response channel closed",
838 ))
839 })?
840 }
841
842 pub async fn send_bytes_zero_copy(&self, bytes_vec: Vec<Bytes>) -> RocketMQResult<()> {
844 let (tx, rx) = oneshot::channel();
845 self.write_tx
846 .send(WriteCommand::SendZeroCopy(bytes_vec, tx))
847 .await
848 .map_err(|_| {
849 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
850 "connection",
851 "Writer task closed",
852 ))
853 })?;
854 rx.await.map_err(|_| {
855 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
856 "connection",
857 "Response channel closed",
858 ))
859 })?
860 }
861
862 pub async fn send_response_hybrid(
864 &self,
865 response: RemotingCommand,
866 bodies: Vec<Bytes>,
867 ) -> RocketMQResult<()> {
868 let (tx, rx) = oneshot::channel();
869 self.write_tx
870 .send(WriteCommand::SendHybrid(response, bodies, tx))
871 .await
872 .map_err(|_| {
873 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
874 "connection",
875 "Writer task closed",
876 ))
877 })?;
878 rx.await.map_err(|_| {
879 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
880 "connection",
881 "Response channel closed",
882 ))
883 })?
884 }
885
886 pub async fn send_response_hybrid_vectored(
888 &self,
889 header_bytes: Bytes,
890 bodies: Vec<Bytes>,
891 ) -> RocketMQResult<()> {
892 let (tx, rx) = oneshot::channel();
893 self.write_tx
894 .send(WriteCommand::SendHybridVectored(header_bytes, bodies, tx))
895 .await
896 .map_err(|_| {
897 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
898 "connection",
899 "Writer task closed",
900 ))
901 })?;
902 rx.await.map_err(|_| {
903 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
904 "connection",
905 "Response channel closed",
906 ))
907 })?
908 }
909
910 pub async fn recv_command(&mut self) -> RocketMQResult<Option<RemotingCommand>> {
912 self.framed_reader.next().await.transpose()
913 }
914
915 pub fn state(&self) -> ConnectionState {
917 *self.state_rx.borrow()
918 }
919
920 pub fn subscribe_state(&self) -> watch::Receiver<ConnectionState> {
922 self.state_rx.clone()
923 }
924
925 pub fn connection_id(&self) -> &CheetahString {
927 &self.connection_id
928 }
929
930 pub(crate) fn clone_sender(&self) -> mpsc::Sender<WriteCommand> {
932 self.write_tx.clone()
933 }
934
935 pub async fn close(self) -> RocketMQResult<()> {
937 let (tx, rx) = oneshot::channel();
938 self.write_tx
939 .send(WriteCommand::Close(tx))
940 .await
941 .map_err(|_| {
942 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
943 "connection",
944 "Writer task closed",
945 ))
946 })?;
947 rx.await.map_err(|_| {
948 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
949 "connection",
950 "Response channel closed",
951 ))
952 })??;
953 self.writer_handle.await.map_err(|e| {
954 RocketMQError::Network(rocketmq_error::NetworkError::connection_failed(
955 "connection",
956 format!("{}", e),
957 ))
958 })?;
959 Ok(())
960 }
961}
962
963#[cfg(test)]
992mod tests {
993 use tokio::net::TcpListener;
994 use tokio::time::sleep;
995 use tokio::time::Duration;
996
997 use super::*;
998 use crate::protocol::header::empty_header::EmptyHeader;
999 use crate::protocol::remoting_command::RemotingCommand;
1000
1001 #[tokio::test]
1003 async fn test_framed_connection_basic() {
1004 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1005 let addr = listener.local_addr().unwrap();
1006
1007 let client = tokio::spawn(async move {
1008 let stream = TcpStream::connect(addr).await.unwrap();
1009 let mut conn = RefactoredConnection::new(stream);
1010
1011 let cmd = RemotingCommand::create_request_command(100, EmptyHeader {})
1013 .set_body(Bytes::from("test data"));
1014
1015 conn.send_command(cmd).await.unwrap();
1016
1017 sleep(Duration::from_millis(100)).await;
1019 });
1020
1021 let (socket, _) = listener.accept().await.unwrap();
1022 let mut server_conn = RefactoredConnection::new(socket);
1023
1024 let received = server_conn.recv_command().await.unwrap();
1026 assert!(received.is_some());
1027
1028 let cmd = received.unwrap();
1029 assert_eq!(cmd.code(), 100);
1030 let expected = Bytes::from("test data");
1031 assert_eq!(&expected, cmd.body().as_ref().unwrap());
1032
1033 client.await.unwrap();
1034 }
1035
1036 #[tokio::test]
1038 async fn test_batch_send() {
1039 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1040 let addr = listener.local_addr().unwrap();
1041
1042 let client = tokio::spawn(async move {
1043 let stream = TcpStream::connect(addr).await.unwrap();
1044 let mut conn = RefactoredConnection::new(stream);
1045
1046 let commands = vec![
1047 RemotingCommand::create_request_command(101, EmptyHeader {}),
1048 RemotingCommand::create_request_command(102, EmptyHeader {}),
1049 RemotingCommand::create_request_command(103, EmptyHeader {}),
1050 ];
1051
1052 conn.send_commands_batch(commands).await.unwrap();
1053
1054 sleep(Duration::from_millis(100)).await;
1055 });
1056
1057 let (socket, _) = listener.accept().await.unwrap();
1058 let mut server_conn = RefactoredConnection::new(socket);
1059
1060 for expected_code in [101, 102, 103] {
1062 let received = server_conn.recv_command().await.unwrap();
1063 assert!(received.is_some());
1064 assert_eq!(received.unwrap().code(), expected_code);
1065 }
1066
1067 client.await.unwrap();
1068 }
1069
1070 #[tokio::test]
1072 async fn test_zero_copy_send() {
1073 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1074 let addr = listener.local_addr().unwrap();
1075
1076 let client = tokio::spawn(async move {
1077 let stream = TcpStream::connect(addr).await.unwrap();
1078 let mut conn = RefactoredConnection::new(stream);
1079
1080 let chunks = vec![
1081 Bytes::from("Part1"),
1082 Bytes::from("Part2"),
1083 Bytes::from("Part3"),
1084 ];
1085
1086 conn.send_bytes_zero_copy(chunks).await.unwrap();
1087 });
1088
1089 let (socket, _) = listener.accept().await.unwrap();
1090 let mut buf = vec![0u8; 1024];
1091
1092 sleep(Duration::from_millis(100)).await;
1094 let n = socket.try_read(&mut buf).unwrap();
1095
1096 assert_eq!(&buf[..n], b"Part1Part2Part3");
1097 client.await.unwrap();
1098 }
1099
1100 #[tokio::test]
1102 async fn test_hybrid_vectored() {
1103 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1104 let addr = listener.local_addr().unwrap();
1105
1106 let client = tokio::spawn(async move {
1107 let stream = TcpStream::connect(addr).await.unwrap();
1108 let mut conn = RefactoredConnection::new(stream);
1109
1110 let header = Bytes::from("HEADER:");
1111 let bodies = vec![Bytes::from("Body1"), Bytes::from("|"), Bytes::from("Body2")];
1112
1113 conn.send_response_hybrid_vectored(header, bodies)
1114 .await
1115 .unwrap();
1116 });
1117
1118 let (socket, _) = listener.accept().await.unwrap();
1119 let mut buf = vec![0u8; 1024];
1120
1121 sleep(Duration::from_millis(100)).await;
1122 let n = socket.try_read(&mut buf).unwrap();
1123
1124 assert_eq!(&buf[..n], b"HEADER:Body1|Body2");
1125 client.await.unwrap();
1126 }
1127
1128 #[tokio::test]
1130 async fn test_connection_state() {
1131 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1132 let addr = listener.local_addr().unwrap();
1133
1134 let stream = TcpStream::connect(addr).await.unwrap();
1135 let mut conn = RefactoredConnection::new(stream);
1136
1137 assert_eq!(conn.state(), ConnectionState::Healthy);
1139
1140 conn.mark_degraded();
1142 assert_eq!(conn.state(), ConnectionState::Degraded);
1143
1144 conn.mark_healthy();
1146 assert_eq!(conn.state(), ConnectionState::Healthy);
1147
1148 conn.close().await.unwrap();
1150 assert_eq!(conn.state(), ConnectionState::Closed);
1151 }
1152
1153 #[tokio::test]
1155 async fn test_state_subscription() {
1156 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1157 let addr = listener.local_addr().unwrap();
1158
1159 let _accept_handle = tokio::spawn(async move {
1160 let _ = listener.accept().await;
1161 });
1162
1163 let stream = TcpStream::connect(addr).await.unwrap();
1164 let conn = RefactoredConnection::new(stream);
1165
1166 let state_rx = conn.subscribe_state();
1167
1168 conn.mark_degraded();
1170
1171 assert_eq!(*state_rx.borrow(), ConnectionState::Degraded);
1173 }
1174
1175 #[tokio::test]
1177 async fn test_zero_copy_single() {
1178 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1179 let addr = listener.local_addr().unwrap();
1180
1181 let client = tokio::spawn(async move {
1182 let stream = TcpStream::connect(addr).await.unwrap();
1183 let mut conn = RefactoredConnection::new(stream);
1184
1185 let data = Bytes::from("LargeDataBlock");
1186 conn.send_bytes_zero_copy_single(data).await.unwrap();
1187 });
1188
1189 let (socket, _) = listener.accept().await.unwrap();
1190 let mut buf = vec![0u8; 1024];
1191
1192 sleep(Duration::from_millis(100)).await;
1193 let n = socket.try_read(&mut buf).unwrap();
1194
1195 assert_eq!(&buf[..n], b"LargeDataBlock");
1196 client.await.unwrap();
1197 }
1198
1199 #[tokio::test]
1201 async fn test_hybrid_standard() {
1202 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1203 let addr = listener.local_addr().unwrap();
1204
1205 let client = tokio::spawn(async move {
1206 let stream = TcpStream::connect(addr).await.unwrap();
1207 let mut conn = RefactoredConnection::new(stream);
1208
1209 let response = RemotingCommand::create_response_command();
1210 let bodies = vec![Bytes::from("Message1"), Bytes::from("Message2")];
1211
1212 conn.send_response_hybrid(response, bodies).await.unwrap();
1213 });
1214
1215 let (socket, _) = listener.accept().await.unwrap();
1216 let mut server_conn = RefactoredConnection::new(socket);
1217
1218 let received = server_conn.recv_command().await.unwrap();
1220 assert!(received.is_some());
1221
1222 client.await.unwrap();
1223 }
1224}
1225
1226#[cfg(test)]
1227mod concurrent_tests {
1228 use std::time::Duration;
1229
1230 use bytes::Bytes;
1231 use tokio::net::TcpListener;
1232 use tokio::net::TcpStream;
1233 use tokio::time::sleep;
1234
1235 use super::*;
1236 use crate::protocol::header::pull_message_response_header::PullMessageResponseHeader;
1237
1238 #[tokio::test]
1240 async fn test_concurrent_basic() {
1241 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1242 let addr = listener.local_addr().unwrap();
1243
1244 let client = tokio::spawn(async move {
1245 let stream = TcpStream::connect(addr).await.unwrap();
1246 let conn = ConcurrentConnection::new(stream);
1247
1248 let cmd =
1249 RemotingCommand::create_request_command(100, PullMessageResponseHeader::default());
1250 conn.send_command(cmd).await.unwrap();
1251 });
1252
1253 let (socket, _) = listener.accept().await.unwrap();
1254 let mut server_conn = ConcurrentConnection::new(socket);
1255
1256 let received = server_conn.recv_command().await.unwrap();
1257 assert!(received.is_some());
1258
1259 client.await.unwrap();
1260 server_conn.close().await.unwrap();
1261 }
1262
1263 #[tokio::test]
1265 async fn test_concurrent_multi_writers() {
1266 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1267 let addr = listener.local_addr().unwrap();
1268
1269 let client = tokio::spawn(async move {
1270 let stream = TcpStream::connect(addr).await.unwrap();
1271 let conn = ConcurrentConnection::new(stream);
1272
1273 let mut handles = vec![];
1275 for i in 0..3 {
1276 let conn_clone = conn.clone_sender();
1277 let handle = tokio::spawn(async move {
1278 let cmd = RemotingCommand::create_request_command(
1279 100 + i,
1280 PullMessageResponseHeader::default(),
1281 );
1282 let (tx, rx) = oneshot::channel();
1283 conn_clone
1284 .send(WriteCommand::SendCommand(cmd, tx))
1285 .await
1286 .unwrap();
1287 rx.await.unwrap().unwrap();
1288 });
1289 handles.push(handle);
1290 }
1291
1292 for handle in handles {
1293 handle.await.unwrap();
1294 }
1295
1296 conn.close().await.unwrap();
1297 });
1298
1299 let (socket, _) = listener.accept().await.unwrap();
1300 let mut server_conn = ConcurrentConnection::new(socket);
1301
1302 for _ in 0..3 {
1304 let received = server_conn.recv_command().await.unwrap();
1305 assert!(received.is_some());
1306 }
1307
1308 client.await.unwrap();
1309 server_conn.close().await.unwrap();
1310 }
1311
1312 #[tokio::test]
1314 async fn test_concurrent_batch() {
1315 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1316 let addr = listener.local_addr().unwrap();
1317
1318 let client = tokio::spawn(async move {
1319 let stream = TcpStream::connect(addr).await.unwrap();
1320 let conn = ConcurrentConnection::new(stream);
1321
1322 let bytes_vec = vec![
1323 Bytes::from("Message1"),
1324 Bytes::from("Message2"),
1325 Bytes::from("Message3"),
1326 ];
1327
1328 conn.send_bytes_batch(bytes_vec).await.unwrap();
1329 });
1330
1331 let (socket, _) = listener.accept().await.unwrap();
1332 let mut buf = vec![0u8; 1024];
1333
1334 sleep(Duration::from_millis(100)).await;
1335 let n = socket.try_read(&mut buf).unwrap();
1336
1337 let received = String::from_utf8_lossy(&buf[..n]);
1338 assert!(received.contains("Message1"));
1339 assert!(received.contains("Message2"));
1340 assert!(received.contains("Message3"));
1341
1342 client.await.unwrap();
1343 }
1344
1345 #[tokio::test]
1347 async fn test_concurrent_zero_copy() {
1348 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1349 let addr = listener.local_addr().unwrap();
1350
1351 let client = tokio::spawn(async move {
1352 let stream = TcpStream::connect(addr).await.unwrap();
1353 let conn = ConcurrentConnection::new(stream);
1354
1355 let chunks = vec![
1356 Bytes::from("Zero"),
1357 Bytes::from("Copy"),
1358 Bytes::from("Test"),
1359 ];
1360
1361 conn.send_bytes_zero_copy(chunks).await.unwrap();
1362 });
1363
1364 let (socket, _) = listener.accept().await.unwrap();
1365 let mut buf = vec![0u8; 1024];
1366
1367 sleep(Duration::from_millis(100)).await;
1368 let n = socket.try_read(&mut buf).unwrap();
1369
1370 assert_eq!(&buf[..n], b"ZeroCopyTest");
1371 client.await.unwrap();
1372 }
1373
1374 #[tokio::test]
1376 async fn test_concurrent_hybrid() {
1377 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1378 let addr = listener.local_addr().unwrap();
1379
1380 let client = tokio::spawn(async move {
1381 let stream = TcpStream::connect(addr).await.unwrap();
1382 let conn = ConcurrentConnection::new(stream);
1383
1384 let response = RemotingCommand::create_response_command();
1385 let bodies = vec![Bytes::from("Body1"), Bytes::from("Body2")];
1386
1387 conn.send_response_hybrid(response, bodies).await.unwrap();
1388 });
1389
1390 let (socket, _) = listener.accept().await.unwrap();
1391 let mut server_conn = ConcurrentConnection::new(socket);
1392
1393 let received = server_conn.recv_command().await.unwrap();
1395 assert!(received.is_some());
1396
1397 client.await.unwrap();
1398 server_conn.close().await.unwrap();
1399 }
1400
1401 #[tokio::test]
1403 async fn test_concurrent_state() {
1404 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1405 let addr = listener.local_addr().unwrap();
1406
1407 let client = tokio::spawn(async move {
1408 let stream = TcpStream::connect(addr).await.unwrap();
1409 let conn = ConcurrentConnection::new(stream);
1410
1411 assert_eq!(conn.state(), ConnectionState::Healthy);
1412
1413 let cmd =
1414 RemotingCommand::create_request_command(100, PullMessageResponseHeader::default());
1415 conn.send_command(cmd).await.unwrap();
1416
1417 conn.close().await.unwrap();
1418 });
1419
1420 let (socket, _) = listener.accept().await.unwrap();
1421 let mut server_conn = ConcurrentConnection::new(socket);
1422
1423 assert_eq!(server_conn.state(), ConnectionState::Healthy);
1424
1425 let received = server_conn.recv_command().await.unwrap();
1426 assert!(received.is_some());
1427
1428 client.await.unwrap();
1429 server_conn.close().await.unwrap();
1430 }
1431}