use crate::frame::{
ContentBody, Frame, FrameHeader, FRAME_CONTENT_BODY, FRAME_END, FRAME_HEADER_SIZE,
};
use amqp_serde::{
to_buffer,
types::{AmqpChannelId, LongUint},
};
use bytes::{Buf, BufMut, BytesMut};
use serde::Serialize;
use std::{
io::{self, Cursor},
pin::Pin,
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf},
net::TcpStream,
};
#[cfg(feature = "tls")]
use tokio_rustls::{client::TlsStream, TlsConnector};
#[cfg(feature = "traces")]
use tracing::trace;
use super::Error;
type Result<T> = std::result::Result<T, Error>;
const DEFAULT_IO_BUFFER_SIZE: usize = 8192;
pub(crate) struct SplitConnection {
reader: BufIoReader,
writer: BufIoWriter,
}
pub(crate) struct BufIoReader {
stream: ReadHalf<SplitIoStream>,
buffer: BytesMut,
}
pub(crate) struct BufIoWriter {
stream: WriteHalf<SplitIoStream>,
buffer: BytesMut,
}
#[allow(clippy::large_enum_variant)]
enum SplitIoStream {
TcpStream(TcpStream),
#[cfg(feature = "tls")]
TlsStream(TlsStream<TcpStream>),
}
impl From<TcpStream> for SplitIoStream {
fn from(stream: TcpStream) -> Self {
SplitIoStream::TcpStream(stream)
}
}
#[cfg(feature = "tls")]
impl From<TlsStream<TcpStream>> for SplitIoStream {
fn from(stream: TlsStream<TcpStream>) -> Self {
SplitIoStream::TlsStream(stream)
}
}
impl AsyncRead for SplitIoStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<io::Result<()>> {
match self.get_mut() {
SplitIoStream::TcpStream(stream) => Pin::new(stream).poll_read(cx, buf),
#[cfg(feature = "tls")]
SplitIoStream::TlsStream(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for SplitIoStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<io::Result<usize>> {
match self.get_mut() {
SplitIoStream::TcpStream(stream) => Pin::new(stream).poll_write(cx, buf),
#[cfg(feature = "tls")]
SplitIoStream::TlsStream(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
match self.get_mut() {
SplitIoStream::TcpStream(stream) => Pin::new(stream).poll_flush(cx),
#[cfg(feature = "tls")]
SplitIoStream::TlsStream(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
match self.get_mut() {
SplitIoStream::TcpStream(stream) => Pin::new(stream).poll_shutdown(cx),
#[cfg(feature = "tls")]
SplitIoStream::TlsStream(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
impl SplitConnection {
pub async fn open(addr: &str) -> Result<Self> {
let stream = TcpStream::connect(addr).await?;
stream.set_nodelay(true)?;
let stream: SplitIoStream = stream.into();
let (reader, writer) = tokio::io::split(stream);
let read_buffer = BytesMut::with_capacity(DEFAULT_IO_BUFFER_SIZE);
let write_buffer = BytesMut::with_capacity(DEFAULT_IO_BUFFER_SIZE);
Ok(Self {
reader: BufIoReader {
stream: reader,
buffer: read_buffer,
},
writer: BufIoWriter {
stream: writer,
buffer: write_buffer,
},
})
}
#[cfg(feature = "tls")]
pub async fn open_tls(addr: &str, domain: &str, connector: &TlsConnector) -> Result<Self> {
let domain = rustls_pki_types::ServerName::try_from(domain.to_owned())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;
let tcp_stream = TcpStream::connect(addr).await?;
tcp_stream.set_nodelay(true)?;
let stream = connector.connect(domain, tcp_stream).await?;
let stream: SplitIoStream = stream.into();
let (reader, writer) = tokio::io::split(stream);
let read_buffer = BytesMut::with_capacity(DEFAULT_IO_BUFFER_SIZE);
let write_buffer = BytesMut::with_capacity(DEFAULT_IO_BUFFER_SIZE);
Ok(Self {
reader: BufIoReader {
stream: reader,
buffer: read_buffer,
},
writer: BufIoWriter {
stream: writer,
buffer: write_buffer,
},
})
}
pub(crate) fn into_split(self) -> (BufIoReader, BufIoWriter) {
(self.reader, self.writer)
}
#[allow(dead_code, /*used for testing only*/)]
pub async fn close(self) -> Result<()> {
self.reader.close().await;
self.writer.close().await
}
pub async fn write<V: Serialize>(&mut self, value: &V) -> Result<usize> {
self.writer.write(value).await
}
pub async fn write_frame(
&mut self,
channel: AmqpChannelId,
frame: Frame,
frame_max: LongUint,
) -> Result<usize> {
self.writer.write_frame(channel, frame, frame_max).await
}
pub async fn read_frame(&mut self) -> Result<ChannelFrame> {
self.reader.read_frame().await
}
}
impl BufIoWriter {
pub async fn write<V: Serialize>(&mut self, value: &V) -> Result<usize> {
to_buffer(value, &mut self.buffer)
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
let len = self.buffer.len();
self.stream.write_all(&self.buffer).await?;
self.buffer.advance(len);
Ok(len)
}
async fn serialize_frame_into_buffer(
&mut self,
channel: AmqpChannelId,
frame: Frame,
) -> Result<()> {
let start_index = self.buffer.len();
let header = FrameHeader {
frame_type: frame.get_frame_type(),
channel,
payload_size: 0,
};
to_buffer(&header, &mut self.buffer).unwrap();
let payload_size = to_buffer(&frame, &mut self.buffer)?;
for (i, v) in (payload_size as u32).to_be_bytes().iter().enumerate() {
let p = self.buffer.get_mut(i + 3 + start_index).unwrap();
*p = *v;
}
self.buffer.put_u8(FRAME_END);
Ok(())
}
async fn serialize_content_body_into_buffer(
&mut self,
channel: AmqpChannelId,
body: ContentBody,
frame_max: usize,
) -> Result<()> {
if body.inner.is_empty() {
return Ok(());
}
let mut cursor = Cursor::new(body.inner);
while cursor.has_remaining() {
let start_index = self.buffer.len();
let header = FrameHeader {
frame_type: FRAME_CONTENT_BODY,
channel,
payload_size: 0,
};
to_buffer(&header, &mut self.buffer).unwrap();
const FRAME_HEADER_AND_ENDER_SIZE: usize = FRAME_HEADER_SIZE + 1;
let payload_size = if cursor.remaining() > (frame_max - FRAME_HEADER_AND_ENDER_SIZE) {
frame_max - FRAME_HEADER_AND_ENDER_SIZE
} else {
cursor.remaining()
};
let current = cursor.position() as usize;
self.buffer
.put(&cursor.get_ref()[current..current + payload_size]);
cursor.advance(payload_size);
for (i, v) in (payload_size as u32).to_be_bytes().iter().enumerate() {
let p = self.buffer.get_mut(i + 3 + start_index).unwrap();
*p = *v;
}
self.buffer.put_u8(FRAME_END);
}
Ok(())
}
pub async fn write_frame(
&mut self,
channel: AmqpChannelId,
frame: Frame,
frame_max: LongUint,
) -> Result<usize> {
#[cfg(feature = "traces")]
trace!("SENT on channel {}: {}", channel, frame);
if let Frame::PublishCombo(publish, content_header, content_body) = frame {
self.serialize_frame_into_buffer(channel, publish.into_frame())
.await?;
self.serialize_frame_into_buffer(channel, content_header.into_frame())
.await?;
self.serialize_content_body_into_buffer(channel, content_body, frame_max as usize)
.await?;
} else {
self.serialize_frame_into_buffer(channel, frame).await?;
}
self.stream.write_all(&self.buffer).await?;
let len = self.buffer.len();
self.buffer.advance(len);
Ok(len)
}
pub async fn close(mut self) -> Result<()> {
self.stream.shutdown().await?;
Ok(())
}
}
type ChannelFrame = (AmqpChannelId, Frame);
impl BufIoReader {
fn decode(&mut self) -> Result<Option<ChannelFrame>> {
match Frame::decode(&self.buffer)? {
Some((len, channel_id, frame)) => {
self.buffer.advance(len);
#[cfg(feature = "traces")]
trace!("RECV on channel {}: {}", channel_id, frame);
Ok(Some((channel_id, frame)))
}
None => Ok(None),
}
}
pub async fn read_frame(&mut self) -> Result<ChannelFrame> {
let result = self.decode()?;
if let Some(frame) = result {
return Ok(frame);
}
loop {
let len = self.stream.read_buf(&mut self.buffer).await?;
if len == 0 {
if self.buffer.is_empty() {
return Err(Error::PeerShutdown);
} else {
return Err(Error::Interrupted);
}
}
#[cfg(feature = "traces")]
trace!("{} bytes read from network", len);
let result = self.decode()?;
match result {
Some(frame) => return Ok(frame),
None => continue,
}
}
}
pub async fn close(self) {}
}
#[cfg(test)]
mod test {
use super::SplitConnection;
use crate::{frame::*, test_utils::setup_logging};
use amqp_serde::types::AmqpPeerProperties;
use tokio::sync::mpsc;
#[tokio::test]
async fn test_open_amqp_connection() {
setup_logging();
let (tx_resp, mut rx_resp) = mpsc::channel(1024);
let (tx_req, mut rx_req) = mpsc::channel(1024);
let (mut reader, mut writer) = SplitConnection::open("localhost:5672")
.await
.unwrap()
.into_split();
writer.write(&ProtocolHeader::default()).await.unwrap();
tokio::spawn(async move {
while let Some((channel_id, frame)) = rx_req.recv().await {
writer
.write_frame(channel_id, frame, FRAME_MIN_SIZE)
.await
.unwrap();
}
});
tokio::spawn(async move {
while let Ok((channel_id, frame)) = reader.read_frame().await {
tx_resp.send((channel_id, frame)).await.unwrap();
}
});
let _start = rx_resp.recv().await.unwrap();
let start_ok = StartOk::new(
AmqpPeerProperties::new(),
"RABBIT-CR-DEMO".try_into().unwrap(),
"user".try_into().unwrap(),
"en_US".try_into().unwrap(),
);
tx_req
.send((DEFAULT_CONN_CHANNEL, start_ok.into_frame()))
.await
.unwrap();
rx_resp.recv().await.unwrap();
let secure_ok = SecureOk::new("My password is bitnami".try_into().unwrap());
tx_req
.send((DEFAULT_CONN_CHANNEL, secure_ok.into_frame()))
.await
.unwrap();
let tune = rx_resp.recv().await.unwrap();
let tune = match tune.1 {
Frame::Tune(_, v) => v,
_ => panic!("expect Tune message"),
};
let tune_ok = TuneOk::new(tune.channel_max(), tune.frame_max(), tune.heartbeat());
tx_req
.send((DEFAULT_CONN_CHANNEL, tune_ok.into_frame()))
.await
.unwrap();
let open = Open::default().into_frame();
tx_req.send((DEFAULT_CONN_CHANNEL, open)).await.unwrap();
let _open_ok = rx_resp.recv().await.unwrap();
tx_req
.send((DEFAULT_CONN_CHANNEL, Close::default().into_frame()))
.await
.unwrap();
let _close_ok = rx_resp.recv().await.unwrap();
}
#[tokio::test]
async fn test_connection_open_close() {
let mut connection = SplitConnection::open("localhost:5672").await.unwrap();
connection.write(&ProtocolHeader::default()).await.unwrap();
let (channel_id, _frame) = connection.read_frame().await.unwrap();
assert_eq!(DEFAULT_CONN_CHANNEL, channel_id);
connection.close().await.unwrap();
}
#[tokio::test]
async fn test_split_open_close() {
let (mut reader, mut writer) = SplitConnection::open("localhost:5672")
.await
.unwrap()
.into_split();
writer.write(&ProtocolHeader::default()).await.unwrap();
let (channel_id, _frame) = reader.read_frame().await.unwrap();
assert_eq!(DEFAULT_CONN_CHANNEL, channel_id);
reader.close().await;
writer.close().await.unwrap();
}
}