#![allow(clippy::expect_used)]
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use bytes::BytesMut;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use magnetar_proto::{ConnectionConfig, FrameError, decode_one, encode_command, pb};
use magnetar_runtime_moonpool::{Client, ClientError, EngineError, MoonpoolEngine};
use moonpool_core::{NetworkProvider, Providers, TaskProvider, TcpListenerTrait};
use moonpool_sim::providers::SimProviders;
use moonpool_sim::{SimContext, SimulationBuilder, SimulationError, SimulationResult, Workload};
use parking_lot::Mutex;
const BROKER_PORT: u16 = 6650;
const BROKER_MESSAGE: &str = "token expired";
const MAX_BROKER_STR: usize = 256;
async fn read_into<S: AsyncRead + Unpin>(
stream: &mut S,
buf: &mut BytesMut,
) -> std::io::Result<usize> {
let mut tmp = vec![0u8; 64 * 1024];
let n = stream.read(&mut tmp).await?;
buf.extend_from_slice(&tmp[..n]);
Ok(n)
}
async fn handle_reject_handshake_session<S>(mut stream: S, message: String) -> SimulationResult<()>
where
S: AsyncRead + AsyncWrite + Unpin + Send,
{
let mut read_buf = BytesMut::with_capacity(64 * 1024);
let mut saw_connect = false;
loop {
loop {
let mut framed = read_buf.clone().freeze();
let before = framed.len();
let frame = match decode_one(&mut framed) {
Ok(f) => f,
Err(FrameError::Incomplete { .. }) => break,
Err(_) => return Ok(()),
};
let consumed = before - framed.len();
let _ = read_buf.split_to(consumed);
if pb::base_command::Type::try_from(frame.command.r#type)
== Ok(pb::base_command::Type::Connect)
{
saw_connect = true;
}
}
if saw_connect {
let err = pb::BaseCommand {
r#type: pb::base_command::Type::Error as i32,
error: Some(pb::CommandError {
request_id: 0,
error: pb::ServerError::AuthenticationError as i32,
message: message.clone(),
}),
..Default::default()
};
let mut out = BytesMut::new();
let _ = encode_command(&mut out, &err);
if stream.write_all(&out).await.is_err() {
return Ok(());
}
let _ = stream.flush().await;
return Ok(());
}
match read_into(&mut stream, &mut read_buf).await {
Ok(0) | Err(_) => return Ok(()),
Ok(_) => {}
}
}
}
struct RejectHandshakeBroker {
sessions_handled: Arc<Mutex<u32>>,
message: String,
}
impl RejectHandshakeBroker {
fn new(message: String) -> Self {
Self {
sessions_handled: Arc::new(Mutex::new(0)),
message,
}
}
}
#[async_trait]
impl Workload for RejectHandshakeBroker {
fn name(&self) -> &str {
"broker"
}
async fn run(&mut self, ctx: &SimContext) -> SimulationResult<()> {
let network = ctx.network().clone();
let bind_addr = format!("{}:{BROKER_PORT}", ctx.my_ip());
let listener = network
.bind(&bind_addr)
.await
.map_err(|e| SimulationError::InvalidState(format!("broker bind: {e}")))?;
let shutdown = ctx.shutdown().clone();
let handled = self.sessions_handled.clone();
let task = ctx.providers().task().clone();
loop {
tokio::select! {
() = shutdown.cancelled() => return Ok(()),
inbound = listener.accept() => {
match inbound {
Ok((stream, _peer)) => {
*handled.lock() += 1;
let message = self.message.clone();
let _handle = task.spawn_task(
"reject-handshake-session",
async move {
let _ = handle_reject_handshake_session(stream, message).await;
},
);
}
Err(_) => return Ok(()),
}
}
}
}
}
}
struct HandshakeFailureClient {
saw_server_error: Arc<Mutex<bool>>,
saw_broker_message: Arc<Mutex<bool>>,
last_error: Arc<Mutex<Option<String>>>,
last_handshake_reason: Arc<Mutex<Option<String>>>,
}
impl HandshakeFailureClient {
fn new() -> Self {
Self {
saw_server_error: Arc::new(Mutex::new(false)),
saw_broker_message: Arc::new(Mutex::new(false)),
last_error: Arc::new(Mutex::new(None)),
last_handshake_reason: Arc::new(Mutex::new(None)),
}
}
}
#[async_trait]
impl Workload for HandshakeFailureClient {
fn name(&self) -> &str {
"client"
}
async fn run(&mut self, ctx: &SimContext) -> SimulationResult<()> {
let broker_ip = ctx
.peer("broker")
.ok_or_else(|| SimulationError::InvalidState("broker peer missing".into()))?;
let addr = format!("{broker_ip}:{BROKER_PORT}");
let engine = MoonpoolEngine::new(ctx.providers().clone());
let connect = tokio::time::timeout(
Duration::from_secs(20),
Client::connect_plain(&engine, &addr, ConnectionConfig::default()),
)
.await;
let Ok(result) = connect else {
return Ok(());
};
if let Err(ref err) = result {
*self.last_error.lock() = Some(format!("{err:?}"));
}
if let Err(ClientError::Engine(EngineError::HandshakeFailed(reason))) = result {
if reason.contains("AuthenticationError") {
*self.saw_server_error.lock() = true;
}
if reason.contains(BROKER_MESSAGE) {
*self.saw_broker_message.lock() = true;
}
*self.last_handshake_reason.lock() = Some(reason);
}
Ok(())
}
}
#[test]
fn connect_plain_surfaces_handshake_failure_reason_from_broker_command_error() {
let broker = RejectHandshakeBroker::new(BROKER_MESSAGE.to_owned());
let sessions_handled = broker.sessions_handled.clone();
let client = HandshakeFailureClient::new();
let saw_server_error = client.saw_server_error.clone();
let saw_broker_message = client.saw_broker_message.clone();
let last_error = client.last_error.clone();
let report = SimulationBuilder::new()
.workload(broker)
.workload(client)
.set_debug_seeds(vec![1, 2, 3, 42])
.set_iterations(4)
.run();
let handled = *sessions_handled.lock();
assert!(
handled >= 1,
"broker must have handled at least one inbound handshake \
(sessions_handled={handled}, report={report:?})",
);
let last = last_error.lock().clone();
assert!(
*saw_server_error.lock(),
"HandshakeFailed reason must mention the ServerError variant \
(\"AuthenticationError\") on at least one iteration \
(last_error={last:?}, report={report:?})",
);
assert!(
*saw_broker_message.lock(),
"HandshakeFailed reason must carry the verbatim broker message \
(\"{BROKER_MESSAGE}\") on at least one iteration \
(last_error={last:?}, report={report:?})",
);
}
#[test]
fn connect_plain_bounds_oversized_broker_handshake_message() {
let oversized = "é".repeat(400);
let broker = RejectHandshakeBroker::new(oversized.clone());
let sessions_handled = broker.sessions_handled.clone();
let client = HandshakeFailureClient::new();
let saw_server_error = client.saw_server_error.clone();
let last_reason = client.last_handshake_reason.clone();
let last_error = client.last_error.clone();
let report = SimulationBuilder::new()
.workload(broker)
.workload(client)
.set_debug_seeds(vec![1, 2, 3, 42])
.set_iterations(4)
.run();
let handled = *sessions_handled.lock();
assert!(
handled >= 1,
"broker must have handled at least one inbound handshake \
(sessions_handled={handled}, report={report:?})",
);
let last = last_error.lock().clone();
assert!(
*saw_server_error.lock(),
"HandshakeFailed reason must mention the ServerError variant on at least \
one iteration (last_error={last:?}, report={report:?})",
);
let reason = last_reason
.lock()
.clone()
.expect("at least one iteration must surface a HandshakeFailed reason");
let envelope_budget = MAX_BROKER_STR + 128;
assert!(
reason.len() <= envelope_budget,
"oversized broker handshake message must be bounded \
(reason len {} > budget {envelope_budget}): {reason}",
reason.len(),
);
let bounded_prefix: String = oversized.chars().take(64).collect();
assert!(
reason.contains(&bounded_prefix),
"a bounded prefix of the broker message must still surface (got: {reason})",
);
}
#[allow(dead_code)]
fn _engine_sim_providers_compiles(providers: SimProviders) {
let _engine: MoonpoolEngine<SimProviders> = MoonpoolEngine::new(providers);
}