use std::{
fmt,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
};
use amqp_serde::types::{
AmqpChannelId, AmqpPeerProperties, FieldTable, FieldValue, LongStr, LongUint, ShortUint,
};
use tokio::sync::{broadcast, mpsc, oneshot};
use crate::{
frame::{
Blocked, Close, CloseOk, Frame, MethodHeader, Open, OpenChannel, OpenChannelOk,
ProtocolHeader, StartOk, TuneOk, Unblocked, UpdateSecret, UpdateSecretOk,
DEFAULT_CONN_CHANNEL, FRAME_MIN_SIZE,
},
net::{
ChannelResource, ConnManagementCommand, IncomingMessage, OutgoingMessage, ReaderHandler,
RegisterChannelResource, RegisterConnectionCallback, RegisterResponder, SplitConnection,
WriterHandler,
},
};
use super::{
callbacks::ConnectionCallback,
channel::{Channel, ChannelDispatcher},
error::Error,
security::SecurityCredentials,
Result,
};
#[cfg(feature = "tls")]
use super::tls::TlsAdaptor;
#[cfg(feature = "compliance_assert")]
use crate::api::compliance_asserts::assert_path;
#[cfg(feature = "traces")]
use tracing::{debug, error, info};
#[cfg(feature = "urispec")]
use uriparse::URIReference;
const DEFAULT_AMQP_PORT: u16 = 5672;
const DEFAULT_AMQPS_PORT: u16 = 5671;
const DEFAULT_HEARTBEAT: u16 = 60;
const AMQP_SCHEME: &str = "amqp";
const AMQPS_SCHEME: &str = "amqps";
const OUTGOING_MESSAGE_BUFFER_SIZE: usize = 8192;
const CONNECTION_MANAGEMENT_COMMAND_BUFFER_SIZE: usize = 256;
const DEFAULT_LOCALE: &str = "en_US";
#[derive(Debug, Clone)]
pub struct ServerCapabilities {
consumer_cancel_notify: bool,
publisher_confirms: bool,
consumer_priorities: bool,
authentication_failure_close: bool,
per_consumer_qos: bool,
connection_blocked: bool,
exchange_exchange_bindings: bool,
basic_nack: bool,
direct_reply_to: bool,
}
impl ServerCapabilities {
pub fn consumer_cancel_notify(&self) -> bool {
self.consumer_cancel_notify
}
pub fn publisher_confirms(&self) -> bool {
self.publisher_confirms
}
pub fn consumer_priorities(&self) -> bool {
self.consumer_priorities
}
pub fn authentication_failure_close(&self) -> bool {
self.authentication_failure_close
}
pub fn per_consumer_qos(&self) -> bool {
self.per_consumer_qos
}
pub fn connection_blocked(&self) -> bool {
self.connection_blocked
}
pub fn exchange_exchange_bindings(&self) -> bool {
self.exchange_exchange_bindings
}
pub fn basic_nack(&self) -> bool {
self.basic_nack
}
pub fn direct_reply_to(&self) -> bool {
self.direct_reply_to
}
}
#[derive(Debug, Clone)]
pub struct ServerProperties {
capabilities: ServerCapabilities,
product: String,
cluster_name: String,
version: String,
}
impl ServerProperties {
pub fn capabilities(&self) -> &ServerCapabilities {
&self.capabilities
}
pub fn cluster_name(&self) -> &str {
self.cluster_name.as_ref()
}
pub fn version(&self) -> &str {
self.version.as_ref()
}
pub fn product(&self) -> &str {
self.product.as_ref()
}
}
struct DropGuard {
outgoing_tx: mpsc::Sender<OutgoingMessage>,
is_open: Arc<AtomicBool>,
connection_name: String,
}
impl DropGuard {
fn new(
outgoing_tx: mpsc::Sender<OutgoingMessage>,
is_open: Arc<AtomicBool>,
connection_name: String,
) -> Self {
Self {
outgoing_tx,
is_open,
connection_name,
}
}
}
#[derive(Clone)]
pub struct Connection {
shared: Arc<SharedConnectionInner>,
is_open: Arc<AtomicBool>,
_guard: Option<Arc<DropGuard>>,
}
#[derive(Debug)]
struct SharedConnectionInner {
server_properties: ServerProperties,
connection_name: String,
channel_max: ShortUint,
frame_max: LongUint,
heartbeat: ShortUint,
outgoing_tx: mpsc::Sender<OutgoingMessage>,
conn_mgmt_tx: mpsc::Sender<ConnManagementCommand>,
shutdown_subscriber: broadcast::Sender<bool>,
}
#[derive(Clone)]
pub struct OpenConnectionArguments {
host: String,
port: u16,
virtual_host: String,
connection_name: Option<String>,
credentials: SecurityCredentials,
heartbeat: u16,
scheme: Option<String>,
#[cfg(feature = "tls")]
tls_adaptor: Option<TlsAdaptor>,
}
impl Default for OpenConnectionArguments {
fn default() -> Self {
Self {
host: String::from("localhost"),
port: DEFAULT_AMQP_PORT,
virtual_host: String::from("/"),
connection_name: None,
credentials: SecurityCredentials::new_plain("guest", "guest"),
heartbeat: 60,
scheme: None,
#[cfg(feature = "tls")]
tls_adaptor: None,
}
}
}
impl OpenConnectionArguments {
pub fn new(host: &str, port: u16, username: &str, password: &str) -> Self {
Self {
host: host.to_owned(),
port,
virtual_host: String::from("/"),
connection_name: None,
credentials: SecurityCredentials::new_plain(username, password),
heartbeat: 60,
scheme: None,
#[cfg(feature = "tls")]
tls_adaptor: None,
}
}
pub fn host(&mut self, host: &str) -> &mut Self {
self.host = host.to_owned();
self
}
pub fn get_host(&self) -> &str {
&self.host
}
pub fn port(&mut self, port: u16) -> &mut Self {
self.port = port;
self
}
pub fn get_port(&self) -> u16 {
self.port
}
pub fn virtual_host(&mut self, virtual_host: &str) -> &mut Self {
#[cfg(feature = "compliance_assert")]
assert_path(virtual_host);
self.virtual_host = virtual_host.to_owned();
self
}
pub fn get_virtual_host(&self) -> &str {
&self.virtual_host
}
pub fn connection_name(&mut self, connection_name: &str) -> &mut Self {
self.connection_name = Some(connection_name.to_owned());
self
}
pub fn get_connection_name(&self) -> Option<&str> {
self.connection_name.as_deref()
}
pub fn credentials(&mut self, credentials: SecurityCredentials) -> &mut Self {
self.credentials = credentials;
self
}
pub fn get_credentials(&self) -> &SecurityCredentials {
&self.credentials
}
pub fn heartbeat(&mut self, heartbeat: u16) -> &mut Self {
self.heartbeat = heartbeat;
self
}
pub fn get_heartbeat(&self) -> u16 {
self.heartbeat
}
pub fn scheme(&mut self, scheme: &str) -> &mut Self {
self.scheme = Some(scheme.to_owned());
self
}
pub fn get_scheme(&self) -> Option<&str> {
self.scheme.as_deref()
}
#[cfg(feature = "tls")]
pub fn tls_adaptor(&mut self, tls_adaptor: TlsAdaptor) -> &mut Self {
self.tls_adaptor = Some(tls_adaptor);
self
}
#[cfg(feature = "tls")]
pub fn get_tls_adaptor(&self) -> Option<&TlsAdaptor> {
self.tls_adaptor.as_ref()
}
pub fn finish(&mut self) -> Self {
self.clone()
}
}
#[cfg(feature = "urispec")]
impl TryFrom<&str> for OpenConnectionArguments {
type Error = Error;
fn try_from(uri: &str) -> Result<Self> {
let pu = URIReference::try_from(uri)?;
let scheme = pu
.scheme()
.ok_or_else(|| Error::UriError(String::from("No URI scheme")))?
.as_str();
let default_port: u16 = match scheme {
AMQP_SCHEME => DEFAULT_AMQP_PORT,
AMQPS_SCHEME => {
if cfg!(feature = "tls") {
DEFAULT_AMQPS_PORT
} else {
return Err(Error::UriError(format!(
"TLS feature should be enabled to use scheme: {}",
scheme
)));
}
}
_ => {
return Err(Error::UriError(format!(
"Unsupported URI scheme: {}",
scheme
)))
}
};
let pu_authority = pu
.authority()
.ok_or_else(|| Error::UriError(String::from("Invalid URI authority")))?;
let pu_authority_username = pu_authority
.username()
.map(|v| v.as_str())
.unwrap_or("guest");
let pu_authority_password = pu_authority
.password()
.map(|v| v.as_str())
.unwrap_or("guest");
let host = pu_authority.host().to_string();
let mut args = OpenConnectionArguments::new(
host.as_str(),
pu_authority.port().unwrap_or(default_port),
pu_authority_username,
pu_authority_password,
);
args.scheme = Some(scheme.to_owned());
let pu_path = pu.path().to_string();
if pu_path.len() <= 1 {
args.virtual_host("/");
} else {
args.virtual_host(&pu_path[1..]);
}
if scheme == AMQPS_SCHEME {
#[cfg(feature = "tls")]
args.tls_adaptor(
TlsAdaptor::without_client_auth(None, host.to_string())
.map_err(|e| Error::UriError(format!("error creating TLS adaptor: {}", e)))?,
);
#[cfg(not(feature = "tls"))]
return Err(Error::UriError(
"can't create amqps url without the `tls` feature enabled".to_string(),
));
}
let pu_q = pu.query().map(|v| v.as_str()).ok_or(|| "").unwrap_or("");
if pu_q.is_empty() {
return Ok(args);
}
let pu_q_map: std::collections::HashMap<&str, &str> = pu_q
.split('&')
.map(|s| {
let mut split = s.split('=');
let key = split.next().unwrap();
let value = split.next().unwrap();
(key, value)
})
.collect();
let heartbeat = pu_q_map
.get("heartbeat")
.map(|v| v.parse::<u16>().unwrap_or(DEFAULT_HEARTBEAT))
.unwrap_or(DEFAULT_HEARTBEAT);
args.heartbeat(heartbeat);
Ok(args)
}
}
impl Connection {
pub async fn open(args: &OpenConnectionArguments) -> Result<Self> {
#[cfg(feature = "tls")]
let mut io_conn = match &args.tls_adaptor {
Some(tls_adaptor) => {
if let Some(scheme) = &args.scheme {
if scheme == AMQP_SCHEME {
return Err(Error::UriError(format!(
"Try to open a secure connection with '{}' scheme",
scheme
)));
}
}
SplitConnection::open_tls(
&format!("{}:{}", args.host, args.port),
&tls_adaptor.domain,
&tls_adaptor.connector,
)
.await?
}
None => {
if let Some(scheme) = &args.scheme {
if scheme == AMQPS_SCHEME {
return Err(Error::UriError(format!(
"Try to open a regular connection with '{}' scheme",
scheme
)));
}
}
SplitConnection::open(&format!("{}:{}", args.host, args.port)).await?
}
};
#[cfg(not(feature = "tls"))]
let mut io_conn = {
if let Some(scheme) = &args.scheme {
if scheme == AMQPS_SCHEME {
return Err(Error::UriError(format!(
"Try to open a regular connection with '{}' scheme",
scheme
)));
}
}
SplitConnection::open(&format!("{}:{}", args.host, args.port)).await?
};
Self::negotiate_protocol(&mut io_conn).await?;
let connection_name = match args.connection_name {
Some(ref given_name) => given_name.clone(),
None => generate_connection_name(&format!(
"{}:{}{}",
args.host, args.port, args.virtual_host
)),
};
let mut client_properties = AmqpPeerProperties::new();
client_properties.insert(
"connection_name".try_into().unwrap(),
FieldValue::S(connection_name.clone().try_into().unwrap()),
);
client_properties.insert(
"product".try_into().unwrap(),
FieldValue::S("AMQPRS".try_into().unwrap()),
);
client_properties.insert(
"platform".try_into().unwrap(),
FieldValue::S("Rust".try_into().unwrap()),
);
client_properties.insert(
"version".try_into().unwrap(),
FieldValue::S("0.1".try_into().unwrap()),
);
let mut client_properties_capabilities = FieldTable::new();
client_properties_capabilities
.insert("consumer_cancel_notify".try_into().unwrap(), true.into());
client_properties.insert(
"capabilities".try_into().unwrap(),
FieldValue::F(client_properties_capabilities),
);
let server_properties =
Self::start_connection_negotiation(&mut io_conn, client_properties, args).await?;
let (channel_max, frame_max, heartbeat) =
Self::tuning_parameters(&mut io_conn, args.heartbeat).await?;
let open = Open::new(
args.virtual_host.clone().try_into().unwrap(),
"".try_into().unwrap(),
)
.into_frame();
io_conn
.write_frame(DEFAULT_CONN_CHANNEL, open, FRAME_MIN_SIZE)
.await?;
let (_, frame) = io_conn.read_frame().await?;
unwrap_expected_method!(
frame,
Frame::OpenOk,
Error::ConnectionOpenError(format!("failed to open connection, reason: {}", frame))
)?;
let (outgoing_tx, outgoing_rx) = mpsc::channel(OUTGOING_MESSAGE_BUFFER_SIZE);
let (conn_mgmt_tx, conn_mgmt_rx) = mpsc::channel(CONNECTION_MANAGEMENT_COMMAND_BUFFER_SIZE);
let (shutdown_notifer, _) = broadcast::channel::<bool>(1);
let shared = Arc::new(SharedConnectionInner {
server_properties,
connection_name,
channel_max,
frame_max,
heartbeat,
outgoing_tx,
conn_mgmt_tx,
shutdown_subscriber: shutdown_notifer.clone(),
});
let is_open = Arc::new(AtomicBool::new(true));
let _guard = Some(Arc::new(DropGuard::new(
shared.outgoing_tx.clone(),
is_open.clone(),
shared.connection_name.clone(),
)));
let new_amqp_conn = Self {
shared,
is_open,
_guard,
};
new_amqp_conn
.spawn_handlers(
io_conn,
outgoing_rx,
conn_mgmt_rx,
heartbeat,
shutdown_notifer,
)
.await;
new_amqp_conn
.register_channel_resource(Some(DEFAULT_CONN_CHANNEL), ChannelResource::new(None))
.await
.ok_or_else(|| {
Error::ConnectionOpenError("failed to register channel resource".to_string())
})?;
#[cfg(feature = "traces")]
info!("open connection {}", new_amqp_conn.connection_name());
Ok(new_amqp_conn)
}
async fn negotiate_protocol(io_conn: &mut SplitConnection) -> Result<()> {
io_conn.write(&ProtocolHeader::default()).await?;
Ok(())
}
async fn start_connection_negotiation(
io_conn: &mut SplitConnection,
client_properties: AmqpPeerProperties,
args: &OpenConnectionArguments,
) -> Result<ServerProperties> {
let (_, frame) = io_conn.read_frame().await?;
let mut start = unwrap_expected_method!(
frame,
Frame::Start,
Error::ConnectionOpenError(format!(
"failed to negotiate connection params, reason: {}",
frame
))
)?;
if !start
.locales
.as_ref()
.split(' ')
.any(|v| DEFAULT_LOCALE == v)
{
return Err(Error::ConnectionOpenError(format!(
"locale '{}' is not supported by server",
DEFAULT_LOCALE
)));
}
if !start
.mechanisms
.as_ref()
.split(' ')
.any(|v| args.credentials.get_mechanism_name() == v)
{
return Err(Error::ConnectionOpenError(format!(
"authentication '{}' is not supported by server",
args.credentials.get_mechanism_name()
)));
}
let mut caps_table: FieldTable = start
.server_properties
.remove(&"capabilities".try_into().unwrap())
.unwrap_or_else(|| FieldValue::F(FieldTable::default()))
.try_into()
.unwrap();
let mut unwrap_bool_field = |key: &str| {
let value: bool = caps_table
.remove(&key.try_into().unwrap())
.unwrap_or(FieldValue::t(false))
.try_into()
.unwrap();
value
};
let capabilities = ServerCapabilities {
consumer_cancel_notify: unwrap_bool_field("consumer_cancel_notify"),
publisher_confirms: unwrap_bool_field("publisher_confirms"),
consumer_priorities: unwrap_bool_field("consumer_priorities"),
authentication_failure_close: unwrap_bool_field("authentication_failure_close"),
per_consumer_qos: unwrap_bool_field("per_consumer_qos"),
connection_blocked: unwrap_bool_field("connection.blocked"),
exchange_exchange_bindings: unwrap_bool_field("exchange_exchange_bindings"),
basic_nack: unwrap_bool_field("basic.nack"),
direct_reply_to: unwrap_bool_field("direct_reply_to"),
};
let mut unwrap_longstr_field = |key: &str| {
let value: LongStr = start
.server_properties
.remove(&key.try_into().unwrap())
.unwrap_or_else(|| FieldValue::S("unknown".try_into().unwrap()))
.try_into()
.unwrap();
value
};
let server_properties = ServerProperties {
capabilities,
product: unwrap_longstr_field("product").into(),
cluster_name: unwrap_longstr_field("cluster_name").into(),
version: unwrap_longstr_field("version").into(),
};
let resopnse = args.credentials.get_response().try_into().unwrap();
let start_ok = StartOk::new(
client_properties,
args.credentials.get_mechanism_name().try_into().unwrap(),
resopnse,
DEFAULT_LOCALE.try_into().unwrap(),
);
io_conn
.write_frame(DEFAULT_CONN_CHANNEL, start_ok.into_frame(), FRAME_MIN_SIZE)
.await?;
Ok(server_properties)
}
async fn tuning_parameters(
io_conn: &mut SplitConnection,
heartbeat: ShortUint,
) -> Result<(ShortUint, LongUint, ShortUint)> {
let (_, frame) = io_conn.read_frame().await?;
let tune = unwrap_expected_method!(
frame,
Frame::Tune,
Error::ConnectionOpenError(format!(
"failed to tune connection params, reason: {}",
frame
))
)?;
let new_heartbeat = if tune.heartbeat() == 0 || heartbeat == 0 {
std::cmp::max(tune.heartbeat(), heartbeat)
} else {
std::cmp::min(tune.heartbeat(), heartbeat)
};
#[cfg(feature = "compliance_assert")]
{
assert_ne!(0, tune.channel_max());
assert!(tune.frame_max() >= FRAME_MIN_SIZE);
}
let new_channel_max = tune.channel_max();
let new_frame_max = tune.frame_max();
let tune_ok = TuneOk::new(new_channel_max, new_frame_max, new_heartbeat);
io_conn
.write_frame(DEFAULT_CONN_CHANNEL, tune_ok.into_frame(), FRAME_MIN_SIZE)
.await?;
Ok((new_channel_max, new_frame_max, new_heartbeat))
}
pub fn connection_name(&self) -> &str {
&self.shared.connection_name
}
pub fn channel_max(&self) -> u16 {
self.shared.channel_max
}
pub fn frame_max(&self) -> u32 {
self.shared.frame_max
}
pub fn server_properties(&self) -> &ServerProperties {
&self.shared.server_properties
}
async fn register_responder(
&self,
channel_id: AmqpChannelId,
method_header: &'static MethodHeader,
) -> Result<oneshot::Receiver<IncomingMessage>> {
let (responder, responder_rx) = oneshot::channel();
let (acker, acker_rx) = oneshot::channel();
let cmd = RegisterResponder {
channel_id,
method_header,
responder,
acker,
};
self.shared
.conn_mgmt_tx
.send(ConnManagementCommand::RegisterResponder(cmd))
.await?;
acker_rx.await?;
Ok(responder_rx)
}
pub async fn register_callback<F>(&self, callback: F) -> Result<()>
where
F: ConnectionCallback + Send + 'static,
{
let cmd = RegisterConnectionCallback {
callback: Box::new(callback),
};
self.shared
.conn_mgmt_tx
.send(ConnManagementCommand::RegisterConnectionCallback(cmd))
.await?;
Ok(())
}
pub(crate) fn set_is_open(&self, is_open: bool) {
self.is_open.store(is_open, Ordering::Relaxed);
}
pub fn is_open(&self) -> bool {
self.is_open.load(Ordering::Relaxed)
}
pub fn heartbeat(&self) -> u16 {
self.shared.heartbeat
}
pub(crate) async fn register_channel_resource(
&self,
channel_id: Option<AmqpChannelId>,
resource: ChannelResource,
) -> Option<AmqpChannelId> {
let (acker, acker_rx) = oneshot::channel();
let cmd = ConnManagementCommand::RegisterChannelResource(RegisterChannelResource {
channel_id,
resource,
acker,
});
if let Err(err) = self.shared.conn_mgmt_tx.send(cmd).await {
#[cfg(feature = "traces")]
debug!(
"failed to register channel resource on connection {}, cause: {}",
self, err
);
return None;
}
match acker_rx.await {
Ok(res) => {
if res.is_none() {
#[cfg(feature = "traces")]
debug!(
"failed to allocate/reserve channel id on connection {}",
self
);
}
res
}
Err(err) => {
#[cfg(feature = "traces")]
debug!(
"failed to register channel resource on connection {}, cause: {}",
self, err
);
None
}
}
}
pub(crate) async fn spawn_handlers(
&self,
io_conn: SplitConnection,
outgoing_rx: mpsc::Receiver<OutgoingMessage>,
conn_mgmt_rx: mpsc::Receiver<ConnManagementCommand>,
heartbeat: ShortUint,
shutdown_notifer: broadcast::Sender<bool>,
) {
let (reader, writer) = io_conn.into_split();
let wh = WriterHandler::new(
writer,
outgoing_rx,
shutdown_notifer.subscribe(),
self.clone_no_drop_guard(),
);
tokio::spawn(async move {
wh.run_until_shutdown(heartbeat).await;
});
let rh = ReaderHandler::new(
reader,
self.clone_no_drop_guard(),
self.shared.outgoing_tx.clone(),
conn_mgmt_rx,
self.shared.channel_max,
shutdown_notifer,
);
tokio::spawn(async move {
rh.run_until_shutdown(heartbeat).await;
});
}
pub async fn open_channel(&self, channel_id: Option<AmqpChannelId>) -> Result<Channel> {
assert_ne!(Some(DEFAULT_CONN_CHANNEL), channel_id);
let (dispatcher_tx, dispatcher_rx) = mpsc::unbounded_channel();
let (dispatcher_mgmt_tx, dispatcher_mgmt_rx) = mpsc::unbounded_channel();
let channel_id = self
.register_channel_resource(channel_id, ChannelResource::new(Some(dispatcher_tx)))
.await
.ok_or_else(|| {
Error::ChannelOpenError("failed to register channel resource".to_string())
})?;
let responder_rx = self
.register_responder(channel_id, OpenChannelOk::header())
.await?;
synchronous_request!(
self.shared.outgoing_tx,
(channel_id, OpenChannel::new().into_frame()),
responder_rx,
Frame::OpenChannelOk,
Error::ChannelOpenError
)?;
let channel = Channel::new(
AtomicBool::new(true),
self.clone_no_drop_guard(),
channel_id,
self.shared.outgoing_tx.clone(),
self.shared.conn_mgmt_tx.clone(),
dispatcher_mgmt_tx,
);
let dispatcher = ChannelDispatcher::new(
channel.clone_as_secondary(),
dispatcher_rx,
dispatcher_mgmt_rx,
);
dispatcher.spawn().await;
#[cfg(feature = "traces")]
info!("open channel {}", channel);
Ok(channel)
}
pub async fn blocked(&self, reason: &str) -> Result<()> {
let blocked = Blocked::new(reason.to_owned().try_into().unwrap());
self.shared
.outgoing_tx
.send((DEFAULT_CONN_CHANNEL, blocked.into_frame()))
.await?;
Ok(())
}
pub async fn unblocked(&self) -> Result<()> {
let unblocked = Unblocked;
self.shared
.outgoing_tx
.send((DEFAULT_CONN_CHANNEL, unblocked.into_frame()))
.await?;
Ok(())
}
pub async fn close(self) -> Result<()> {
if let Ok(true) =
self.is_open
.compare_exchange(true, false, Ordering::Acquire, Ordering::Relaxed)
{
#[cfg(feature = "traces")]
info!("close connection {}", self);
self.close_handshake().await?;
}
Ok(())
}
async fn close_handshake(&self) -> Result<()> {
let responder_rx = self
.register_responder(DEFAULT_CONN_CHANNEL, CloseOk::header())
.await
.map_err(|err| {
Error::ConnectionCloseError(format!("failed to register responder {}", err))
})?;
let close = Close::default();
synchronous_request!(
self.shared.outgoing_tx,
(DEFAULT_CONN_CHANNEL, close.into_frame()),
responder_rx,
Frame::CloseOk,
Error::ConnectionCloseError
)?;
Ok(())
}
pub(crate) fn clone_no_drop_guard(&self) -> Self {
Self {
shared: self.shared.clone(),
is_open: self.is_open.clone(),
_guard: None,
}
}
pub async fn listen_network_io_failure(&self) -> bool {
let mut shutdown_listener = self.shared.shutdown_subscriber.subscribe();
(shutdown_listener.recv().await).unwrap_or(false)
}
pub async fn update_secret(&self, new_secret: &str, reason: &str) -> Result<()> {
let responder_rx = self
.register_responder(DEFAULT_CONN_CHANNEL, UpdateSecretOk::header())
.await?;
let update_secret = UpdateSecret::new(
new_secret.to_owned().try_into().unwrap(),
reason.to_owned().try_into().unwrap(),
);
synchronous_request!(
self.shared.outgoing_tx,
(DEFAULT_CONN_CHANNEL, update_secret.into_frame()),
responder_rx,
Frame::UpdateSecretOk,
Error::UpdateSecretError
)?;
Ok(())
}
}
impl Drop for DropGuard {
fn drop(&mut self) {
if let Ok(true) =
self.is_open
.compare_exchange(true, false, Ordering::Acquire, Ordering::Relaxed)
{
let connection_name = self.connection_name.clone();
let outgoing_tx = self.outgoing_tx.clone();
tokio::spawn(async move {
#[cfg(feature = "traces")]
info!("try to close connection {} at drop", connection_name);
let close = Close::default();
if let Err(err) = outgoing_tx
.send((DEFAULT_CONN_CHANNEL, close.into_frame()))
.await
{
#[cfg(feature = "traces")]
error!(
"failed to gracefully close connection {} at drop, cause: '{}'",
connection_name, err
);
}
});
}
}
}
impl fmt::Display for Connection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"'{} [{}]'",
self.connection_name(),
if self.is_open() { "open" } else { "closed" }
)
}
}
fn generate_connection_name(domain: &str) -> String {
const PREFIX: &str = "AMQPRS";
static COUNTER: AtomicUsize = AtomicUsize::new(0);
format!(
"{}{:03}@{}",
PREFIX,
COUNTER.fetch_add(1, Ordering::Relaxed),
domain
)
}
#[cfg(test)]
mod tests {
use super::{generate_connection_name, Connection, OpenConnectionArguments};
use crate::test_utils::setup_logging;
use crate::{connection::AMQPS_SCHEME, security::SecurityCredentials};
use std::{collections::HashSet, thread};
use tokio::time;
#[tokio::test]
async fn test_channel_open_close() {
setup_logging();
{
let args = OpenConnectionArguments::new("localhost", 5672, "user", "bitnami");
let connection = Connection::open(&args).await.unwrap();
{
let _channel = connection.open_channel(None).await.unwrap();
}
time::sleep(time::Duration::from_millis(100)).await;
}
time::sleep(time::Duration::from_millis(100)).await;
}
#[test]
fn test_connection_getters() {
let args = OpenConnectionArguments::new("localhost", 5672, "user", "bitnami");
assert_eq!(args.get_host(), "localhost");
assert_eq!(args.get_port(), 5672);
assert_eq!(args.get_virtual_host(), "/");
assert!(args.get_connection_name().is_none());
assert!(args.get_credentials() == &SecurityCredentials::new_plain("user", "bitnami"));
assert_eq!(args.get_heartbeat(), 60);
assert!(args.get_scheme().is_none());
#[cfg(feature = "tls")]
assert!(args.get_tls_adaptor().is_none());
}
#[test]
fn test_custom_connection_getters() {
let mut default_args = OpenConnectionArguments {
..Default::default()
};
let args = default_args
.host("localhost")
.port(1234)
.virtual_host("/vhost")
.connection_name("test")
.credentials(SecurityCredentials::new_plain("user", "bitnami"))
.heartbeat(30)
.scheme("amqps");
assert_eq!(args.get_host(), "localhost");
assert_eq!(args.get_port(), 1234);
assert_eq!(args.get_virtual_host(), "/vhost");
assert!(args.get_connection_name() == Some("test"));
assert!(args.get_credentials() == &SecurityCredentials::new_plain("user", "bitnami"));
assert_eq!(args.get_heartbeat(), 30);
assert!(args.get_scheme() == Some(AMQPS_SCHEME));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 10)]
async fn test_multi_conn_open_close() {
setup_logging();
let mut handles = vec![];
for _ in 0..10 {
let handle = tokio::spawn(async {
let args = OpenConnectionArguments::new("localhost", 5672, "user", "bitnami");
time::sleep(time::Duration::from_millis(200)).await;
let connection = Connection::open(&args).await.unwrap();
time::sleep(time::Duration::from_millis(200)).await;
connection.close().await.unwrap();
});
handles.push(handle);
}
for h in handles {
h.await.unwrap();
}
}
#[tokio::test]
async fn test_connection_clone_and_drop() {
setup_logging();
{
let args = OpenConnectionArguments::new("localhost", 5672, "user", "bitnami");
let conn1 = Connection::open(&args).await.unwrap();
let conn2 = conn1.clone();
tokio::spawn(async move {
assert!(conn2.is_open());
});
assert!(conn1.is_open());
}
time::sleep(time::Duration::from_millis(100)).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 10)]
async fn test_multi_channel_open_close() {
setup_logging();
{
let args = OpenConnectionArguments::new("localhost", 5672, "user", "bitnami")
.connection_name("test_multi_channel_open_close")
.finish();
let connection = Connection::open(&args).await.unwrap();
let mut handles = vec![];
let num_loop = 10;
for _ in 0..num_loop {
let ch = connection.open_channel(None).await.unwrap();
let handle = tokio::spawn(async move {
let _ch = ch;
time::sleep(time::Duration::from_millis(100)).await;
});
handles.push(handle);
}
for h in handles {
h.await.unwrap();
}
time::sleep(time::Duration::from_millis(100 * num_loop)).await;
}
time::sleep(time::Duration::from_millis(100)).await;
}
#[test]
fn test_name_generation() {
let n = 100;
let mut jh = Vec::with_capacity(n);
let mut res = HashSet::with_capacity(n);
for _ in 0..n {
jh.push(thread::spawn(|| generate_connection_name("testdomain")));
}
for h in jh {
assert!(res.insert(h.join().unwrap()));
}
}
#[tokio::test]
async fn test_duplicated_conn_name_is_accpeted_by_server() {
setup_logging();
let args = OpenConnectionArguments::new("localhost", 5672, "user", "bitnami")
.connection_name("amq.cname-test")
.finish();
let conn1 = Connection::open(&args).await.unwrap();
let conn2 = Connection::open(&args).await.unwrap();
time::sleep(time::Duration::from_millis(100)).await;
conn1.close().await.unwrap();
conn2.close().await.unwrap();
}
#[tokio::test]
async fn test_auth_amqplain() {
setup_logging();
let args = OpenConnectionArguments::new("localhost", 5672, "user", "bitnami")
.credentials(SecurityCredentials::new_amqplain("user", "bitnami"))
.finish();
Connection::open(&args).await.unwrap();
}
#[tokio::test]
async fn test_block_unblock() {
setup_logging();
let args = OpenConnectionArguments::default()
.credentials(SecurityCredentials::new_plain("user", "bitnami"))
.finish();
let conn = Connection::open(&args).await.unwrap();
conn.blocked("test blocked").await.unwrap();
conn.unblocked().await.unwrap();
}
#[tokio::test]
#[should_panic(expected = "failed to register channel resource")]
async fn test_open_already_opened_channel() {
setup_logging();
let args = OpenConnectionArguments::new("localhost", 5672, "user", "bitnami")
.credentials(SecurityCredentials::new_amqplain("user", "bitnami"))
.finish();
let connection = Connection::open(&args).await.unwrap();
let id = Some(9);
let _ch1 = connection.open_channel(id).await.unwrap();
let _ch2 = connection.open_channel(id).await.unwrap();
}
#[cfg(feature = "urispec")]
#[test]
fn test_openconnectionarguments_try_from() {
let args = OpenConnectionArguments::try_from("amqp://user:pass@host:10000/vhost").unwrap();
assert_eq!(args.host, "host");
assert_eq!(args.port, 10000);
assert_eq!(args.virtual_host, "vhost");
let args =
OpenConnectionArguments::try_from("amqp://user%61:%61pass@ho%61st:10000/v%2fhost")
.unwrap();
assert_eq!(args.host, "ho%61st");
assert_eq!(args.port, 10000);
assert_eq!(args.virtual_host, "v%2fhost");
let args = OpenConnectionArguments::try_from("amqp://").unwrap();
assert_eq!(args.host, "");
assert_eq!(args.port, 5672);
assert_eq!(args.virtual_host, "/");
let args = OpenConnectionArguments::try_from("amqp://:@/").unwrap();
assert_eq!(args.host, "");
assert_eq!(args.port, 5672);
assert_eq!(args.virtual_host, "/");
let args = OpenConnectionArguments::try_from("amqp://user@").unwrap();
assert_eq!(args.host, "");
assert_eq!(args.port, 5672);
assert_eq!(args.virtual_host, "/");
let args = OpenConnectionArguments::try_from("amqp://user:pass@").unwrap();
assert_eq!(args.host, "");
assert_eq!(args.port, 5672);
assert_eq!(args.virtual_host, "/");
let args = OpenConnectionArguments::try_from("amqp://host").unwrap();
assert_eq!(args.host, "host");
assert_eq!(args.port, 5672);
assert_eq!(args.virtual_host, "/");
let args = OpenConnectionArguments::try_from("amqp://:10000").unwrap();
assert_eq!(args.host, "");
assert_eq!(args.port, 10000);
assert_eq!(args.virtual_host, "/");
let args = OpenConnectionArguments::try_from("amqp://host:10000").unwrap();
assert_eq!(args.host, "host");
assert_eq!(args.port, 10000);
assert_eq!(args.virtual_host, "/");
let args = OpenConnectionArguments::try_from("fsdkfjflsd::/fsdfsdfsd:sfsd/");
assert!(args.is_err());
let args = OpenConnectionArguments::try_from("fsdkfjflsd/fsdfsdfsdsfsd/");
assert!(args.is_err());
let args = OpenConnectionArguments::try_from("amqp://[::1]").unwrap();
assert_eq!(args.host, "[::1]");
assert_eq!(args.port, 5672);
assert_eq!(args.virtual_host, "/");
assert_eq!(args.heartbeat, 60);
let args = OpenConnectionArguments::try_from("amqp://[::1]?heartbeat=30").unwrap();
assert_eq!(args.host, "[::1]");
assert_eq!(args.port, 5672);
assert_eq!(args.virtual_host, "/");
assert_eq!(args.heartbeat, 30);
}
#[cfg(all(feature = "urispec", not(feature = "tls")))]
#[test]
fn test_urispec_amqps_without_tls() {
match OpenConnectionArguments::try_from("amqps://user:bitnami@localhost?heartbeat=10") {
Ok(_) => panic!("Unexpected ok"),
Err(e) => assert!(matches!(e, crate::api::Error::UriError(_))),
}
}
#[cfg(all(feature = "urispec", feature = "tls"))]
#[test]
fn test_urispec_amqps() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let args = OpenConnectionArguments::try_from("amqps://user:bitnami@localhost?heartbeat=10")
.unwrap();
assert_eq!(args.host, "localhost");
assert_eq!(args.port, 5671);
assert_eq!(args.virtual_host, "/");
assert_eq!(args.heartbeat, 10);
let tls_adaptor = args.tls_adaptor.unwrap();
assert_eq!(tls_adaptor.domain, "localhost");
}
#[cfg(all(feature = "urispec", feature = "tls"))]
#[test]
fn test_urispec_amqps_simple() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let args = OpenConnectionArguments::try_from("amqps://localhost").unwrap();
assert_eq!(args.host, "localhost");
assert_eq!(args.port, 5671);
assert_eq!(args.virtual_host, "/");
let tls_adaptor = args.tls_adaptor.unwrap();
assert_eq!(tls_adaptor.domain, "localhost");
}
#[cfg(all(feature = "urispec", feature = "tls"))]
#[tokio::test]
#[should_panic(expected = "UriError")]
async fn test_amqp_scheme_with_tls() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let current_dir = std::env::current_dir().unwrap();
let current_dir = current_dir.join("../rabbitmq_conf/client/");
let root_ca_cert = current_dir.join("ca_certificate.pem");
let client_cert = current_dir.join("client_AMQPRS_TEST_certificate.pem");
let client_private_key = current_dir.join("client_AMQPRS_TEST_key.pem");
let domain = "AMQPRS_TEST";
let tls_adaptor = crate::tls::TlsAdaptor::with_client_auth(
Some(root_ca_cert.as_path()),
client_cert.as_path(),
client_private_key.as_path(),
domain.to_owned(),
)
.unwrap();
let args = OpenConnectionArguments::try_from("amqp://user:bitnami@localhost?heartbeat=10")
.unwrap()
.tls_adaptor(tls_adaptor)
.finish();
Connection::open(&args).await.unwrap();
}
}