use crate::connection_options::ConnectionOptions;
use crate::errors::*;
use crate::io_loop::{Channel0Handle, IoLoop};
use crate::{Channel, FieldTable, IoStream, Sasl};
use crossbeam_channel::Receiver;
use log::debug;
use std::thread::JoinHandle;
#[cfg(feature = "native-tls")]
use crate::TlsConnector;
#[derive(Debug, Clone)]
pub enum ConnectionBlockedNotification {
Blocked(String),
Unblocked,
}
#[derive(Debug, Clone)]
pub struct ConnectionTuning {
pub mem_channel_bound: usize,
pub buffered_writes_high_water: usize,
pub buffered_writes_low_water: usize,
}
impl Default for ConnectionTuning {
fn default() -> Self {
ConnectionTuning {
mem_channel_bound: 16,
buffered_writes_high_water: 16 << 20,
buffered_writes_low_water: 0,
}
}
}
impl ConnectionTuning {
pub fn mem_channel_bound(self, mem_channel_bound: usize) -> Self {
ConnectionTuning {
mem_channel_bound,
..self
}
}
pub fn buffered_writes_high_water(self, buffered_writes_high_water: usize) -> Self {
ConnectionTuning {
buffered_writes_high_water,
..self
}
}
pub fn buffered_writes_low_water(self, buffered_writes_low_water: usize) -> Self {
ConnectionTuning {
buffered_writes_low_water,
..self
}
}
}
#[derive(Debug)]
pub struct Connection {
join_handle: Option<JoinHandle<Result<()>>>,
channel0: Channel0Handle,
server_properties: FieldTable,
}
impl Drop for Connection {
fn drop(&mut self) {
let _ = self.close_impl();
}
}
impl Connection {
#[cfg(feature = "native-tls")]
pub fn open(url: &str) -> Result<Connection> {
Self::open_tuned(url, ConnectionTuning::default())
}
#[cfg(feature = "native-tls")]
pub fn open_tuned(url: &str, tuning: ConnectionTuning) -> Result<Connection> {
self::amqp_url::open(url, tuning, false)
}
pub fn insecure_open(url: &str) -> Result<Connection> {
Self::insecure_open_tuned(url, ConnectionTuning::default())
}
pub fn insecure_open_tuned(url: &str, tuning: ConnectionTuning) -> Result<Connection> {
self::amqp_url::open(url, tuning, true)
}
#[cfg(feature = "native-tls")]
pub fn open_tls_stream<Auth: Sasl, C: Into<TlsConnector>, S: IoStream>(
connector: C,
domain: &str,
stream: S,
options: ConnectionOptions<Auth>,
tuning: ConnectionTuning,
) -> Result<Connection> {
let stream = connector.into().connect(domain, stream)?;
let io_loop = IoLoop::new(tuning)?;
let (join_handle, server_properties, channel0) = io_loop.start_tls(stream, options)?;
Ok(Connection {
join_handle: Some(join_handle),
channel0,
server_properties,
})
}
pub fn insecure_open_stream<Auth: Sasl, S: IoStream>(
stream: S,
options: ConnectionOptions<Auth>,
tuning: ConnectionTuning,
) -> Result<Connection> {
let io_loop = IoLoop::new(tuning)?;
let (join_handle, server_properties, channel0) = io_loop.start(stream, options)?;
Ok(Connection {
join_handle: Some(join_handle),
channel0,
server_properties,
})
}
pub fn server_properties(&self) -> &FieldTable {
&self.server_properties
}
pub fn open_channel(&mut self, channel_id: Option<u16>) -> Result<Channel> {
let handle = self.channel0.open_channel(channel_id)?;
Ok(Channel::new(handle))
}
pub fn listen_for_connection_blocked(
&mut self,
) -> Result<Receiver<ConnectionBlockedNotification>> {
let (tx, rx) = crossbeam_channel::unbounded();
self.channel0.set_blocked_tx(tx)?;
Ok(rx)
}
pub fn close(mut self) -> Result<()> {
self.close_impl()
}
fn close_impl(&mut self) -> Result<()> {
if let Some(join_handle) = self.join_handle.take() {
debug!("closing connection");
let close_result = self.channel0.close_connection();
join_handle.join().map_err(|_| Error::IoThreadPanic)??;
close_result
} else {
Ok(())
}
}
}
mod amqp_url {
use super::*;
use crate::{Auth, Error};
use mio::net::TcpStream;
use snafu::ResultExt;
use std::borrow::Cow;
use std::time::Duration;
use url::Url;
pub fn open(url: &str, tuning: ConnectionTuning, allow_insecure: bool) -> Result<Connection> {
let mut url = Url::parse(url).context(UrlParseSnafu)?;
let scheme = populate_host_and_port(&mut url)?;
let options = decode(&url)?;
match scheme {
Scheme::Amqp => {
if allow_insecure {
open_amqp(url, options, tuning)
} else {
InsecureUrlSnafu { url }.fail()
}
}
Scheme::Amqps => open_amqps(url, options, tuning),
}
}
fn open_amqp(
url: Url,
options: ConnectionOptions<Auth>,
tuning: ConnectionTuning,
) -> Result<Connection> {
let mut last_err: Option<Error> = None;
for addr in url
.socket_addrs(|| None)
.with_context(|_| ResolveUrlToSocketAddrSnafu { url: url.clone() })?
{
let result = TcpStream::connect(&addr)
.with_context(|_| FailedToConnectSnafu { url: url.clone() })
.and_then(|stream| {
Connection::insecure_open_stream(stream, options.clone(), tuning.clone())
});
match result {
Ok(connection) => return Ok(connection),
Err(err) => {
last_err = Some(err);
}
}
}
let last_err = last_err.unwrap_or(Error::UrlNoSocketAddrs { url });
Err(last_err)
}
#[cfg(not(feature = "native-tls"))]
fn open_amqps(_: Url, _: ConnectionOptions<Auth>, _: ConnectionTuning) -> Result<Connection> {
TlsFeatureNotEnabledSnafu.fail()
}
#[cfg(feature = "native-tls")]
fn open_amqps(
url: Url,
options: ConnectionOptions<Auth>,
tuning: ConnectionTuning,
) -> Result<Connection> {
let mut last_err: Option<Error> = None;
let connector = native_tls::TlsConnector::new().context(CreateTlsConnectorSnafu)?;
let domain = match url.domain() {
Some(domain) => domain,
None => return UrlMissingDomainSnafu { url: url.clone() }.fail(),
};
for addr in url
.socket_addrs(|| None)
.with_context(|_| ResolveUrlToSocketAddrSnafu { url: url.clone() })?
{
let result = TcpStream::connect(&addr)
.with_context(|_| FailedToConnectSnafu { url: url.clone() })
.and_then(|stream| {
Connection::open_tls_stream(
connector.clone(),
domain,
stream,
options.clone(),
tuning.clone(),
)
});
match result {
Ok(connection) => return Ok(connection),
Err(err) => {
last_err = Some(err);
}
}
}
let last_err = last_err.unwrap_or(Error::UrlNoSocketAddrs { url });
Err(last_err)
}
#[derive(Debug, PartialEq)]
enum Scheme {
Amqp,
Amqps,
}
fn populate_host_and_port(url: &mut Url) -> Result<Scheme> {
if !url.has_host() || url.host_str() == Some("") {
url.set_host(Some("localhost")).context(UrlParseSnafu)?;
}
match url.scheme() {
"amqp" => {
url.set_port(Some(url.port().unwrap_or(5672)))
.map_err(|()| Error::SpecifyUrlPort { url: url.clone() })?;
Ok(Scheme::Amqp)
}
"amqps" => {
url.set_port(Some(url.port().unwrap_or(5671)))
.map_err(|()| Error::SpecifyUrlPort { url: url.clone() })?;
Ok(Scheme::Amqps)
}
_ => InvalidUrlSchemeSnafu { url: url.clone() }.fail(),
}
}
fn decode(url: &Url) -> Result<ConnectionOptions<Auth>> {
fn percent_decode(s: &str) -> Cow<str> {
let s = percent_encoding::percent_decode(s.as_bytes());
s.decode_utf8_lossy()
}
let mut options = ConnectionOptions::default();
if let Some(mut path_segments) = url.path_segments() {
let vhost = path_segments.next().unwrap();
if vhost != "" {
options = options.virtual_host(percent_decode(vhost));
}
if path_segments.next().is_some() {
return ExtraUrlPathSegmentsSnafu { url: url.clone() }.fail();
}
}
if url.username() != "" || url.password().is_some() {
let username = match url.username() {
"" => "guest",
other => other,
};
let auth = Auth::Plain {
username: percent_decode(username).to_string(),
password: percent_decode(url.password().unwrap_or("guest")).to_string(),
};
options = options.auth(auth);
}
for (k, v) in url.query_pairs() {
match k.as_ref() {
"heartbeat" => {
let v = v
.parse::<u16>()
.with_context(|_| UrlParseHeartbeatSnafu { url: url.clone() })?;
options = options.heartbeat(v);
}
"channel_max" => {
let v = v
.parse::<u16>()
.with_context(|_| UrlParseChannelMaxSnafu { url: url.clone() })?;
options = options.channel_max(v);
}
"connection_timeout" => {
let v = v
.parse::<u64>()
.with_context(|_| UrlParseConnectionTimeoutSnafu { url: url.clone() })?;
options = options.connection_timeout(Some(Duration::from_millis(v)));
}
"auth_mechanism" => {
if v == "external" {
options = options.auth(Auth::External);
} else {
return UrlInvalidAuthMechanismSnafu {
url: url.clone(),
mechanism: v,
}
.fail();
}
}
parameter => {
return UrlUnsupportedParameterSnafu {
url: url.clone(),
parameter,
}
.fail();
}
}
}
Ok(options)
}
#[cfg(test)]
mod tests {
use super::*;
fn decode_s(s: &str) -> Result<ConnectionOptions<Auth>> {
decode(&Url::parse(s).unwrap())
}
#[test]
#[cfg(feature = "native-tls")]
fn open_rejects_amqp_urls() {
let result = Connection::open("amqp://localhost/");
match result.unwrap_err() {
Error::InsecureUrl { .. } => (),
err => panic!("unexpected error {}", err),
}
}
#[test]
fn empty_default() {
let options = decode_s("amqp://").unwrap();
assert_eq!(options, ConnectionOptions::default());
let options = decode_s("amqps://").unwrap();
assert_eq!(options, ConnectionOptions::default());
}
#[test]
fn vhost() {
let options = decode_s("amqp:///vhost").unwrap();
assert_eq!(options, ConnectionOptions::default().virtual_host("vhost"));
let options = decode_s("amqp:///v%2fhost").unwrap();
assert_eq!(options, ConnectionOptions::default().virtual_host("v/host"));
assert!(decode_s("amqp:///vhost/nonescapedslash").is_err());
}
#[test]
fn user_pass() {
let options = decode_s("amqp://user:pass@localhost/").unwrap();
assert_eq!(
options,
ConnectionOptions::default().auth(Auth::Plain {
username: "user".to_string(),
password: "pass".to_string()
})
);
let options = decode_s("amqp://user%61:pass%62@localhost/").unwrap();
assert_eq!(
options,
ConnectionOptions::default().auth(Auth::Plain {
username: "usera".to_string(),
password: "passb".to_string()
})
);
}
#[test]
fn heartbeat() {
let options = decode_s("amqp://?heartbeat=13").unwrap();
assert_eq!(options, ConnectionOptions::default().heartbeat(13));
}
#[test]
fn channel_max() {
let options = decode_s("amqp://?channel_max=13").unwrap();
assert_eq!(options, ConnectionOptions::default().channel_max(13));
}
#[test]
fn connection_timeout() {
let options = decode_s("amqp://?connection_timeout=13").unwrap();
assert_eq!(
options,
ConnectionOptions::default().connection_timeout(Some(Duration::from_millis(13)))
);
}
#[test]
fn auth_mechanism() {
let options = decode_s("amqp://?auth_mechanism=external").unwrap();
assert_eq!(options, ConnectionOptions::default().auth(Auth::External));
}
#[test]
fn populate_host() {
let mut url = Url::parse("amqp://").unwrap();
populate_host_and_port(&mut url).unwrap();
assert_eq!(url.host_str(), Some("localhost"));
let mut url = Url::parse("amqp://foo.com").unwrap();
populate_host_and_port(&mut url).unwrap();
assert_eq!(url.host_str(), Some("foo.com"));
}
#[test]
fn populate_port() {
let mut url = Url::parse("amqp://").unwrap();
populate_host_and_port(&mut url).unwrap();
assert_eq!(url.port(), Some(5672));
let mut url = Url::parse("amqps://").unwrap();
populate_host_and_port(&mut url).unwrap();
assert_eq!(url.port(), Some(5671));
let mut url = Url::parse("amqp://foo.com:35").unwrap();
populate_host_and_port(&mut url).unwrap();
assert_eq!(url.port(), Some(35));
let mut url = Url::parse("amqps://foo.com:35").unwrap();
populate_host_and_port(&mut url).unwrap();
assert_eq!(url.port(), Some(35));
}
}
}