use std::fmt;
use ntex_bytes::{ByteString, Bytes};
use ntex_io::IoBoxed;
use ntex_service::cfg::Cfg;
use crate::codec::protocol::{
self, ProtocolId, SaslChallenge, SaslCode, SaslFrameBody, SaslMechanisms, SaslOutcome, Symbols,
};
use crate::codec::{AmqpCodec, AmqpFrame, ProtocolIdCodec, ProtocolIdError, SaslFrame};
use crate::{AmqpServiceConfig, RemoteServiceConfig, connection::Connection};
use super::{HandshakeError, handshake::HandshakeAmqpOpened};
#[derive(Debug)]
pub struct Sasl {
state: IoBoxed,
mechanisms: Symbols,
local_config: Cfg<AmqpServiceConfig>,
}
impl Sasl {
pub(crate) fn new(state: IoBoxed, local_config: Cfg<AmqpServiceConfig>) -> Self {
Sasl {
state,
local_config,
mechanisms: Symbols::default(),
}
}
}
impl Sasl {
pub fn io(&self) -> &IoBoxed {
&self.state
}
#[must_use]
pub fn mechanism<U: Into<String>>(mut self, symbol: U) -> Self {
self.mechanisms.push(ByteString::from(symbol.into()).into());
self
}
pub async fn init(self) -> Result<SaslInit, HandshakeError> {
let Sasl {
state,
mechanisms,
local_config,
..
} = self;
let frame = SaslMechanisms {
sasl_server_mechanisms: mechanisms,
}
.into();
let codec = AmqpCodec::<SaslFrame>::new();
state
.send(frame, &codec)
.await
.map_err(HandshakeError::from)?;
let frame = state
.recv(&codec)
.await?
.ok_or(HandshakeError::Disconnected(None))?;
match frame.body {
SaslFrameBody::SaslInit(frame) => Ok(SaslInit {
frame,
state,
codec,
local_config,
}),
body => Err(HandshakeError::UnexpectedSaslBodyFrame(Box::new(body))),
}
}
}
pub struct SaslInit {
frame: protocol::SaslInit,
state: IoBoxed,
codec: AmqpCodec<SaslFrame>,
local_config: Cfg<AmqpServiceConfig>,
}
impl fmt::Debug for SaslInit {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("SaslInit")
.field("frame", &self.frame)
.finish()
}
}
impl SaslInit {
pub fn io(&self) -> &IoBoxed {
&self.state
}
pub fn mechanism(&self) -> &str {
self.frame.mechanism.as_str()
}
pub fn initial_response(&self) -> Option<&[u8]> {
self.frame.initial_response.as_ref().map(AsRef::as_ref)
}
pub fn hostname(&self) -> Option<&str> {
self.frame.hostname.as_ref().map(AsRef::as_ref)
}
pub async fn challenge(self) -> Result<SaslResponse, HandshakeError> {
self.challenge_with(Bytes::new()).await
}
pub async fn challenge_with(self, challenge: Bytes) -> Result<SaslResponse, HandshakeError> {
let state = self.state;
let codec = self.codec;
let local_config = self.local_config;
let frame = SaslChallenge { challenge }.into();
state
.send(frame, &codec)
.await
.map_err(HandshakeError::from)?;
let frame = state
.recv(&codec)
.await?
.ok_or(HandshakeError::Disconnected(None))?;
match frame.body {
SaslFrameBody::SaslResponse(frame) => Ok(SaslResponse {
frame,
state,
codec,
local_config,
}),
body => Err(HandshakeError::UnexpectedSaslBodyFrame(Box::new(body))),
}
}
pub async fn outcome(self, code: SaslCode) -> Result<SaslSuccess, HandshakeError> {
let state = self.state;
let codec = self.codec;
let local_config = self.local_config;
let frame = SaslOutcome {
code,
additional_data: None,
}
.into();
state
.send(frame, &codec)
.await
.map_err(HandshakeError::from)?;
Ok(SaslSuccess {
state,
local_config,
})
}
}
pub struct SaslResponse {
frame: protocol::SaslResponse,
state: IoBoxed,
codec: AmqpCodec<SaslFrame>,
local_config: Cfg<AmqpServiceConfig>,
}
impl fmt::Debug for SaslResponse {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("SaslResponse")
.field("frame", &self.frame)
.finish()
}
}
impl SaslResponse {
pub fn io(&self) -> &IoBoxed {
&self.state
}
pub fn response(&self) -> &[u8] {
&self.frame.response[..]
}
pub async fn outcome(self, code: SaslCode) -> Result<SaslSuccess, HandshakeError> {
let state = self.state;
let codec = self.codec;
let local_config = self.local_config;
let frame = SaslOutcome {
code,
additional_data: None,
}
.into();
state
.send(frame, &codec)
.await
.map_err(HandshakeError::from)?;
state
.recv(&codec)
.await?
.ok_or(HandshakeError::Disconnected(None))?;
Ok(SaslSuccess {
state,
local_config,
})
}
}
pub struct SaslSuccess {
state: IoBoxed,
local_config: Cfg<AmqpServiceConfig>,
}
impl SaslSuccess {
pub fn io(&self) -> &IoBoxed {
&self.state
}
pub async fn open(self) -> Result<HandshakeAmqpOpened, HandshakeError> {
let state = self.state;
let protocol = state
.recv(&ProtocolIdCodec)
.await?
.ok_or(HandshakeError::Disconnected(None))?;
match protocol {
ProtocolId::Amqp => {
state
.send(ProtocolId::Amqp, &ProtocolIdCodec)
.await
.map_err(HandshakeError::from)?;
let codec = AmqpCodec::<AmqpFrame>::new();
let frame = state
.recv(&codec)
.await?
.ok_or(HandshakeError::Disconnected(None))?;
let frame = frame.into_parts().1;
match frame {
protocol::Frame::Open(frame) => {
log::trace!("{}: Got open frame: {:?}", state.tag(), frame);
let local_config = self.local_config;
let remote_config = RemoteServiceConfig::new(&frame);
let sink = Connection::new(state.clone(), &local_config, &remote_config);
Ok(HandshakeAmqpOpened::new(
frame,
sink,
state,
local_config,
remote_config,
))
}
frame => Err(HandshakeError::Unexpected(frame)),
}
}
proto => Err(ProtocolIdError::Unexpected {
exp: ProtocolId::Amqp,
got: proto,
}
.into()),
}
}
}