use std::marker::PhantomData;
use ntex_bytes::ByteString;
use ntex_error::Error;
use ntex_io::IoBoxed;
use ntex_net::connect::{self, Address};
use ntex_service::{Service, ServiceCtx, ServiceFactory, cfg::Cfg, cfg::SharedCfg};
use ntex_util::time::timeout_checked;
use crate::codec::protocol::{Frame, ProtocolId, SaslCode, SaslFrameBody, SaslInit};
use crate::codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec, SaslFrame, types::Symbol};
use crate::{AmqpServiceConfig, Connection, RemoteServiceConfig, error::ProtocolIdError};
use super::{Connect, SaslAuth, connection::Client, error::ConnectError};
pub struct Connector<A, T = ()> {
connector: T,
_t: PhantomData<A>,
}
pub struct ConnectorService<A, T> {
connector: T,
config: Cfg<AmqpServiceConfig>,
_t: PhantomData<A>,
}
impl<A> Connector<A> {
pub fn new() -> Connector<A, connect::Connector2<A>> {
Connector {
connector: connect::Connector2::default(),
_t: PhantomData,
}
}
}
impl<A, T> Connector<A, T>
where
A: Address,
T: ServiceFactory<connect::Connect<A>, SharedCfg, Error = Error<connect::ConnectError>>,
IoBoxed: From<T::Response>,
{
pub fn with(connector: T) -> Connector<A, T> {
Connector {
connector,
_t: PhantomData,
}
}
}
impl<A, T> Connector<A, T>
where
A: Address,
{
pub fn connector<U>(self, connector: U) -> Connector<A, U>
where
U: ServiceFactory<connect::Connect<A>, SharedCfg, Error = Error<connect::ConnectError>>,
IoBoxed: From<U::Response>,
{
Connector {
connector,
_t: PhantomData,
}
}
}
impl<A, T> ServiceFactory<Connect<A>, SharedCfg> for Connector<A, T>
where
A: Address,
T: ServiceFactory<connect::Connect<A>, SharedCfg, Error = Error<connect::ConnectError>>,
IoBoxed: From<T::Response>,
{
type Response = Client;
type Error = Error<ConnectError>;
type Service = ConnectorService<A, T::Service>;
type InitError = T::InitError;
async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
Ok(ConnectorService {
config: cfg.get(),
connector: self.connector.create(cfg).await?,
_t: PhantomData,
})
}
}
impl<A, T> Service<Connect<A>> for ConnectorService<A, T>
where
A: Address,
T: Service<connect::Connect<A>, Error = Error<connect::ConnectError>>,
IoBoxed: From<T::Response>,
{
type Response = Client;
type Error = Error<ConnectError>;
async fn call(
&self,
req: Connect<A>,
ctx: ServiceCtx<'_, Self>,
) -> Result<Client, Self::Error> {
let fut = async {
let (addr, sasl, hostname) = req.into_parts();
let io = ctx
.call(&self.connector, connect::Connect::new(addr))
.await
.map_err(|e| e.map(ConnectError::from))?;
let io = IoBoxed::from(io);
if let Some(auth) = sasl {
connect_sasl_inner(io, auth, self.config.clone(), hostname).await
} else {
connect_plain_inner(io, self.config.clone(), hostname).await
}
};
timeout_checked(self.config.handshake_timeout, fut)
.await
.map_err(|()| Error::from(ConnectError::HandshakeTimeout))
.and_then(|res| res)
.map_err(|e| e.set_service(self.config.service()))
}
}
impl<A, T> ConnectorService<A, T>
where
A: Address,
T: Service<connect::Connect<A>, Error = Error<connect::ConnectError>>,
IoBoxed: From<T::Response>,
{
pub async fn negotiate(
&self,
io: IoBoxed,
hostname: Option<ByteString>,
) -> Result<Client, Error<ConnectError>> {
log::trace!("{}: Negotiation client protocol id: Amqp", io.tag());
connect_plain_inner(io, self.config.clone(), hostname)
.await
.map_err(|e| e.set_service(self.config.service()))
}
pub async fn negotiate_sasl(
&self,
io: IoBoxed,
auth: SaslAuth,
hostname: Option<ByteString>,
) -> Result<Client, Error<ConnectError>> {
log::trace!("{}: Negotiation client protocol id: Amqp", io.tag());
connect_sasl_inner(io, auth, self.config.clone(), hostname)
.await
.map_err(|e| e.set_service(self.config.service()))
}
}
async fn connect_sasl_inner(
io: IoBoxed,
auth: SaslAuth,
config: Cfg<AmqpServiceConfig>,
hostname: Option<ByteString>,
) -> Result<Client, Error<ConnectError>> {
log::trace!("{}: Negotiation client protocol id: AmqpSasl", io.tag());
io.send(ProtocolId::AmqpSasl, &ProtocolIdCodec)
.await
.map_err(ConnectError::from)?;
let proto = io
.recv(&ProtocolIdCodec)
.await
.map_err(ConnectError::from)?
.ok_or_else(|| {
log::trace!("{}: Amqp server is disconnected during handshake", io.tag());
ConnectError::Disconnected
})?;
if proto != ProtocolId::AmqpSasl {
return Err(Error::from(ConnectError::from(
ProtocolIdError::Unexpected {
exp: ProtocolId::AmqpSasl,
got: proto,
},
)));
}
let codec = AmqpCodec::<SaslFrame>::new();
let _ = io
.recv(&codec)
.await
.map_err(ConnectError::from)?
.ok_or(ConnectError::Disconnected)?;
let initial_response =
SaslInit::prepare_response(&auth.authz_id, &auth.authn_id, &auth.password);
let sasl_init = SaslInit {
hostname: config.hostname.clone(),
mechanism: Symbol::from("PLAIN"),
initial_response: Some(initial_response),
};
io.send(sasl_init.into(), &codec)
.await
.map_err(ConnectError::from)?;
let sasl_frame = io
.recv(&codec)
.await
.map_err(ConnectError::from)?
.ok_or(ConnectError::Disconnected)?;
if let SaslFrame {
body: SaslFrameBody::SaslOutcome(outcome),
} = sasl_frame
{
if outcome.code() != SaslCode::Ok {
return Err(Error::from(ConnectError::Sasl(outcome.code())));
}
} else {
return Err(Error::from(ConnectError::Disconnected));
}
connect_plain_inner(io, config, hostname).await
}
async fn connect_plain_inner(
io: IoBoxed,
config: Cfg<AmqpServiceConfig>,
hostname: Option<ByteString>,
) -> Result<Client, Error<ConnectError>> {
log::trace!("{}: Negotiation client protocol id: Amqp", io.tag());
io.send(ProtocolId::Amqp, &ProtocolIdCodec)
.await
.map_err(ConnectError::from)?;
let proto = io
.recv(&ProtocolIdCodec)
.await
.map_err(ConnectError::from)?
.ok_or_else(|| {
log::trace!("{}: Amqp server is disconnected during handshake", io.tag());
Error::from(ConnectError::Disconnected)
})?;
if proto != ProtocolId::Amqp {
return Err(Error::from(ConnectError::from(
ProtocolIdError::Unexpected {
exp: ProtocolId::Amqp,
got: proto,
},
)));
}
let mut open = config.to_open();
if let Some(hostname) = hostname {
*open.hostname_mut() = Some(hostname);
}
let codec = AmqpCodec::<AmqpFrame>::new().max_size(config.max_frame_size as usize);
log::trace!("{}: Open client amqp connection: {:?}", io.tag(), open);
io.send(AmqpFrame::new(0, Frame::Open(open)), &codec)
.await
.map_err(ConnectError::from)?;
let frame = io
.recv(&codec)
.await
.map_err(ConnectError::from)?
.ok_or_else(|| {
log::trace!("{}: Amqp server is disconnected during handshake", io.tag());
Error::from(ConnectError::Disconnected)
})?;
if let Frame::Open(open) = frame.performative() {
log::trace!("{}: Open confirmed: {:?}", io.tag(), open);
let remote_config = RemoteServiceConfig::new(open);
let connection = Connection::new(io.get_ref(), &config, &remote_config);
let client = Client::new(io, codec, connection, remote_config);
Ok(client)
} else {
Err(Error::from(ConnectError::ExpectOpenFrame(Box::new(frame))))
}
}