use std::io::{Read, Write};
use std::sync::Arc;
use zerodds_amqp_bridge::extended_types::AmqpExtValue;
use zerodds_amqp_bridge::frame::FrameType;
use zerodds_amqp_bridge::performatives;
use zerodds_amqp_bridge::types::AmqpValue;
use zerodds_amqp_endpoint::security::SaslSubject;
use zerodds_amqp_endpoint::security::{
AccessControlPlugin, AccessDecision, AccessOp, IdentityToken, build_identity_token,
};
use zerodds_amqp_endpoint::session::InboundFrameKind;
use zerodds_amqp_endpoint::{ConnectionState, EndpointError, MetricsHub, advance_connection};
use crate::frame_io::{
AmqpProtocol, FrameIoError, read_frame, read_protocol_header, write_frame,
write_protocol_header,
};
#[derive(Debug, Default, Clone)]
pub struct ConnectionStats {
pub frames_received: u64,
pub frames_sent: u64,
pub sasl_completed: bool,
pub open_received: bool,
pub closed: bool,
}
#[derive(Debug)]
pub enum HandlerError {
FrameIo(FrameIoError),
Endpoint(EndpointError),
PerformativeDecode(String),
UnexpectedEof,
}
impl core::fmt::Display for HandlerError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::FrameIo(e) => write!(f, "frame io: {e}"),
Self::Endpoint(e) => write!(f, "endpoint: {e:?}"),
Self::PerformativeDecode(s) => write!(f, "performative decode: {s}"),
Self::UnexpectedEof => write!(f, "unexpected eof"),
}
}
}
impl std::error::Error for HandlerError {}
impl From<FrameIoError> for HandlerError {
fn from(e: FrameIoError) -> Self {
Self::FrameIo(e)
}
}
impl From<EndpointError> for HandlerError {
fn from(e: EndpointError) -> Self {
Self::Endpoint(e)
}
}
#[derive(Clone)]
pub struct HandlerConfig {
pub container_id: String,
pub max_frame_size: u32,
pub tls_active: bool,
pub metrics: Arc<MetricsHub>,
pub access_control: Option<Arc<dyn AccessControlPlugin + Send + Sync>>,
pub default_identity: IdentityToken,
}
impl core::fmt::Debug for HandlerConfig {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("HandlerConfig")
.field("container_id", &self.container_id)
.field("max_frame_size", &self.max_frame_size)
.field("tls_active", &self.tls_active)
.field("access_control_present", &self.access_control.is_some())
.field(
"default_identity_subject",
&self.default_identity.subject_name,
)
.finish()
}
}
impl HandlerConfig {
#[must_use]
pub fn for_tests(metrics: Arc<MetricsHub>) -> Self {
Self {
container_id: "zerodds-amqp-endpoint".to_string(),
max_frame_size: 1_048_576,
tls_active: false,
metrics,
access_control: None,
default_identity: build_identity_token(&SaslSubject::Anonymous),
}
}
#[must_use]
pub fn with_access_control(
mut self,
plugin: Arc<dyn AccessControlPlugin + Send + Sync>,
) -> Self {
self.access_control = Some(plugin);
self
}
#[must_use]
pub fn with_identity(mut self, identity: IdentityToken) -> Self {
self.default_identity = identity;
self
}
}
pub fn handle_connection<S: Read + Write>(
stream: &mut S,
cfg: &HandlerConfig,
) -> Result<ConnectionStats, HandlerError> {
cfg.metrics.on_connection_open();
let mut stats = ConnectionStats::default();
let first = read_protocol_header(stream)?;
match first.protocol {
AmqpProtocol::Sasl => {
do_sasl_phase(stream, cfg, &mut stats)?;
let second = read_protocol_header(stream)?;
if second.protocol != AmqpProtocol::Amqp {
return Err(HandlerError::FrameIo(FrameIoError::UnsupportedProtocolId(
second.protocol.as_bytes()[4],
)));
}
write_protocol_header(stream, AmqpProtocol::Amqp)?;
}
AmqpProtocol::Amqp => {
write_protocol_header(stream, AmqpProtocol::Amqp)?;
}
}
let mut state = ConnectionState::Start;
state = advance_connection(state, InboundFrameKind::Header)?;
state = advance_connection(state, InboundFrameKind::Header)?;
do_amqp_phase(stream, cfg, &mut stats, &mut state)?;
cfg.metrics.on_connection_close();
stats.closed = true;
Ok(stats)
}
fn do_sasl_phase<S: Read + Write>(
stream: &mut S,
cfg: &HandlerConfig,
stats: &mut ConnectionStats,
) -> Result<(), HandlerError> {
write_protocol_header(stream, AmqpProtocol::Sasl)?;
let mechs = build_sasl_mechanisms(cfg.tls_active);
let sasl_mechanisms_descriptor: u64 = 0x0000_0000_0000_0040;
let body = performatives::encode_performative(sasl_mechanisms_descriptor, &mechs)
.map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
write_frame(stream, FrameType::Sasl, 0, &body)?;
stats.frames_sent += 1;
let init_frame = read_frame(stream, cfg.max_frame_size)?;
stats.frames_received += 1;
if init_frame.header.frame_type != FrameType::Sasl {
return Err(HandlerError::FrameIo(FrameIoError::UnsupportedProtocolId(
init_frame.header.frame_type.to_u8(),
)));
}
let outcome_descriptor: u64 = 0x0000_0000_0000_0044;
let outcome_body = AmqpExtValue::List(vec![AmqpExtValue::Ubyte(0)]);
let body = performatives::encode_performative(outcome_descriptor, &outcome_body)
.map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
write_frame(stream, FrameType::Sasl, 0, &body)?;
stats.frames_sent += 1;
stats.sasl_completed = true;
Ok(())
}
fn build_sasl_mechanisms(tls_active: bool) -> AmqpExtValue {
let mut mechs: Vec<AmqpExtValue> = Vec::new();
if tls_active {
mechs.push(AmqpExtValue::Symbol("PLAIN".to_string()));
}
mechs.push(AmqpExtValue::Symbol("ANONYMOUS".to_string()));
mechs.push(AmqpExtValue::Symbol("EXTERNAL".to_string()));
AmqpExtValue::List(vec![AmqpExtValue::Array(mechs)])
}
fn do_amqp_phase<S: Read + Write>(
stream: &mut S,
cfg: &HandlerConfig,
stats: &mut ConnectionStats,
state: &mut ConnectionState,
) -> Result<(), HandlerError> {
loop {
let frame = match read_frame(stream, cfg.max_frame_size) {
Ok(f) => f,
Err(FrameIoError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return Ok(());
}
Err(e) => return Err(HandlerError::FrameIo(e)),
};
stats.frames_received += 1;
if frame.body.is_empty() {
continue;
}
let kind = match classify_performative(&frame.body) {
Some(k) => k,
None => {
cfg.metrics.on_decode_error();
continue;
}
};
*state = advance_connection(*state, kind)?;
match kind {
InboundFrameKind::Open => {
stats.open_received = true;
let open = performatives::open(&cfg.container_id)
.map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
write_frame(stream, FrameType::Amqp, 0, &open)?;
stats.frames_sent += 1;
*state = advance_connection(*state, InboundFrameKind::Open)?;
}
InboundFrameKind::Begin => {
let begin = performatives::begin(Some(0), 0, 1024, 1024)
.map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
write_frame(stream, FrameType::Amqp, frame.header.channel, &begin)?;
stats.frames_sent += 1;
}
InboundFrameKind::Attach => {
let (link_name, target_addr, is_sender) = parse_attach(&frame.body);
if !check_access(
cfg,
&target_addr,
if is_sender {
AccessOp::AttachReceiver
} else {
AccessOp::AttachSender
},
) {
cfg.metrics.on_unauthorized();
let detach = performatives::detach(0, true)
.map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
write_frame(stream, FrameType::Amqp, frame.header.channel, &detach)?;
stats.frames_sent += 1;
continue;
}
let attach = performatives::attach(&link_name, 0, true)
.map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
write_frame(stream, FrameType::Amqp, frame.header.channel, &attach)?;
stats.frames_sent += 1;
}
InboundFrameKind::Transfer => {
if !check_access(cfg, "<transfer>", AccessOp::ReceiveSample) {
cfg.metrics.on_unauthorized();
continue;
}
cfg.metrics.on_transfer_received();
}
InboundFrameKind::Close => {
let close = performatives::close()
.map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
write_frame(stream, FrameType::Amqp, 0, &close)?;
stats.frames_sent += 1;
*state = advance_connection(*state, InboundFrameKind::Close)?;
return Ok(());
}
InboundFrameKind::End => {
let end = performatives::end()
.map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
write_frame(stream, FrameType::Amqp, frame.header.channel, &end)?;
stats.frames_sent += 1;
}
InboundFrameKind::Flow | InboundFrameKind::Disposition | InboundFrameKind::Detach => {}
InboundFrameKind::Header => {
}
}
}
}
fn check_access(cfg: &HandlerConfig, address: &str, op: AccessOp) -> bool {
let Some(plugin) = cfg.access_control.as_ref() else {
return true;
};
matches!(
plugin.check(&cfg.default_identity, address, op),
AccessDecision::Allow
)
}
fn parse_attach(body: &[u8]) -> (String, String, bool) {
let default = ("link".to_string(), "<unknown>".to_string(), true);
let Ok((_, body_value, _)) = zerodds_amqp_bridge::performatives::decode_performative(body)
else {
return default;
};
let AmqpExtValue::List(items) = body_value else {
return default;
};
let link_name = items
.first()
.and_then(|v| match v {
AmqpExtValue::Str(s) => Some(s.clone()),
_ => None,
})
.unwrap_or_else(|| default.0.clone());
let is_sender_from_role = items
.get(2)
.map(|v| matches!(v, AmqpExtValue::Boolean(false)))
.unwrap_or(default.2);
let target_addr = items
.get(6)
.and_then(extract_address)
.or_else(|| items.get(5).and_then(extract_address))
.unwrap_or_else(|| default.1.clone());
(link_name, target_addr, is_sender_from_role)
}
fn extract_address(v: &AmqpExtValue) -> Option<String> {
match v {
AmqpExtValue::Str(s) => Some(s.clone()),
AmqpExtValue::Symbol(s) => Some(s.clone()),
AmqpExtValue::List(items) => items.first().and_then(|x| match x {
AmqpExtValue::Str(s) | AmqpExtValue::Symbol(s) => Some(s.clone()),
_ => None,
}),
_ => None,
}
}
#[must_use]
pub fn classify_performative(body: &[u8]) -> Option<InboundFrameKind> {
let (descriptor, _, _) = zerodds_amqp_bridge::performatives::decode_performative(body).ok()?;
descriptor_to_kind(descriptor)
}
const fn descriptor_to_kind(descriptor: u64) -> Option<InboundFrameKind> {
use zerodds_amqp_bridge::performatives::descriptor as d;
let kind = match descriptor {
d::OPEN => InboundFrameKind::Open,
d::BEGIN => InboundFrameKind::Begin,
d::ATTACH => InboundFrameKind::Attach,
d::FLOW => InboundFrameKind::Flow,
d::TRANSFER => InboundFrameKind::Transfer,
d::DISPOSITION => InboundFrameKind::Disposition,
d::DETACH => InboundFrameKind::Detach,
d::END => InboundFrameKind::End,
d::CLOSE => InboundFrameKind::Close,
_ => return None,
};
Some(kind)
}
const _: Option<AmqpValue> = None;
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use std::io::Cursor;
fn cfg() -> HandlerConfig {
HandlerConfig::for_tests(Arc::new(MetricsHub::new()))
}
struct DuplexCursor {
input: Cursor<Vec<u8>>,
output: Vec<u8>,
}
impl Read for DuplexCursor {
fn read(&mut self, b: &mut [u8]) -> std::io::Result<usize> {
self.input.read(b)
}
}
impl Write for DuplexCursor {
fn write(&mut self, b: &[u8]) -> std::io::Result<usize> {
self.output.write(b)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
fn duplex(input: Vec<u8>) -> DuplexCursor {
DuplexCursor {
input: Cursor::new(input),
output: Vec::new(),
}
}
#[test]
fn descriptor_classification_covers_9_performatives() {
use zerodds_amqp_bridge::performatives::descriptor as d;
for (code, expected) in [
(d::OPEN, InboundFrameKind::Open),
(d::BEGIN, InboundFrameKind::Begin),
(d::ATTACH, InboundFrameKind::Attach),
(d::FLOW, InboundFrameKind::Flow),
(d::TRANSFER, InboundFrameKind::Transfer),
(d::DISPOSITION, InboundFrameKind::Disposition),
(d::DETACH, InboundFrameKind::Detach),
(d::END, InboundFrameKind::End),
(d::CLOSE, InboundFrameKind::Close),
] {
assert_eq!(descriptor_to_kind(code), Some(expected));
}
assert_eq!(descriptor_to_kind(0xFFFF), None);
}
#[test]
fn handle_connection_open_close_round_trip() {
let mut input = Vec::new();
input.extend(AmqpProtocol::Amqp.as_bytes()); let open = performatives::open("client").unwrap();
let header = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + open.len() as u32,
doff: 2,
frame_type: FrameType::Amqp,
channel: 0,
};
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(header));
input.extend(&open);
let close = performatives::close().unwrap();
let header = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + close.len() as u32,
doff: 2,
frame_type: FrameType::Amqp,
channel: 0,
};
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(header));
input.extend(&close);
let mut io = duplex(input);
let stats = handle_connection(&mut io, &cfg()).unwrap();
assert!(stats.open_received);
assert!(stats.closed);
assert_eq!(stats.frames_received, 2);
assert!(stats.frames_sent >= 2);
assert_eq!(&io.output[0..4], b"AMQP");
}
#[test]
fn handle_connection_invalid_magic_rejected() {
let bad = b"NOPE\x00\x01\x00\x00";
let mut io = duplex(bad.to_vec());
let err = handle_connection(&mut io, &cfg()).unwrap_err();
assert!(matches!(
err,
HandlerError::FrameIo(FrameIoError::InvalidProtocolMagic(_))
));
}
#[test]
fn handle_connection_sasl_then_amqp() {
let mut input = Vec::new();
input.extend(AmqpProtocol::Sasl.as_bytes());
let sasl_init_descriptor = 0x0000_0000_0000_0041u64;
let init_body = AmqpExtValue::List(vec![AmqpExtValue::Symbol("ANONYMOUS".into())]);
let init_payload =
performatives::encode_performative(sasl_init_descriptor, &init_body).unwrap();
let header = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + init_payload.len() as u32,
doff: 2,
frame_type: FrameType::Sasl,
channel: 0,
};
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(header));
input.extend(&init_payload);
input.extend(AmqpProtocol::Amqp.as_bytes());
let open = performatives::open("client").unwrap();
let h = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + open.len() as u32,
doff: 2,
frame_type: FrameType::Amqp,
channel: 0,
};
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
input.extend(&open);
let close = performatives::close().unwrap();
let h = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + close.len() as u32,
doff: 2,
frame_type: FrameType::Amqp,
channel: 0,
};
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
input.extend(&close);
let mut io = duplex(input);
let stats = handle_connection(&mut io, &cfg()).unwrap();
assert!(stats.sasl_completed);
assert!(stats.open_received);
assert!(stats.closed);
}
#[test]
fn access_control_deny_attach_yields_unauthorized_metric() {
use zerodds_amqp_endpoint::security::{
AccessControlPlugin, AccessDecision, AccessOp, IdentityToken,
};
struct DenyAll;
impl AccessControlPlugin for DenyAll {
fn check(&self, _: &IdentityToken, _: &str, _: AccessOp) -> AccessDecision {
AccessDecision::Deny
}
}
let metrics = Arc::new(MetricsHub::new());
let cfg = HandlerConfig::for_tests(metrics.clone()).with_access_control(Arc::new(DenyAll));
let mut input = Vec::new();
input.extend(AmqpProtocol::Amqp.as_bytes());
let open = performatives::open("c").unwrap();
let h = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + open.len() as u32,
doff: 2,
frame_type: FrameType::Amqp,
channel: 0,
};
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
input.extend(&open);
let attach = performatives::attach("L", 0, true).unwrap();
let h = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + attach.len() as u32,
doff: 2,
frame_type: FrameType::Amqp,
channel: 0,
};
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
input.extend(&attach);
let close = performatives::close().unwrap();
let h = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + close.len() as u32,
doff: 2,
frame_type: FrameType::Amqp,
channel: 0,
};
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
input.extend(&close);
let mut io = duplex(input);
handle_connection(&mut io, &cfg).unwrap();
assert!(metrics.snapshot("errors.unauthorized").unwrap_or(0) >= 1);
}
#[test]
fn access_control_allow_does_not_increment_unauthorized() {
use zerodds_amqp_endpoint::security::AllowAll;
let metrics = Arc::new(MetricsHub::new());
let cfg = HandlerConfig::for_tests(metrics.clone()).with_access_control(Arc::new(AllowAll));
let mut input = Vec::new();
input.extend(AmqpProtocol::Amqp.as_bytes());
let open = performatives::open("c").unwrap();
let h = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + open.len() as u32,
doff: 2,
frame_type: FrameType::Amqp,
channel: 0,
};
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
input.extend(&open);
let close = performatives::close().unwrap();
let h = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + close.len() as u32,
doff: 2,
frame_type: FrameType::Amqp,
channel: 0,
};
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
input.extend(&close);
let mut io = duplex(input);
handle_connection(&mut io, &cfg).unwrap();
assert_eq!(metrics.snapshot("errors.unauthorized"), Some(0));
}
#[test]
fn metrics_counter_incremented_on_connection() {
let m = Arc::new(MetricsHub::new());
let cfg = HandlerConfig::for_tests(m.clone());
let mut input = Vec::new();
input.extend(AmqpProtocol::Amqp.as_bytes());
let close = performatives::close().unwrap();
let h = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + close.len() as u32,
doff: 2,
frame_type: FrameType::Amqp,
channel: 0,
};
let open = performatives::open("c").unwrap();
let oh = zerodds_amqp_bridge::frame::FrameHeader {
size: 8 + open.len() as u32,
doff: 2,
frame_type: FrameType::Amqp,
channel: 0,
};
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(oh));
input.extend(&open);
input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
input.extend(&close);
let mut io = duplex(input);
handle_connection(&mut io, &cfg).unwrap();
assert_eq!(m.snapshot("connections.active"), Some(0));
assert_eq!(m.snapshot("connections.total"), Some(1));
}
}