1use std::io::{Read, Write};
17use std::sync::Arc;
18
19use zerodds_amqp_bridge::extended_types::AmqpExtValue;
20use zerodds_amqp_bridge::frame::FrameType;
21use zerodds_amqp_bridge::performatives;
22use zerodds_amqp_bridge::types::AmqpValue;
23use zerodds_amqp_endpoint::security::SaslSubject;
24use zerodds_amqp_endpoint::security::{
25 AccessControlPlugin, AccessDecision, AccessOp, IdentityToken, build_identity_token,
26};
27use zerodds_amqp_endpoint::session::InboundFrameKind;
28use zerodds_amqp_endpoint::{ConnectionState, EndpointError, MetricsHub, advance_connection};
29
30use crate::frame_io::{
31 AmqpProtocol, FrameIoError, read_frame, read_protocol_header, write_frame,
32 write_protocol_header,
33};
34
35#[derive(Debug, Default, Clone)]
37pub struct ConnectionStats {
38 pub frames_received: u64,
40 pub frames_sent: u64,
42 pub sasl_completed: bool,
44 pub open_received: bool,
46 pub closed: bool,
48}
49
50#[derive(Debug)]
52pub enum HandlerError {
53 FrameIo(FrameIoError),
55 Endpoint(EndpointError),
57 PerformativeDecode(String),
59 UnexpectedEof,
61}
62
63impl core::fmt::Display for HandlerError {
64 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
65 match self {
66 Self::FrameIo(e) => write!(f, "frame io: {e}"),
67 Self::Endpoint(e) => write!(f, "endpoint: {e:?}"),
68 Self::PerformativeDecode(s) => write!(f, "performative decode: {s}"),
69 Self::UnexpectedEof => write!(f, "unexpected eof"),
70 }
71 }
72}
73
74impl std::error::Error for HandlerError {}
75
76impl From<FrameIoError> for HandlerError {
77 fn from(e: FrameIoError) -> Self {
78 Self::FrameIo(e)
79 }
80}
81
82impl From<EndpointError> for HandlerError {
83 fn from(e: EndpointError) -> Self {
84 Self::Endpoint(e)
85 }
86}
87
88#[derive(Clone)]
90pub struct HandlerConfig {
91 pub container_id: String,
94 pub max_frame_size: u32,
96 pub tls_active: bool,
98 pub metrics: Arc<MetricsHub>,
100 pub access_control: Option<Arc<dyn AccessControlPlugin + Send + Sync>>,
104 pub default_identity: IdentityToken,
108}
109
110impl core::fmt::Debug for HandlerConfig {
111 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
112 f.debug_struct("HandlerConfig")
113 .field("container_id", &self.container_id)
114 .field("max_frame_size", &self.max_frame_size)
115 .field("tls_active", &self.tls_active)
116 .field("access_control_present", &self.access_control.is_some())
117 .field(
118 "default_identity_subject",
119 &self.default_identity.subject_name,
120 )
121 .finish()
122 }
123}
124
125impl HandlerConfig {
126 #[must_use]
128 pub fn for_tests(metrics: Arc<MetricsHub>) -> Self {
129 Self {
130 container_id: "zerodds-amqp-endpoint".to_string(),
131 max_frame_size: 1_048_576,
132 tls_active: false,
133 metrics,
134 access_control: None,
135 default_identity: build_identity_token(&SaslSubject::Anonymous),
136 }
137 }
138
139 #[must_use]
141 pub fn with_access_control(
142 mut self,
143 plugin: Arc<dyn AccessControlPlugin + Send + Sync>,
144 ) -> Self {
145 self.access_control = Some(plugin);
146 self
147 }
148
149 #[must_use]
151 pub fn with_identity(mut self, identity: IdentityToken) -> Self {
152 self.default_identity = identity;
153 self
154 }
155}
156
157pub fn handle_connection<S: Read + Write>(
167 stream: &mut S,
168 cfg: &HandlerConfig,
169) -> Result<ConnectionStats, HandlerError> {
170 cfg.metrics.on_connection_open();
171 let mut stats = ConnectionStats::default();
172
173 let first = read_protocol_header(stream)?;
176 match first.protocol {
177 AmqpProtocol::Sasl => {
178 do_sasl_phase(stream, cfg, &mut stats)?;
181 let second = read_protocol_header(stream)?;
184 if second.protocol != AmqpProtocol::Amqp {
185 return Err(HandlerError::FrameIo(FrameIoError::UnsupportedProtocolId(
186 second.protocol.as_bytes()[4],
187 )));
188 }
189 write_protocol_header(stream, AmqpProtocol::Amqp)?;
191 }
192 AmqpProtocol::Amqp => {
193 write_protocol_header(stream, AmqpProtocol::Amqp)?;
195 }
196 }
197
198 let mut state = ConnectionState::Start;
200 state = advance_connection(state, InboundFrameKind::Header)?;
201 state = advance_connection(state, InboundFrameKind::Header)?;
202
203 do_amqp_phase(stream, cfg, &mut stats, &mut state)?;
205
206 cfg.metrics.on_connection_close();
207 stats.closed = true;
208 Ok(stats)
209}
210
211fn do_sasl_phase<S: Read + Write>(
212 stream: &mut S,
213 cfg: &HandlerConfig,
214 stats: &mut ConnectionStats,
215) -> Result<(), HandlerError> {
216 write_protocol_header(stream, AmqpProtocol::Sasl)?;
218
219 let mechs = build_sasl_mechanisms(cfg.tls_active);
223 let sasl_mechanisms_descriptor: u64 = 0x0000_0000_0000_0040;
224 let body = performatives::encode_performative(sasl_mechanisms_descriptor, &mechs)
225 .map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
226 write_frame(stream, FrameType::Sasl, 0, &body)?;
227 stats.frames_sent += 1;
228
229 let init_frame = read_frame(stream, cfg.max_frame_size)?;
234 stats.frames_received += 1;
235 if init_frame.header.frame_type != FrameType::Sasl {
236 return Err(HandlerError::FrameIo(FrameIoError::UnsupportedProtocolId(
237 init_frame.header.frame_type.to_u8(),
238 )));
239 }
240
241 let outcome_descriptor: u64 = 0x0000_0000_0000_0044;
243 let outcome_body = AmqpExtValue::List(vec![AmqpExtValue::Ubyte(0)]);
245 let body = performatives::encode_performative(outcome_descriptor, &outcome_body)
246 .map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
247 write_frame(stream, FrameType::Sasl, 0, &body)?;
248 stats.frames_sent += 1;
249 stats.sasl_completed = true;
250 Ok(())
251}
252
253fn build_sasl_mechanisms(tls_active: bool) -> AmqpExtValue {
254 let mut mechs: Vec<AmqpExtValue> = Vec::new();
256 if tls_active {
257 mechs.push(AmqpExtValue::Symbol("PLAIN".to_string()));
258 }
259 mechs.push(AmqpExtValue::Symbol("ANONYMOUS".to_string()));
260 mechs.push(AmqpExtValue::Symbol("EXTERNAL".to_string()));
261 AmqpExtValue::List(vec![AmqpExtValue::Array(mechs)])
262}
263
264fn do_amqp_phase<S: Read + Write>(
265 stream: &mut S,
266 cfg: &HandlerConfig,
267 stats: &mut ConnectionStats,
268 state: &mut ConnectionState,
269) -> Result<(), HandlerError> {
270 loop {
271 let frame = match read_frame(stream, cfg.max_frame_size) {
272 Ok(f) => f,
273 Err(FrameIoError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
274 return Ok(());
276 }
277 Err(e) => return Err(HandlerError::FrameIo(e)),
278 };
279 stats.frames_received += 1;
280
281 if frame.body.is_empty() {
283 continue;
285 }
286
287 let kind = match classify_performative(&frame.body) {
289 Some(k) => k,
290 None => {
291 cfg.metrics.on_decode_error();
293 continue;
294 }
295 };
296
297 *state = advance_connection(*state, kind)?;
299
300 match kind {
301 InboundFrameKind::Open => {
302 stats.open_received = true;
303 let open = performatives::open(&cfg.container_id)
307 .map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
308 write_frame(stream, FrameType::Amqp, 0, &open)?;
309 stats.frames_sent += 1;
310 *state = advance_connection(*state, InboundFrameKind::Open)?;
311 }
312 InboundFrameKind::Begin => {
313 let begin = performatives::begin(Some(0), 0, 1024, 1024)
315 .map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
316 write_frame(stream, FrameType::Amqp, frame.header.channel, &begin)?;
317 stats.frames_sent += 1;
318 }
319 InboundFrameKind::Attach => {
320 let (link_name, target_addr, is_sender) = parse_attach(&frame.body);
324 if !check_access(
325 cfg,
326 &target_addr,
327 if is_sender {
328 AccessOp::AttachReceiver
329 } else {
330 AccessOp::AttachSender
331 },
332 ) {
333 cfg.metrics.on_unauthorized();
334 let detach = performatives::detach(0, true)
335 .map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
336 write_frame(stream, FrameType::Amqp, frame.header.channel, &detach)?;
337 stats.frames_sent += 1;
338 continue;
339 }
340 let attach = performatives::attach(&link_name, 0, true)
342 .map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
343 write_frame(stream, FrameType::Amqp, frame.header.channel, &attach)?;
344 stats.frames_sent += 1;
345 }
346 InboundFrameKind::Transfer => {
347 if !check_access(cfg, "<transfer>", AccessOp::ReceiveSample) {
352 cfg.metrics.on_unauthorized();
353 continue;
354 }
355 cfg.metrics.on_transfer_received();
356 }
358 InboundFrameKind::Close => {
359 let close = performatives::close()
362 .map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
363 write_frame(stream, FrameType::Amqp, 0, &close)?;
364 stats.frames_sent += 1;
365 *state = advance_connection(*state, InboundFrameKind::Close)?;
366 return Ok(());
367 }
368 InboundFrameKind::End => {
369 let end = performatives::end()
370 .map_err(|e| HandlerError::PerformativeDecode(format!("{e}")))?;
371 write_frame(stream, FrameType::Amqp, frame.header.channel, &end)?;
372 stats.frames_sent += 1;
373 }
374 InboundFrameKind::Flow | InboundFrameKind::Disposition | InboundFrameKind::Detach => {}
376 InboundFrameKind::Header => {
377 }
379 }
380 }
381}
382
383fn check_access(cfg: &HandlerConfig, address: &str, op: AccessOp) -> bool {
389 let Some(plugin) = cfg.access_control.as_ref() else {
390 return true;
392 };
393 matches!(
394 plugin.check(&cfg.default_identity, address, op),
395 AccessDecision::Allow
396 )
397}
398
399fn parse_attach(body: &[u8]) -> (String, String, bool) {
405 let default = ("link".to_string(), "<unknown>".to_string(), true);
406 let Ok((_, body_value, _)) = zerodds_amqp_bridge::performatives::decode_performative(body)
407 else {
408 return default;
409 };
410 let AmqpExtValue::List(items) = body_value else {
411 return default;
412 };
413 let link_name = items
416 .first()
417 .and_then(|v| match v {
418 AmqpExtValue::Str(s) => Some(s.clone()),
419 _ => None,
420 })
421 .unwrap_or_else(|| default.0.clone());
422 let is_sender_from_role = items
424 .get(2)
425 .map(|v| matches!(v, AmqpExtValue::Boolean(false)))
426 .unwrap_or(default.2);
427 let target_addr = items
432 .get(6)
433 .and_then(extract_address)
434 .or_else(|| items.get(5).and_then(extract_address))
435 .unwrap_or_else(|| default.1.clone());
436 (link_name, target_addr, is_sender_from_role)
437}
438
439fn extract_address(v: &AmqpExtValue) -> Option<String> {
440 match v {
441 AmqpExtValue::Str(s) => Some(s.clone()),
442 AmqpExtValue::Symbol(s) => Some(s.clone()),
443 AmqpExtValue::List(items) => items.first().and_then(|x| match x {
444 AmqpExtValue::Str(s) | AmqpExtValue::Symbol(s) => Some(s.clone()),
445 _ => None,
446 }),
447 _ => None,
448 }
449}
450
451#[must_use]
454pub fn classify_performative(body: &[u8]) -> Option<InboundFrameKind> {
455 let (descriptor, _, _) = zerodds_amqp_bridge::performatives::decode_performative(body).ok()?;
458 descriptor_to_kind(descriptor)
459}
460
461const fn descriptor_to_kind(descriptor: u64) -> Option<InboundFrameKind> {
462 use zerodds_amqp_bridge::performatives::descriptor as d;
463 let kind = match descriptor {
464 d::OPEN => InboundFrameKind::Open,
465 d::BEGIN => InboundFrameKind::Begin,
466 d::ATTACH => InboundFrameKind::Attach,
467 d::FLOW => InboundFrameKind::Flow,
468 d::TRANSFER => InboundFrameKind::Transfer,
469 d::DISPOSITION => InboundFrameKind::Disposition,
470 d::DETACH => InboundFrameKind::Detach,
471 d::END => InboundFrameKind::End,
472 d::CLOSE => InboundFrameKind::Close,
473 _ => return None,
474 };
475 Some(kind)
476}
477
478const _: Option<AmqpValue> = None;
480
481#[cfg(test)]
482#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
483mod tests {
484 use super::*;
485 use std::io::Cursor;
486
487 fn cfg() -> HandlerConfig {
488 HandlerConfig::for_tests(Arc::new(MetricsHub::new()))
489 }
490
491 struct DuplexCursor {
494 input: Cursor<Vec<u8>>,
495 output: Vec<u8>,
496 }
497 impl Read for DuplexCursor {
498 fn read(&mut self, b: &mut [u8]) -> std::io::Result<usize> {
499 self.input.read(b)
500 }
501 }
502 impl Write for DuplexCursor {
503 fn write(&mut self, b: &[u8]) -> std::io::Result<usize> {
504 self.output.write(b)
505 }
506 fn flush(&mut self) -> std::io::Result<()> {
507 Ok(())
508 }
509 }
510
511 fn duplex(input: Vec<u8>) -> DuplexCursor {
512 DuplexCursor {
513 input: Cursor::new(input),
514 output: Vec::new(),
515 }
516 }
517
518 #[test]
519 fn descriptor_classification_covers_9_performatives() {
520 use zerodds_amqp_bridge::performatives::descriptor as d;
521 for (code, expected) in [
522 (d::OPEN, InboundFrameKind::Open),
523 (d::BEGIN, InboundFrameKind::Begin),
524 (d::ATTACH, InboundFrameKind::Attach),
525 (d::FLOW, InboundFrameKind::Flow),
526 (d::TRANSFER, InboundFrameKind::Transfer),
527 (d::DISPOSITION, InboundFrameKind::Disposition),
528 (d::DETACH, InboundFrameKind::Detach),
529 (d::END, InboundFrameKind::End),
530 (d::CLOSE, InboundFrameKind::Close),
531 ] {
532 assert_eq!(descriptor_to_kind(code), Some(expected));
533 }
534 assert_eq!(descriptor_to_kind(0xFFFF), None);
535 }
536
537 #[test]
538 fn handle_connection_open_close_round_trip() {
539 let mut input = Vec::new();
541 input.extend(AmqpProtocol::Amqp.as_bytes()); let open = performatives::open("client").unwrap();
544 let header = zerodds_amqp_bridge::frame::FrameHeader {
545 size: 8 + open.len() as u32,
546 doff: 2,
547 frame_type: FrameType::Amqp,
548 channel: 0,
549 };
550 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(header));
551 input.extend(&open);
552 let close = performatives::close().unwrap();
554 let header = zerodds_amqp_bridge::frame::FrameHeader {
555 size: 8 + close.len() as u32,
556 doff: 2,
557 frame_type: FrameType::Amqp,
558 channel: 0,
559 };
560 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(header));
561 input.extend(&close);
562
563 let mut io = duplex(input);
564 let stats = handle_connection(&mut io, &cfg()).unwrap();
565 assert!(stats.open_received);
566 assert!(stats.closed);
567 assert_eq!(stats.frames_received, 2);
568 assert!(stats.frames_sent >= 2);
570 assert_eq!(&io.output[0..4], b"AMQP");
572 }
573
574 #[test]
575 fn handle_connection_invalid_magic_rejected() {
576 let bad = b"NOPE\x00\x01\x00\x00";
577 let mut io = duplex(bad.to_vec());
578 let err = handle_connection(&mut io, &cfg()).unwrap_err();
579 assert!(matches!(
580 err,
581 HandlerError::FrameIo(FrameIoError::InvalidProtocolMagic(_))
582 ));
583 }
584
585 #[test]
586 fn handle_connection_sasl_then_amqp() {
587 let mut input = Vec::new();
591 input.extend(AmqpProtocol::Sasl.as_bytes());
592 let sasl_init_descriptor = 0x0000_0000_0000_0041u64;
594 let init_body = AmqpExtValue::List(vec![AmqpExtValue::Symbol("ANONYMOUS".into())]);
595 let init_payload =
596 performatives::encode_performative(sasl_init_descriptor, &init_body).unwrap();
597 let header = zerodds_amqp_bridge::frame::FrameHeader {
598 size: 8 + init_payload.len() as u32,
599 doff: 2,
600 frame_type: FrameType::Sasl,
601 channel: 0,
602 };
603 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(header));
604 input.extend(&init_payload);
605 input.extend(AmqpProtocol::Amqp.as_bytes());
607 let open = performatives::open("client").unwrap();
609 let h = zerodds_amqp_bridge::frame::FrameHeader {
610 size: 8 + open.len() as u32,
611 doff: 2,
612 frame_type: FrameType::Amqp,
613 channel: 0,
614 };
615 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
616 input.extend(&open);
617 let close = performatives::close().unwrap();
618 let h = zerodds_amqp_bridge::frame::FrameHeader {
619 size: 8 + close.len() as u32,
620 doff: 2,
621 frame_type: FrameType::Amqp,
622 channel: 0,
623 };
624 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
625 input.extend(&close);
626
627 let mut io = duplex(input);
628 let stats = handle_connection(&mut io, &cfg()).unwrap();
629 assert!(stats.sasl_completed);
630 assert!(stats.open_received);
631 assert!(stats.closed);
632 }
633
634 #[test]
635 fn access_control_deny_attach_yields_unauthorized_metric() {
636 use zerodds_amqp_endpoint::security::{
637 AccessControlPlugin, AccessDecision, AccessOp, IdentityToken,
638 };
639 struct DenyAll;
640 impl AccessControlPlugin for DenyAll {
641 fn check(&self, _: &IdentityToken, _: &str, _: AccessOp) -> AccessDecision {
642 AccessDecision::Deny
643 }
644 }
645
646 let metrics = Arc::new(MetricsHub::new());
647 let cfg = HandlerConfig::for_tests(metrics.clone()).with_access_control(Arc::new(DenyAll));
648
649 let mut input = Vec::new();
651 input.extend(AmqpProtocol::Amqp.as_bytes());
652 let open = performatives::open("c").unwrap();
653 let h = zerodds_amqp_bridge::frame::FrameHeader {
654 size: 8 + open.len() as u32,
655 doff: 2,
656 frame_type: FrameType::Amqp,
657 channel: 0,
658 };
659 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
660 input.extend(&open);
661
662 let attach = performatives::attach("L", 0, true).unwrap();
664 let h = zerodds_amqp_bridge::frame::FrameHeader {
665 size: 8 + attach.len() as u32,
666 doff: 2,
667 frame_type: FrameType::Amqp,
668 channel: 0,
669 };
670 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
671 input.extend(&attach);
672
673 let close = performatives::close().unwrap();
674 let h = zerodds_amqp_bridge::frame::FrameHeader {
675 size: 8 + close.len() as u32,
676 doff: 2,
677 frame_type: FrameType::Amqp,
678 channel: 0,
679 };
680 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
681 input.extend(&close);
682
683 let mut io = duplex(input);
684 handle_connection(&mut io, &cfg).unwrap();
685 assert!(metrics.snapshot("errors.unauthorized").unwrap_or(0) >= 1);
688 }
689
690 #[test]
691 fn access_control_allow_does_not_increment_unauthorized() {
692 use zerodds_amqp_endpoint::security::AllowAll;
693 let metrics = Arc::new(MetricsHub::new());
694 let cfg = HandlerConfig::for_tests(metrics.clone()).with_access_control(Arc::new(AllowAll));
695
696 let mut input = Vec::new();
697 input.extend(AmqpProtocol::Amqp.as_bytes());
698 let open = performatives::open("c").unwrap();
699 let h = zerodds_amqp_bridge::frame::FrameHeader {
700 size: 8 + open.len() as u32,
701 doff: 2,
702 frame_type: FrameType::Amqp,
703 channel: 0,
704 };
705 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
706 input.extend(&open);
707 let close = performatives::close().unwrap();
708 let h = zerodds_amqp_bridge::frame::FrameHeader {
709 size: 8 + close.len() as u32,
710 doff: 2,
711 frame_type: FrameType::Amqp,
712 channel: 0,
713 };
714 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
715 input.extend(&close);
716
717 let mut io = duplex(input);
718 handle_connection(&mut io, &cfg).unwrap();
719 assert_eq!(metrics.snapshot("errors.unauthorized"), Some(0));
720 }
721
722 #[test]
723 fn metrics_counter_incremented_on_connection() {
724 let m = Arc::new(MetricsHub::new());
725 let cfg = HandlerConfig::for_tests(m.clone());
726 let mut input = Vec::new();
727 input.extend(AmqpProtocol::Amqp.as_bytes());
728 let close = performatives::close().unwrap();
729 let h = zerodds_amqp_bridge::frame::FrameHeader {
730 size: 8 + close.len() as u32,
731 doff: 2,
732 frame_type: FrameType::Amqp,
733 channel: 0,
734 };
735 let open = performatives::open("c").unwrap();
737 let oh = zerodds_amqp_bridge::frame::FrameHeader {
738 size: 8 + open.len() as u32,
739 doff: 2,
740 frame_type: FrameType::Amqp,
741 channel: 0,
742 };
743 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(oh));
744 input.extend(&open);
745 input.extend(zerodds_amqp_bridge::frame::encode_frame_header(h));
746 input.extend(&close);
747 let mut io = duplex(input);
748 handle_connection(&mut io, &cfg).unwrap();
749 assert_eq!(m.snapshot("connections.active"), Some(0));
752 assert_eq!(m.snapshot("connections.total"), Some(1));
753 }
754}