use crate::frame;
use crate::{FromServer, Message, Result, ToServer};
use anyhow::{anyhow, bail};
use bytes::{Buf, BytesMut};
use futures::prelude::*;
use futures::sink::SinkExt;
use rustls::pki_types::ServerName;
use std::fmt;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use tokio_rustls::client::TlsStream;
use tokio_util::codec::{Decoder, Encoder, Framed};
use typed_builder::TypedBuilder;
use winnow::Partial;
use winnow::error::ErrMode;
use winnow::stream::Offset;
pub type ClientTransport = Framed<TransportStream, ClientCodec>;
#[allow(clippy::large_enum_variant)]
pub enum TransportStream {
Plain(TcpStream),
Tls(TlsStream<TcpStream>),
}
impl tokio::io::AsyncRead for TransportStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TransportStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
TransportStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl tokio::io::AsyncWrite for TransportStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match self.get_mut() {
TransportStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
TransportStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TransportStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
TransportStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TransportStream::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
TransportStream::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
impl fmt::Debug for TransportStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TransportStream::Plain(_) => write!(f, "Plain TCP connection"),
TransportStream::Tls(_) => write!(f, "TLS connection"),
}
}
}
#[derive(TypedBuilder)]
#[builder(build_method(vis="", name=__build))]
pub struct Connector<S: tokio::net::ToSocketAddrs + Clone, V: Into<String> + Clone> {
server: S,
virtualhost: V,
#[builder(default, setter(strip_option))]
login: Option<String>,
#[builder(default, setter(strip_option))]
passcode: Option<String>,
#[builder(default)]
headers: Vec<(String, String)>,
#[builder(default = false)]
use_tls: bool,
#[builder(default, setter(strip_option))]
tls_server_name: Option<String>,
}
#[allow(non_camel_case_types)]
impl<
S: tokio::net::ToSocketAddrs + Clone,
V: Into<String> + Clone,
__login,
__passcode,
__headers,
__use_tls,
__tls_server_name,
>
ConnectorBuilder<
S,
V,
(
(S,),
(V,),
__login,
__passcode,
__headers,
__use_tls,
__tls_server_name,
),
>
where
Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
(
&'__typed_builder_lifetime_for_default S,
&'__typed_builder_lifetime_for_default V,
__login,
),
Output = Option<String>,
>,
Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
(
&'__typed_builder_lifetime_for_default S,
&'__typed_builder_lifetime_for_default V,
&'__typed_builder_lifetime_for_default Option<String>,
__passcode,
),
Output = Option<String>,
>,
Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
(
&'__typed_builder_lifetime_for_default S,
&'__typed_builder_lifetime_for_default V,
&'__typed_builder_lifetime_for_default Option<String>,
&'__typed_builder_lifetime_for_default Option<String>,
__headers,
),
Output = Vec<(String, String)>,
>,
Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
(
&'__typed_builder_lifetime_for_default S,
&'__typed_builder_lifetime_for_default V,
&'__typed_builder_lifetime_for_default Option<String>,
&'__typed_builder_lifetime_for_default Option<String>,
&'__typed_builder_lifetime_for_default Vec<(String, String)>,
__use_tls,
),
Output = bool,
>,
Connector<S, V>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
(
&'__typed_builder_lifetime_for_default S,
&'__typed_builder_lifetime_for_default V,
&'__typed_builder_lifetime_for_default Option<String>,
&'__typed_builder_lifetime_for_default Option<String>,
&'__typed_builder_lifetime_for_default Vec<(String, String)>,
&'__typed_builder_lifetime_for_default bool,
__tls_server_name,
),
Output = Option<String>,
>,
{
pub async fn connect(self) -> Result<ClientTransport> {
let connector: Connector<S, V> = self.__build();
connector.connect().await
}
pub fn msg(self) -> Message<ToServer> {
let connector = self.__build();
connector.msg()
}
}
impl<S: tokio::net::ToSocketAddrs + Clone, V: Into<String> + Clone> Connector<S, V> {
async fn create_tls_connector(&self) -> Result<TlsConnector> {
let root_store = rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
};
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
Ok(TlsConnector::from(Arc::new(config)))
}
pub async fn connect(self) -> Result<ClientTransport> {
let tcp = TcpStream::connect(self.server.clone()).await?;
let transport_stream = if self.use_tls {
let server_name = if let Some(name) = &self.tls_server_name {
name.clone()
} else {
let server_addr = tcp.peer_addr()?;
let hostname = server_addr.ip().to_string();
if hostname.is_empty() {
return Err(anyhow!(
"Could not determine server hostname for TLS verification"
));
}
hostname
};
let tls_connector = self.create_tls_connector().await?;
let server_name_copy = server_name.clone();
let dns_name = if let Ok(ip_addr) = server_name_copy.parse::<IpAddr>() {
match ip_addr {
IpAddr::V4(ipv4) => ServerName::IpAddress(ipv4.into()),
IpAddr::V6(ipv6) => ServerName::IpAddress(ipv6.into()),
}
} else {
ServerName::DnsName(
server_name_copy
.try_into()
.map_err(|_| anyhow!("Invalid DNS name: {}", server_name))?,
)
};
let tls_stream = tls_connector.connect(dns_name, tcp).await?;
TransportStream::Tls(tls_stream)
} else {
TransportStream::Plain(tcp)
};
let mut transport = ClientCodec.framed(transport_stream);
client_handshake(
&mut transport,
self.virtualhost.into(),
self.login,
self.passcode,
self.headers,
)
.await?;
Ok(transport)
}
pub fn msg(self) -> Message<ToServer> {
let extra_headers = self
.headers
.into_iter()
.map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
.collect();
Message {
content: ToServer::Connect {
accept_version: "1.2".into(),
host: self.virtualhost.into(),
login: self.login,
passcode: self.passcode,
heartbeat: None,
},
extra_headers,
}
}
}
async fn client_handshake(
transport: &mut ClientTransport,
virtualhost: String,
login: Option<String>,
passcode: Option<String>,
headers: Vec<(String, String)>,
) -> Result<()> {
let extra_headers = headers
.iter()
.map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
.collect();
let connect = Message {
content: ToServer::Connect {
accept_version: "1.2".into(),
host: virtualhost,
login,
passcode,
heartbeat: None,
},
extra_headers,
};
transport.send(connect).await?;
let msg = transport.next().await.transpose()?;
if let Some(FromServer::Connected { .. }) = msg.as_ref().map(|m| &m.content) {
Ok(())
} else {
Err(anyhow!("unexpected reply: {:?}", msg))
}
}
#[derive(TypedBuilder)]
#[builder(build_method(vis="", name=__build))]
pub struct Subscriber<S: Into<String>, I: Into<String>> {
destination: S,
id: I,
#[builder(default)]
headers: Vec<(String, String)>,
}
#[allow(non_camel_case_types)]
impl<S: Into<String>, I: Into<String>, __headers> SubscriberBuilder<S, I, ((S,), (I,), __headers)>
where
Subscriber<S, I>: for<'__typed_builder_lifetime_for_default> ::typed_builder::NextFieldDefault<
(
&'__typed_builder_lifetime_for_default S,
&'__typed_builder_lifetime_for_default I,
__headers,
),
Output = Vec<(String, String)>,
>,
{
pub fn subscribe(self) -> Message<ToServer> {
let subscriber = self.__build();
subscriber.subscribe()
}
}
impl<S: Into<String>, I: Into<String>> Subscriber<S, I> {
pub fn subscribe(self) -> Message<ToServer> {
let mut msg: Message<ToServer> = ToServer::Subscribe {
destination: self.destination.into(),
id: self.id.into(),
ack: None,
}
.into();
msg.extra_headers = self
.headers
.iter()
.map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
.collect();
msg
}
}
pub struct ClientCodec;
impl Decoder for ClientCodec {
type Item = Message<FromServer>;
type Error = anyhow::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>> {
let buf = &mut Partial::new(src.chunk());
let item = match frame::parse_frame(buf) {
Ok(frame) => Message::<FromServer>::from_frame(frame),
Err(ErrMode::Incomplete(_)) => return Ok(None), Err(e) => bail!("Parse failed: {:?}", e), };
let len = buf.offset_from(&Partial::new(src.chunk()));
src.advance(len);
item.map(Some)
}
}
impl Encoder<Message<ToServer>> for ClientCodec {
type Error = anyhow::Error;
fn encode(
&mut self,
item: Message<ToServer>,
dst: &mut BytesMut,
) -> std::result::Result<(), Self::Error> {
item.to_frame().serialize(dst);
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::{
Message, ToServer,
client::{Connector, Subscriber},
};
use bytes::BytesMut;
#[test]
fn subscription_message() {
let headers = vec![(
"activemq.subscriptionName".to_string(),
"ClientTest".to_string(),
)];
let subscribe_msg = Subscriber::builder()
.destination("queue.test")
.id("custom-subscriber-id")
.headers(headers.clone())
.subscribe();
let mut expected: Message<ToServer> = ToServer::Subscribe {
destination: "queue.test".to_string(),
id: "custom-subscriber-id".to_string(),
ack: None,
}
.into();
expected.extra_headers = headers
.into_iter()
.map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
.collect();
let mut expected_buffer = BytesMut::new();
expected.to_frame().serialize(&mut expected_buffer);
let mut actual_buffer = BytesMut::new();
subscribe_msg.to_frame().serialize(&mut actual_buffer);
assert_eq!(expected_buffer, actual_buffer);
}
#[test]
fn connection_message() {
let headers = vec![("client-id".to_string(), "ClientTest".to_string())];
let connect_msg = Connector::builder()
.server("stomp.example.com")
.virtualhost("virtual.stomp.example.com")
.login("guest_login".to_string())
.passcode("guest_passcode".to_string())
.headers(headers.clone())
.msg();
let mut expected: Message<ToServer> = ToServer::Connect {
accept_version: "1.2".into(),
host: "virtual.stomp.example.com".into(),
login: Some("guest_login".to_string()),
passcode: Some("guest_passcode".to_string()),
heartbeat: None,
}
.into();
expected.extra_headers = headers
.into_iter()
.map(|(k, v)| (k.as_bytes().to_vec(), v.as_bytes().to_vec()))
.collect();
let mut expected_buffer = BytesMut::new();
expected.to_frame().serialize(&mut expected_buffer);
let mut actual_buffer = BytesMut::new();
connect_msg.to_frame().serialize(&mut actual_buffer);
assert_eq!(expected_buffer, actual_buffer);
}
}