1use std::io::{self, Read, Write};
17use std::net::{TcpStream, ToSocketAddrs};
18use std::sync::Arc;
19use std::sync::atomic::{AtomicBool, Ordering};
20use std::thread;
21use std::time::Duration;
22
23use zerodds_amqp_bridge::extended_types::AmqpExtValue;
24use zerodds_amqp_bridge::frame::FrameType;
25use zerodds_amqp_bridge::performatives;
26use zerodds_amqp_endpoint::sasl::{SaslMechanism, SaslState};
27use zerodds_amqp_endpoint::{ConnectionState, MetricsHub};
28
29use crate::frame_io::{
30 AmqpProtocol, FrameIoError, read_frame, read_protocol_header, write_frame,
31 write_protocol_header,
32};
33use crate::handler::HandlerError;
34
35pub const DEFAULT_BACKOFF_INITIAL_MS: u64 = 1_000;
37pub const DEFAULT_BACKOFF_MULT: u64 = 2;
39pub const DEFAULT_BACKOFF_CAP_MS: u64 = 60_000;
41
42#[derive(Debug, Clone)]
44pub struct ClientConfig {
45 pub upstream_addr: String,
47 pub container_id: String,
49 pub max_frame_size: u32,
51 pub tls_active: bool,
53 pub plain_credentials: Option<(String, String)>,
57 pub io_timeout: Option<Duration>,
59}
60
61#[derive(Debug, Clone)]
63pub struct ReconnectConfig {
64 pub initial_ms: u64,
66 pub multiplier: u64,
68 pub cap_ms: u64,
70 pub max_attempts: Option<u32>,
72}
73
74impl Default for ReconnectConfig {
75 fn default() -> Self {
76 Self {
77 initial_ms: DEFAULT_BACKOFF_INITIAL_MS,
78 multiplier: DEFAULT_BACKOFF_MULT,
79 cap_ms: DEFAULT_BACKOFF_CAP_MS,
80 max_attempts: None,
81 }
82 }
83}
84
85impl ReconnectConfig {
86 #[must_use]
89 pub fn next_backoff_ms(&self, attempt: u32) -> u64 {
90 if attempt == 0 {
91 return self.initial_ms;
92 }
93 let mut v = self.initial_ms;
95 for _ in 0..attempt {
96 v = v.saturating_mul(self.multiplier);
97 if v >= self.cap_ms {
98 return self.cap_ms;
99 }
100 }
101 v
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct OutboundSession {
108 pub state: ConnectionState,
111 pub remote_container_id: Option<String>,
113 pub sasl_mechanism: Option<SaslMechanism>,
115}
116
117#[derive(Debug)]
119pub enum ClientError {
120 Io(io::Error),
122 FrameIo(FrameIoError),
124 Handler(HandlerError),
126 BrokerReject(String),
129 PlainRejectedNoTls,
132 NoAcceptableSaslMechanism,
134 ReconnectExhausted(u32),
136}
137
138impl core::fmt::Display for ClientError {
139 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
140 match self {
141 Self::Io(e) => write!(f, "io error: {e}"),
142 Self::FrameIo(e) => write!(f, "frame io: {e}"),
143 Self::Handler(e) => write!(f, "handler: {e}"),
144 Self::BrokerReject(s) => write!(f, "broker reject: {s}"),
145 Self::PlainRejectedNoTls => write!(
146 f,
147 "SASL PLAIN refused over unencrypted transport (Spec §2.2 Cl. 5)"
148 ),
149 Self::NoAcceptableSaslMechanism => write!(f, "no acceptable SASL mechanism offered"),
150 Self::ReconnectExhausted(n) => {
151 write!(f, "reconnect attempts exhausted after {n} tries")
152 }
153 }
154 }
155}
156
157impl std::error::Error for ClientError {}
158
159impl From<io::Error> for ClientError {
160 fn from(e: io::Error) -> Self {
161 Self::Io(e)
162 }
163}
164impl From<FrameIoError> for ClientError {
165 fn from(e: FrameIoError) -> Self {
166 Self::FrameIo(e)
167 }
168}
169impl From<HandlerError> for ClientError {
170 fn from(e: HandlerError) -> Self {
171 Self::Handler(e)
172 }
173}
174
175pub fn connect_outbound(cfg: &ClientConfig) -> Result<(TcpStream, OutboundSession), ClientError> {
184 let mut stream = tcp_connect(&cfg.upstream_addr)?;
185 if let Some(t) = cfg.io_timeout {
186 stream.set_read_timeout(Some(t))?;
187 stream.set_write_timeout(Some(t))?;
188 }
189 let session = drive_outbound_handshake(&mut stream, cfg)?;
190 Ok((stream, session))
191}
192
193fn tcp_connect(addr: &str) -> io::Result<TcpStream> {
194 let addrs: Vec<_> = addr.to_socket_addrs()?.collect();
197 let mut last_err: Option<io::Error> = None;
198 for a in addrs {
199 match TcpStream::connect_timeout(&a, Duration::from_secs(10)) {
200 Ok(s) => return Ok(s),
201 Err(e) => last_err = Some(e),
202 }
203 }
204 Err(last_err.unwrap_or_else(|| {
205 io::Error::new(io::ErrorKind::AddrNotAvailable, "no resolvable address")
206 }))
207}
208
209fn drive_outbound_handshake<S: Read + Write>(
210 stream: &mut S,
211 cfg: &ClientConfig,
212) -> Result<OutboundSession, ClientError> {
213 write_protocol_header(stream, AmqpProtocol::Sasl)?;
217 let server_proto = read_protocol_header(stream)?;
218 let mechanism = if server_proto.protocol == AmqpProtocol::Sasl {
219 Some(do_outbound_sasl(stream, cfg)?)
220 } else {
221 None
223 };
224
225 write_protocol_header(stream, AmqpProtocol::Amqp)?;
227 let amqp_proto = read_protocol_header(stream)?;
228 if amqp_proto.protocol != AmqpProtocol::Amqp {
229 return Err(ClientError::FrameIo(FrameIoError::UnsupportedProtocolId(
230 amqp_proto.protocol.as_bytes()[4],
231 )));
232 }
233
234 let open = performatives::open(&cfg.container_id)
236 .map_err(|e| ClientError::Handler(HandlerError::PerformativeDecode(format!("{e}"))))?;
237 write_frame(stream, FrameType::Amqp, 0, &open)?;
238
239 let frame = read_frame(stream, cfg.max_frame_size)?;
241 let remote_container_id = extract_container_id(&frame.body);
242
243 Ok(OutboundSession {
244 state: ConnectionState::Opened,
245 remote_container_id,
246 sasl_mechanism: mechanism,
247 })
248}
249
250fn do_outbound_sasl<S: Read + Write>(
251 stream: &mut S,
252 cfg: &ClientConfig,
253) -> Result<SaslMechanism, ClientError> {
254 let mechs_frame = read_frame(stream, cfg.max_frame_size)?;
256 if mechs_frame.header.frame_type != FrameType::Sasl {
257 return Err(ClientError::FrameIo(FrameIoError::UnsupportedProtocolId(
258 mechs_frame.header.frame_type.to_u8(),
259 )));
260 }
261 let (_descriptor, body, _) =
263 zerodds_amqp_bridge::performatives::decode_performative(&mechs_frame.body)
264 .map_err(|e| ClientError::Handler(HandlerError::PerformativeDecode(format!("{e}"))))?;
265 let offered = parse_offered_mechanisms(&body);
266
267 let chosen = SaslState::select_outbound(&offered, cfg.tls_active)
269 .ok_or(ClientError::NoAcceptableSaslMechanism)?;
270
271 if chosen == SaslMechanism::Plain && !cfg.tls_active {
275 return Err(ClientError::PlainRejectedNoTls);
276 }
277
278 let init_descriptor: u64 = 0x0000_0000_0000_0041;
280 let init_body = build_sasl_init(chosen, cfg);
281 let init_payload = performatives::encode_performative(init_descriptor, &init_body)
282 .map_err(|e| ClientError::Handler(HandlerError::PerformativeDecode(format!("{e}"))))?;
283 write_frame(stream, FrameType::Sasl, 0, &init_payload)?;
284
285 let outcome_frame = read_frame(stream, cfg.max_frame_size)?;
287 let (descriptor, outcome_body, _) =
288 zerodds_amqp_bridge::performatives::decode_performative(&outcome_frame.body)
289 .map_err(|e| ClientError::Handler(HandlerError::PerformativeDecode(format!("{e}"))))?;
290 if descriptor != 0x0000_0000_0000_0044 {
291 return Err(ClientError::BrokerReject(format!(
292 "expected sasl-outcome (0x44), got descriptor 0x{descriptor:x}"
293 )));
294 }
295 let code = extract_outcome_code(&outcome_body);
296 if code != Some(0) {
297 return Err(ClientError::BrokerReject(format!(
298 "sasl outcome code {code:?}"
299 )));
300 }
301 Ok(chosen)
302}
303
304fn parse_offered_mechanisms(body: &AmqpExtValue) -> Vec<SaslMechanism> {
305 let mut out = Vec::new();
306 if let AmqpExtValue::List(items) = body {
307 if let Some(AmqpExtValue::Array(arr)) = items.first() {
308 for sym in arr {
309 if let AmqpExtValue::Symbol(s) = sym {
310 if let Some(m) = SaslMechanism::from_name(s) {
311 out.push(m);
312 }
313 }
314 }
315 } else if let Some(AmqpExtValue::Symbol(s)) = items.first() {
316 if let Some(m) = SaslMechanism::from_name(s) {
318 out.push(m);
319 }
320 }
321 }
322 out
323}
324
325fn build_sasl_init(mech: SaslMechanism, cfg: &ClientConfig) -> AmqpExtValue {
326 let mut items: Vec<AmqpExtValue> = Vec::new();
328 items.push(AmqpExtValue::Symbol(mech.name().to_string()));
329 let response = match (mech, &cfg.plain_credentials) {
330 (SaslMechanism::Plain, Some((user, pw))) => {
331 let mut buf: Vec<u8> = Vec::new();
333 buf.push(0);
334 buf.extend(user.as_bytes());
335 buf.push(0);
336 buf.extend(pw.as_bytes());
337 AmqpExtValue::Binary(buf)
338 }
339 (SaslMechanism::Anonymous, _) => AmqpExtValue::Binary(b"anonymous".to_vec()),
340 (SaslMechanism::External, _) => AmqpExtValue::Binary(Vec::new()),
341 (SaslMechanism::Plain, None) => AmqpExtValue::Binary(Vec::new()),
342 (SaslMechanism::ScramSha256, Some((user, _pw))) => {
343 let body = format!("n,,n={user},r=");
351 AmqpExtValue::Binary(body.into_bytes())
352 }
353 (SaslMechanism::ScramSha256, None) => {
354 AmqpExtValue::Binary(Vec::new())
357 }
358 };
359 items.push(response);
360 AmqpExtValue::List(items)
361}
362
363fn extract_outcome_code(body: &AmqpExtValue) -> Option<u8> {
364 if let AmqpExtValue::List(items) = body {
365 if let Some(AmqpExtValue::Ubyte(code)) = items.first() {
366 return Some(*code);
367 }
368 }
369 None
370}
371
372fn extract_container_id(performative_body: &[u8]) -> Option<String> {
373 let (_descriptor, body, _) =
374 zerodds_amqp_bridge::performatives::decode_performative(performative_body).ok()?;
375 if let AmqpExtValue::List(items) = body {
376 if let Some(AmqpExtValue::Str(s)) = items.first() {
377 return Some(s.clone());
378 }
379 }
380 None
381}
382
383pub fn connect_with_reconnect(
396 cfg: &ClientConfig,
397 reconnect: &ReconnectConfig,
398 shutdown_signal: &Arc<AtomicBool>,
399 metrics: &Arc<MetricsHub>,
400) -> Result<(TcpStream, OutboundSession), ClientError> {
401 let mut attempt: u32 = 0;
402 let mut last_err: Option<ClientError> = None;
403 loop {
404 if shutdown_signal.load(Ordering::Relaxed) {
405 return Err(last_err.unwrap_or(ClientError::ReconnectExhausted(attempt)));
406 }
407 if let Some(max) = reconnect.max_attempts {
408 if attempt >= max {
409 return Err(ClientError::ReconnectExhausted(attempt));
410 }
411 }
412 match connect_outbound(cfg) {
413 Ok(ok) => return Ok(ok),
414 Err(e) => {
415 metrics.on_decode_error(); last_err = Some(e);
418 let wait_ms = reconnect.next_backoff_ms(attempt);
419 attempt = attempt.saturating_add(1);
420 thread::sleep(Duration::from_millis(wait_ms));
421 }
422 }
423 }
424}
425
426#[cfg(test)]
427#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
428mod tests {
429 use super::*;
430 use std::net::TcpListener;
431 use std::sync::atomic::AtomicBool;
432
433 fn cfg(addr: &str) -> ClientConfig {
434 ClientConfig {
435 upstream_addr: addr.into(),
436 container_id: "client-test".into(),
437 max_frame_size: 65_536,
438 tls_active: false,
439 plain_credentials: None,
440 io_timeout: Some(Duration::from_secs(2)),
441 }
442 }
443
444 #[test]
447 fn backoff_starts_at_initial() {
448 let r = ReconnectConfig::default();
449 assert_eq!(r.next_backoff_ms(0), 1_000);
450 }
451
452 #[test]
453 fn backoff_doubles_until_cap() {
454 let r = ReconnectConfig::default();
455 assert_eq!(r.next_backoff_ms(1), 2_000);
456 assert_eq!(r.next_backoff_ms(2), 4_000);
457 assert_eq!(r.next_backoff_ms(3), 8_000);
458 assert_eq!(r.next_backoff_ms(4), 16_000);
459 assert_eq!(r.next_backoff_ms(5), 32_000);
460 assert_eq!(r.next_backoff_ms(6), 60_000); assert_eq!(r.next_backoff_ms(20), 60_000); }
463
464 #[test]
465 fn backoff_respects_custom_cap() {
466 let r = ReconnectConfig {
467 initial_ms: 100,
468 multiplier: 3,
469 cap_ms: 5_000,
470 max_attempts: None,
471 };
472 assert_eq!(r.next_backoff_ms(0), 100);
473 assert_eq!(r.next_backoff_ms(1), 300);
474 assert_eq!(r.next_backoff_ms(2), 900);
475 assert_eq!(r.next_backoff_ms(3), 2_700);
476 assert_eq!(r.next_backoff_ms(4), 5_000); }
478
479 #[test]
480 fn backoff_with_unit_multiplier_stays_at_initial() {
481 let r = ReconnectConfig {
482 initial_ms: 500,
483 multiplier: 1,
484 cap_ms: 60_000,
485 max_attempts: None,
486 };
487 assert_eq!(r.next_backoff_ms(0), 500);
488 assert_eq!(r.next_backoff_ms(5), 500);
489 }
490
491 #[test]
494 fn parse_offered_mechanisms_array_form() {
495 let body = AmqpExtValue::List(vec![AmqpExtValue::Array(vec![
496 AmqpExtValue::Symbol("PLAIN".into()),
497 AmqpExtValue::Symbol("ANONYMOUS".into()),
498 ])]);
499 let mechs = parse_offered_mechanisms(&body);
500 assert_eq!(mechs.len(), 2);
501 assert!(mechs.contains(&SaslMechanism::Plain));
502 assert!(mechs.contains(&SaslMechanism::Anonymous));
503 }
504
505 #[test]
506 fn parse_offered_mechanisms_single_symbol() {
507 let body = AmqpExtValue::List(vec![AmqpExtValue::Symbol("EXTERNAL".into())]);
508 let mechs = parse_offered_mechanisms(&body);
509 assert_eq!(mechs, vec![SaslMechanism::External]);
510 }
511
512 #[test]
513 fn parse_offered_mechanisms_unknown_filtered() {
514 let body = AmqpExtValue::List(vec![AmqpExtValue::Array(vec![
515 AmqpExtValue::Symbol("BOGUS".into()),
516 AmqpExtValue::Symbol("ANONYMOUS".into()),
517 ])]);
518 let mechs = parse_offered_mechanisms(&body);
519 assert_eq!(mechs, vec![SaslMechanism::Anonymous]);
520 }
521
522 #[test]
525 fn sasl_init_plain_includes_credentials() {
526 let mut c = cfg("x:1");
527 c.plain_credentials = Some(("alice".into(), "secret".into()));
528 let body = build_sasl_init(SaslMechanism::Plain, &c);
529 let items = match body {
530 AmqpExtValue::List(v) => v,
531 _ => panic!(),
532 };
533 assert_eq!(items[0], AmqpExtValue::Symbol("PLAIN".into()));
534 let response = match &items[1] {
535 AmqpExtValue::Binary(b) => b,
536 _ => panic!(),
537 };
538 assert_eq!(response, &b"\0alice\0secret".to_vec());
540 }
541
542 #[test]
543 fn sasl_init_anonymous_uses_marker() {
544 let body = build_sasl_init(SaslMechanism::Anonymous, &cfg("x:1"));
545 let items = match body {
546 AmqpExtValue::List(v) => v,
547 _ => panic!(),
548 };
549 assert_eq!(items[0], AmqpExtValue::Symbol("ANONYMOUS".into()));
550 }
551
552 #[test]
553 fn sasl_init_external_has_empty_response() {
554 let body = build_sasl_init(SaslMechanism::External, &cfg("x:1"));
555 let items = match body {
556 AmqpExtValue::List(v) => v,
557 _ => panic!(),
558 };
559 assert_eq!(items[1], AmqpExtValue::Binary(Vec::new()));
560 }
561
562 #[test]
567 fn outbound_connect_to_local_server() {
568 use crate::handler::{HandlerConfig, handle_connection};
571
572 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
573 let port = listener.local_addr().unwrap().port();
574 listener.set_nonblocking(false).unwrap();
575
576 let server_metrics = Arc::new(MetricsHub::new());
578 let server_metrics_clone = server_metrics.clone();
579 let server = thread::spawn(move || {
580 if let Ok((mut sock, _)) = listener.accept() {
581 let _ = sock.set_read_timeout(Some(Duration::from_secs(2)));
582 let _ = sock.set_write_timeout(Some(Duration::from_secs(2)));
583 let cfg = HandlerConfig::for_tests(server_metrics_clone);
584 let _ = handle_connection(&mut sock, &cfg);
585 }
586 });
587
588 let client_cfg = cfg(&format!("127.0.0.1:{port}"));
590 let metrics = Arc::new(MetricsHub::new());
591 let shutdown = Arc::new(AtomicBool::new(false));
592 let result = connect_with_reconnect(
593 &client_cfg,
594 &ReconnectConfig {
595 max_attempts: Some(1),
596 ..ReconnectConfig::default()
597 },
598 &shutdown,
599 &metrics,
600 );
601 assert!(result.is_ok(), "connect failed: {result:?}");
602 let (mut stream, session) = result.unwrap();
603 assert_eq!(session.state, ConnectionState::Opened);
604 assert!(session.remote_container_id.is_some());
605
606 let close = performatives::close().unwrap();
608 write_frame(&mut stream, FrameType::Amqp, 0, &close).unwrap();
609 drop(stream);
610 let _ = server.join();
611 }
612
613 #[test]
614 fn reconnect_exhausts_with_max_attempts() {
615 let cfg = cfg("127.0.0.1:1"); let metrics = Arc::new(MetricsHub::new());
618 let shutdown = Arc::new(AtomicBool::new(false));
619 let r = ReconnectConfig {
620 initial_ms: 1, multiplier: 1,
622 cap_ms: 1,
623 max_attempts: Some(2),
624 };
625 let err = connect_with_reconnect(&cfg, &r, &shutdown, &metrics).unwrap_err();
626 assert!(matches!(err, ClientError::ReconnectExhausted(_)));
627 }
628
629 #[test]
630 fn reconnect_aborts_on_shutdown_signal() {
631 let cfg = cfg("127.0.0.1:1");
632 let metrics = Arc::new(MetricsHub::new());
633 let shutdown = Arc::new(AtomicBool::new(false));
634 let s = shutdown.clone();
636 thread::spawn(move || {
637 thread::sleep(Duration::from_millis(50));
638 s.store(true, Ordering::Relaxed);
639 });
640 let r = ReconnectConfig {
641 initial_ms: 200, multiplier: 1,
643 cap_ms: 200,
644 max_attempts: None, };
646 let err = connect_with_reconnect(&cfg, &r, &shutdown, &metrics);
647 assert!(err.is_err());
650 }
651}