use std::{
collections::HashMap,
future::Future,
io::Cursor,
ops::DerefMut,
sync::{
Arc,
atomic::{AtomicI32, Ordering},
},
task::Poll,
};
use futures::future::BoxFuture;
use parking_lot::Mutex;
use rsasl::{
mechname::MechanismNameError,
prelude::{Mechname, SASLError, SessionError},
};
use thiserror::Error;
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf},
sync::{
Mutex as AsyncMutex,
oneshot::{Sender, channel},
},
task::JoinHandle,
};
use tracing::{debug, info, warn};
use crate::protocol::{messages::ApiVersionsRequest, traits::ReadType};
use crate::{
backoff::ErrorOrThrottle,
protocol::{
api_key::ApiKey,
api_version::ApiVersion,
error::Error as ApiError,
frame::{AsyncMessageRead, AsyncMessageWrite},
messages::{
ReadVersionedError, ReadVersionedType, RequestBody, RequestHeader, ResponseHeader,
SaslAuthenticateRequest, SaslAuthenticateResponse, SaslHandshakeRequest,
SaslHandshakeResponse, WriteVersionedError, WriteVersionedType,
},
primitives::{Int16, Int32, NullableString, TaggedFields},
},
throttle::maybe_throttle,
};
use crate::{
client::SaslConfig,
protocol::{api_version::ApiVersionRange, primitives::CompactString},
};
#[derive(Debug)]
struct Response {
#[allow(dead_code)]
header: ResponseHeader,
data: Cursor<Vec<u8>>,
}
#[derive(Debug)]
struct ActiveRequest {
channel: Sender<Result<Response, RequestError>>,
use_tagged_fields_in_response: bool,
}
#[derive(Debug)]
enum MessengerState {
RequestMap(HashMap<i32, ActiveRequest>),
Poison(Arc<RequestError>),
}
impl MessengerState {
fn poison(&mut self, err: RequestError) -> Arc<RequestError> {
match self {
Self::RequestMap(map) => {
let err = Arc::new(err);
for (_correlation_id, active_request) in map.drain() {
active_request
.channel
.send(Err(RequestError::Poisoned(Arc::clone(&err))))
.ok();
}
*self = Self::Poison(Arc::clone(&err));
err
}
Self::Poison(e) => {
Arc::clone(e)
}
}
}
}
#[derive(Debug)]
pub struct Messenger<RW> {
stream_write: Arc<AsyncMutex<WriteHalf<RW>>>,
client_id: Arc<str>,
correlation_id: AtomicI32,
version_ranges: HashMap<ApiKey, ApiVersionRange>,
state: Arc<Mutex<MessengerState>>,
join_handle: JoinHandle<()>,
}
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum RequestError {
#[error("Cannot find matching version for: {api_key:?}")]
NoVersionMatch { api_key: ApiKey },
#[error("Cannot write data: {0}")]
WriteError(#[from] WriteVersionedError),
#[error("Cannot write versioned data: {0}")]
WriteMessageError(#[from] crate::protocol::frame::WriteError),
#[error("Cannot read data: {0}")]
ReadError(#[from] crate::protocol::traits::ReadError),
#[error("Cannot read versioned data: {0}")]
ReadVersionedError(#[from] ReadVersionedError),
#[error("Cannot read/write data: {0}")]
IO(#[from] std::io::Error),
#[error(
"Data left at the end of the message. Got {message_size} bytes but only read {read} bytes. api_key={api_key:?} api_version={api_version}"
)]
TooMuchData {
message_size: u64,
read: u64,
api_key: ApiKey,
api_version: ApiVersion,
},
#[error("Cannot read framed message: {0}")]
ReadFramedMessageError(#[from] crate::protocol::frame::ReadError),
#[error("Connection is poisoned: {0}")]
Poisoned(Arc<RequestError>),
}
#[derive(Error, Debug)]
#[non_exhaustive]
pub enum SyncVersionsError {
#[error("Did not found a version for ApiVersion that works with that broker")]
NoWorkingVersion,
#[error("Request error: {0}")]
RequestError(#[from] RequestError),
#[error("Got flipped version from server for API key {api_key:?}: min={min:?} max={max:?}")]
FlippedVersionRange {
api_key: ApiKey,
min: ApiVersion,
max: ApiVersion,
},
}
#[derive(Error, Debug)]
pub enum SaslError {
#[error("Request error: {0}")]
RequestError(#[from] RequestError),
#[error("API error: {0}")]
ApiError(#[from] ApiError),
#[error("Invalid sasl mechanism: {0}")]
InvalidSaslMechanism(#[from] MechanismNameError),
#[error("Sasl session error: {0}")]
SaslSessionError(#[from] SessionError),
#[error("Invalid SASL config: {0}")]
InvalidConfig(#[from] SASLError),
#[error("Error in user defined callback: {0}")]
Callback(Box<dyn std::error::Error + Send + Sync>),
#[error("unsupported sasl mechanism")]
UnsupportedSaslMechanism,
}
impl<RW> Messenger<RW>
where
RW: AsyncRead + AsyncWrite + Send + 'static,
{
pub fn new(stream: RW, max_message_size: usize, client_id: Arc<str>) -> Self {
let (stream_read, stream_write) = tokio::io::split(stream);
let state = Arc::new(Mutex::new(MessengerState::RequestMap(HashMap::default())));
let state_captured = Arc::clone(&state);
let join_handle = tokio::spawn(async move {
let mut stream_read = stream_read;
loop {
match stream_read.read_message(max_message_size).await {
Ok(msg) => {
let mut cursor = Cursor::new(msg);
let mut header =
match ResponseHeader::read_versioned(&mut cursor, ApiVersion(Int16(0)))
{
Ok(header) => header,
Err(e) => {
warn!(%e, "Cannot read message header, ignoring message");
continue;
}
};
let active_request = match state_captured.lock().deref_mut() {
MessengerState::RequestMap(map) => {
match map.remove(&header.correlation_id.0) {
Some(active_request) => active_request,
_ => {
warn!(
correlation_id = header.correlation_id.0,
"Got response for unknown request",
);
continue;
}
}
}
MessengerState::Poison(_) => {
return;
}
};
if active_request.use_tagged_fields_in_response {
header.tagged_fields = match TaggedFields::read(&mut cursor) {
Ok(fields) => Some(fields),
Err(e) => {
active_request
.channel
.send(Err(RequestError::ReadError(e)))
.ok();
continue;
}
};
}
active_request
.channel
.send(Ok(Response {
header,
data: cursor,
}))
.ok();
}
Err(e) => {
state_captured
.lock()
.poison(RequestError::ReadFramedMessageError(e));
return;
}
}
}
});
Self {
stream_write: Arc::new(AsyncMutex::new(stream_write)),
client_id,
correlation_id: AtomicI32::new(0),
version_ranges: HashMap::new(),
state,
join_handle,
}
}
#[cfg(feature = "unstable-fuzzing")]
pub fn override_version_ranges(&mut self, ranges: HashMap<ApiKey, ApiVersionRange>) {
self.set_version_ranges(ranges);
}
fn set_version_ranges(&mut self, ranges: HashMap<ApiKey, ApiVersionRange>) {
self.version_ranges = ranges;
}
pub async fn request<R>(&self, msg: R) -> Result<R::ResponseBody, RequestError>
where
R: RequestBody + Send + WriteVersionedType<Vec<u8>>,
R::ResponseBody: ReadVersionedType<Cursor<Vec<u8>>>,
{
self.request_with_version_ranges(msg, &self.version_ranges)
.await
}
async fn request_with_version_ranges<R>(
&self,
msg: R,
version_ranges: &HashMap<ApiKey, ApiVersionRange>,
) -> Result<R::ResponseBody, RequestError>
where
R: RequestBody + Send + WriteVersionedType<Vec<u8>>,
R::ResponseBody: ReadVersionedType<Cursor<Vec<u8>>>,
{
let body_api_version = version_ranges
.get(&R::API_KEY)
.and_then(|range_server| match_versions(*range_server, R::API_VERSION_RANGE))
.ok_or(RequestError::NoVersionMatch {
api_key: R::API_KEY,
})?;
let use_tagged_fields_in_request =
body_api_version >= R::FIRST_TAGGED_FIELD_IN_REQUEST_VERSION;
let use_tagged_fields_in_response =
body_api_version >= R::FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION;
let correlation_id = self.correlation_id.fetch_add(1, Ordering::SeqCst);
let header = RequestHeader {
request_api_key: R::API_KEY,
request_api_version: body_api_version,
correlation_id: Int32(correlation_id),
client_id: Some(NullableString(Some(String::from(self.client_id.as_ref())))),
tagged_fields: Some(TaggedFields::default()),
};
let header_version = if use_tagged_fields_in_request {
ApiVersion(Int16(2))
} else {
ApiVersion(Int16(1))
};
let mut buf = Vec::new();
header
.write_versioned(&mut buf, header_version)
.expect("Writing header to buffer should always work");
msg.write_versioned(&mut buf, body_api_version)?;
let (tx, rx) = channel();
let cleanup_on_cancel =
CleanupRequestStateOnCancel::new(Arc::clone(&self.state), correlation_id);
match self.state.lock().deref_mut() {
MessengerState::RequestMap(map) => {
map.insert(
correlation_id,
ActiveRequest {
channel: tx,
use_tagged_fields_in_response,
},
);
}
MessengerState::Poison(e) => {
return Err(RequestError::Poisoned(Arc::clone(e)));
}
}
self.send_message(buf).await?;
cleanup_on_cancel.message_sent();
let mut response = rx.await.expect("Who closed this channel?!")?;
let body = R::ResponseBody::read_versioned(&mut response.data, body_api_version)?;
let read_bytes = response.data.position();
let message_bytes = response.data.into_inner().len() as u64;
if read_bytes != message_bytes {
return Err(RequestError::TooMuchData {
message_size: message_bytes,
read: read_bytes,
api_key: R::API_KEY,
api_version: body_api_version,
});
}
Ok(body)
}
async fn send_message(&self, msg: Vec<u8>) -> Result<(), RequestError> {
match self.send_message_inner(msg).await {
Ok(()) => Ok(()),
Err(e) => {
let mut state = self.state.lock();
Err(RequestError::Poisoned(state.poison(e)))
}
}
}
async fn send_message_inner(&self, msg: Vec<u8>) -> Result<(), RequestError> {
let mut stream_write = Arc::clone(&self.stream_write).lock_owned().await;
let fut = CancellationSafeFuture::new(async move {
stream_write.write_message(&msg).await?;
stream_write.flush().await?;
Ok(())
});
fut.await
}
pub async fn sync_versions(&mut self) -> Result<(), SyncVersionsError> {
'iter_upper_bound: for upper_bound in (ApiVersionsRequest::API_VERSION_RANGE.min().0.0
..=ApiVersionsRequest::API_VERSION_RANGE.max().0.0)
.rev()
{
let version_ranges = HashMap::from([(
ApiKey::ApiVersions,
ApiVersionRange::new(
ApiVersionsRequest::API_VERSION_RANGE.min(),
ApiVersion(Int16(upper_bound)),
),
)]);
let body = ApiVersionsRequest {
client_software_name: Some(CompactString(String::from(env!("CARGO_PKG_NAME")))),
client_software_version: Some(CompactString(String::from(env!(
"CARGO_PKG_VERSION"
)))),
tagged_fields: Some(TaggedFields::default()),
};
'throttle: loop {
match self
.request_with_version_ranges(&body, &version_ranges)
.await
{
Ok(response) => {
if let Err(ErrorOrThrottle::Throttle(throttle)) =
maybe_throttle::<SyncVersionsError>(response.throttle_time_ms)
{
info!(
?throttle,
request_name = "version sync",
"broker asked us to throttle"
);
tokio::time::sleep(throttle).await;
continue 'throttle;
}
if let Some(e) = response.error_code {
debug!(
%e,
version=upper_bound,
"Got error during version sync, cannot use version for ApiVersionRequest",
);
continue 'iter_upper_bound;
}
for api_key in &response.api_keys {
if api_key.min_version.0 > api_key.max_version.0 {
return Err(SyncVersionsError::FlippedVersionRange {
api_key: api_key.api_key,
min: api_key.min_version,
max: api_key.max_version,
});
}
}
let ranges = response
.api_keys
.into_iter()
.map(|x| {
(
x.api_key,
ApiVersionRange::new(x.min_version, x.max_version),
)
})
.collect();
debug!(
versions=%sorted_ranges_repr(&ranges),
"Detected supported broker versions",
);
self.set_version_ranges(ranges);
return Ok(());
}
Err(RequestError::NoVersionMatch { .. }) => {
unreachable!("Just set to version range to a non-empty range")
}
Err(RequestError::ReadVersionedError(e)) => {
debug!(
%e,
version=upper_bound,
"Cannot read ApiVersionResponse for version",
);
continue 'iter_upper_bound;
}
Err(RequestError::ReadError(e)) => {
debug!(
%e,
version=upper_bound,
"Cannot read ApiVersionResponse for version",
);
continue 'iter_upper_bound;
}
Err(e @ RequestError::TooMuchData { .. }) => {
debug!(
%e,
version=upper_bound,
"Cannot read ApiVersionResponse for version",
);
continue 'iter_upper_bound;
}
Err(e) => {
return Err(SyncVersionsError::RequestError(e));
}
}
}
}
Err(SyncVersionsError::NoWorkingVersion)
}
async fn sasl_authentication(
&self,
auth_bytes: Vec<u8>,
) -> Result<SaslAuthenticateResponse, SaslError> {
let req = SaslAuthenticateRequest::new(auth_bytes);
let resp = self.request(req).await?;
if let Some(err) = resp.error_code {
if let Some(s) = resp.error_message.0 {
debug!("Sasl auth error message: {s}");
}
return Err(SaslError::ApiError(err));
}
Ok(resp)
}
async fn sasl_handshake(&self, mechanism: &str) -> Result<SaslHandshakeResponse, SaslError> {
let req = SaslHandshakeRequest::new(mechanism);
let resp = self.request(req).await?;
if let Some(err) = resp.error_code {
return Err(SaslError::ApiError(err));
}
Ok(resp)
}
pub async fn do_sasl(&self, config: SaslConfig) -> Result<(), SaslError> {
let mechanism = config.mechanism();
let resp = self.sasl_handshake(mechanism).await?;
let config = config.get_sasl_config().await?;
let sasl = rsasl::prelude::SASLClient::new(config);
let raw_mechanisms = resp.mechanisms.0.unwrap_or_default();
let mechanisms = raw_mechanisms
.iter()
.map(|mech| Mechname::parse(mech.0.as_bytes()).map_err(SaslError::InvalidSaslMechanism))
.collect::<Result<Vec<_>, SaslError>>()?;
debug!(?mechanisms, "Supported SASL mechanisms");
let prefer_mechanism =
Mechname::parse(mechanism.as_bytes()).map_err(SaslError::InvalidSaslMechanism)?;
if !mechanisms.contains(&prefer_mechanism) {
return Err(SaslError::UnsupportedSaslMechanism);
}
let mut session = sasl
.start_suggested(&[prefer_mechanism])
.map_err(|_| SaslError::UnsupportedSaslMechanism)?;
debug!(?mechanism, "Using SASL Mechanism");
let mut data_received: Option<Vec<u8>> = None;
loop {
let mut to_sent = Cursor::new(Vec::new());
let state = session.step(data_received.as_deref(), &mut to_sent)?;
if state.has_sent_message() {
let authentication_response =
self.sasl_authentication(to_sent.into_inner()).await?;
data_received = Some(authentication_response.auth_bytes.0);
}
if state.is_finished() {
break;
}
}
Ok(())
}
}
impl<RW> Drop for Messenger<RW> {
fn drop(&mut self) {
self.join_handle.abort();
}
}
fn sorted_ranges_repr(ranges: &HashMap<ApiKey, ApiVersionRange>) -> String {
let mut ranges: Vec<_> = ranges.iter().map(|(key, range)| (*key, *range)).collect();
ranges.sort_by_key(|(key, _range)| *key);
let ranges: Vec<_> = ranges
.into_iter()
.map(|(key, range)| format!("{:?}: {}", key, range))
.collect();
ranges.join(", ")
}
fn match_versions(range_a: ApiVersionRange, range_b: ApiVersionRange) -> Option<ApiVersion> {
if range_a.min() <= range_b.max() && range_b.min() <= range_a.max() {
Some(range_a.max().min(range_b.max()))
} else {
None
}
}
struct CleanupRequestStateOnCancel {
state: Arc<Mutex<MessengerState>>,
correlation_id: i32,
message_sent: bool,
}
impl CleanupRequestStateOnCancel {
fn new(state: Arc<Mutex<MessengerState>>, correlation_id: i32) -> Self {
Self {
state,
correlation_id,
message_sent: false,
}
}
fn message_sent(mut self) {
self.message_sent = true;
}
}
impl Drop for CleanupRequestStateOnCancel {
fn drop(&mut self) {
if !self.message_sent {
if let MessengerState::RequestMap(map) = self.state.lock().deref_mut() {
map.remove(&self.correlation_id);
}
}
}
}
struct CancellationSafeFuture<F>
where
F: Future + Send + 'static,
{
done: bool,
inner: Option<BoxFuture<'static, F::Output>>,
}
impl<F> Drop for CancellationSafeFuture<F>
where
F: Future + Send + 'static,
{
fn drop(&mut self) {
if !self.done {
let inner = self.inner.take().expect("Double-drop?");
tokio::task::spawn(async move {
inner.await;
});
}
}
}
impl<F> CancellationSafeFuture<F>
where
F: Future + Send,
{
fn new(fut: F) -> Self {
Self {
done: false,
inner: Some(Box::pin(fut)),
}
}
}
impl<F> Future for CancellationSafeFuture<F>
where
F: Future + Send,
{
type Output = F::Output;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
match self.inner.as_mut().expect("no dropped").as_mut().poll(cx) {
Poll::Ready(res) => {
self.done = true;
Poll::Ready(res)
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use assert_matches::assert_matches;
use futures::{FutureExt, pin_mut};
use tokio::{
io::{AsyncReadExt, DuplexStream},
sync::{Barrier, mpsc::UnboundedSender},
};
use super::*;
use crate::{
build_info::DEFAULT_CLIENT_ID,
protocol::{
error::Error as ApiError,
messages::{
ApiVersionsResponse, ApiVersionsResponseApiKey, ListOffsetsRequest, NORMAL_CONSUMER,
},
traits::WriteType,
},
};
#[test]
fn test_match_versions() {
assert_eq!(
match_versions(
ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(20))),
ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(20))),
),
Some(ApiVersion(Int16(20))),
);
assert_eq!(
match_versions(
ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(15))),
ApiVersionRange::new(ApiVersion(Int16(13)), ApiVersion(Int16(20))),
),
Some(ApiVersion(Int16(15))),
);
assert_eq!(
match_versions(
ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(15))),
ApiVersionRange::new(ApiVersion(Int16(15)), ApiVersion(Int16(20))),
),
Some(ApiVersion(Int16(15))),
);
assert_eq!(
match_versions(
ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(14))),
ApiVersionRange::new(ApiVersion(Int16(15)), ApiVersion(Int16(20))),
),
None,
);
}
#[tokio::test]
async fn test_sync_versions_ok() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut msg = vec![];
ResponseHeader {
correlation_id: Int32(0),
tagged_fields: Default::default(), }
.write_versioned(&mut msg, ApiVersion(Int16(0)))
.unwrap();
ApiVersionsResponse {
error_code: None,
api_keys: vec![ApiVersionsResponseApiKey {
api_key: ApiKey::Produce,
min_version: ApiVersion(Int16(1)),
max_version: ApiVersion(Int16(5)),
tagged_fields: Default::default(),
}],
throttle_time_ms: None,
tagged_fields: None,
}
.write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
.unwrap();
sim.push(msg);
messenger.sync_versions().await.unwrap();
let expected = HashMap::from([(
(ApiKey::Produce),
ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
)]);
assert_eq!(messenger.version_ranges, expected);
}
#[tokio::test]
async fn test_sync_versions_ignores_error_code() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut msg = vec![];
ResponseHeader {
correlation_id: Int32(0),
tagged_fields: Default::default(), }
.write_versioned(&mut msg, ApiVersion(Int16(0)))
.unwrap();
ApiVersionsResponse {
error_code: Some(ApiError::CorruptMessage),
api_keys: vec![ApiVersionsResponseApiKey {
api_key: ApiKey::Produce,
min_version: ApiVersion(Int16(2)),
max_version: ApiVersion(Int16(3)),
tagged_fields: Default::default(),
}],
throttle_time_ms: None,
tagged_fields: None,
}
.write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
.unwrap();
sim.push(msg);
let mut msg = vec![];
ResponseHeader {
correlation_id: Int32(1),
tagged_fields: Default::default(),
}
.write_versioned(&mut msg, ApiVersion(Int16(0)))
.unwrap();
ApiVersionsResponse {
error_code: None,
api_keys: vec![ApiVersionsResponseApiKey {
api_key: ApiKey::Produce,
min_version: ApiVersion(Int16(1)),
max_version: ApiVersion(Int16(5)),
tagged_fields: Default::default(),
}],
throttle_time_ms: None,
tagged_fields: None,
}
.write_versioned(
&mut msg,
ApiVersion(Int16(ApiVersionsRequest::API_VERSION_RANGE.max().0.0 - 1)),
)
.unwrap();
sim.push(msg);
messenger.sync_versions().await.unwrap();
let expected = HashMap::from([(
(ApiKey::Produce),
ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
)]);
assert_eq!(messenger.version_ranges, expected);
}
#[tokio::test]
async fn test_sync_versions_ignores_read_code() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut msg = vec![];
ResponseHeader {
correlation_id: Int32(0),
tagged_fields: Default::default(), }
.write_versioned(&mut msg, ApiVersion(Int16(0)))
.unwrap();
msg.push(b'\0'); sim.push(msg);
let mut msg = vec![];
ResponseHeader {
correlation_id: Int32(1),
tagged_fields: Default::default(),
}
.write_versioned(&mut msg, ApiVersion(Int16(0)))
.unwrap();
ApiVersionsResponse {
error_code: None,
api_keys: vec![ApiVersionsResponseApiKey {
api_key: ApiKey::Produce,
min_version: ApiVersion(Int16(1)),
max_version: ApiVersion(Int16(5)),
tagged_fields: Default::default(),
}],
throttle_time_ms: None,
tagged_fields: None,
}
.write_versioned(
&mut msg,
ApiVersion(Int16(ApiVersionsRequest::API_VERSION_RANGE.max().0.0 - 1)),
)
.unwrap();
sim.push(msg);
messenger.sync_versions().await.unwrap();
let expected = HashMap::from([(
(ApiKey::Produce),
ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
)]);
assert_eq!(messenger.version_ranges, expected);
}
#[tokio::test]
async fn test_sync_versions_err_flipped_range() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut msg = vec![];
ResponseHeader {
correlation_id: Int32(0),
tagged_fields: Default::default(), }
.write_versioned(&mut msg, ApiVersion(Int16(0)))
.unwrap();
ApiVersionsResponse {
error_code: None,
api_keys: vec![ApiVersionsResponseApiKey {
api_key: ApiKey::Produce,
min_version: ApiVersion(Int16(2)),
max_version: ApiVersion(Int16(1)),
tagged_fields: Default::default(),
}],
throttle_time_ms: None,
tagged_fields: None,
}
.write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
.unwrap();
sim.push(msg);
let err = messenger.sync_versions().await.unwrap_err();
assert_matches!(err, SyncVersionsError::FlippedVersionRange { .. });
}
#[tokio::test]
async fn test_sync_versions_ignores_garbage() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let mut msg = vec![];
ResponseHeader {
correlation_id: Int32(0),
tagged_fields: Default::default(), }
.write_versioned(&mut msg, ApiVersion(Int16(0)))
.unwrap();
ApiVersionsResponse {
error_code: None,
api_keys: vec![ApiVersionsResponseApiKey {
api_key: ApiKey::Produce,
min_version: ApiVersion(Int16(1)),
max_version: ApiVersion(Int16(2)),
tagged_fields: Default::default(),
}],
throttle_time_ms: None,
tagged_fields: None,
}
.write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
.unwrap();
msg.push(b'\0'); sim.push(msg);
let mut msg = vec![];
ResponseHeader {
correlation_id: Int32(1),
tagged_fields: Default::default(),
}
.write_versioned(&mut msg, ApiVersion(Int16(0)))
.unwrap();
ApiVersionsResponse {
error_code: None,
api_keys: vec![ApiVersionsResponseApiKey {
api_key: ApiKey::Produce,
min_version: ApiVersion(Int16(1)),
max_version: ApiVersion(Int16(5)),
tagged_fields: Default::default(),
}],
throttle_time_ms: None,
tagged_fields: None,
}
.write_versioned(
&mut msg,
ApiVersion(Int16(ApiVersionsRequest::API_VERSION_RANGE.max().0.0 - 1)),
)
.unwrap();
sim.push(msg);
messenger.sync_versions().await.unwrap();
let expected = HashMap::from([(
(ApiKey::Produce),
ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
)]);
assert_eq!(messenger.version_ranges, expected);
}
#[tokio::test]
async fn test_sync_versions_err_no_working_version() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
for (i, v) in ((ApiVersionsRequest::API_VERSION_RANGE.min().0.0)
..=(ApiVersionsRequest::API_VERSION_RANGE.max().0.0))
.rev()
.enumerate()
{
let mut msg = vec![];
ResponseHeader {
correlation_id: Int32(i as i32),
tagged_fields: Default::default(),
}
.write_versioned(&mut msg, ApiVersion(Int16(0)))
.unwrap();
ApiVersionsResponse {
error_code: Some(ApiError::CorruptMessage),
api_keys: vec![ApiVersionsResponseApiKey {
api_key: ApiKey::Produce,
min_version: ApiVersion(Int16(1)),
max_version: ApiVersion(Int16(5)),
tagged_fields: Default::default(),
}],
throttle_time_ms: None,
tagged_fields: None,
}
.write_versioned(&mut msg, ApiVersion(Int16(v)))
.unwrap();
sim.push(msg);
}
let err = messenger.sync_versions().await.unwrap_err();
assert_matches!(err, SyncVersionsError::NoWorkingVersion);
}
#[tokio::test]
async fn test_poison_hangup() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
messenger.set_version_ranges(HashMap::from([(
ApiKey::ListOffsets,
ListOffsetsRequest::API_VERSION_RANGE,
)]));
sim.hang_up();
let err = messenger
.request(ListOffsetsRequest {
replica_id: NORMAL_CONSUMER,
isolation_level: None,
topics: vec![],
})
.await
.unwrap_err();
assert_matches!(err, RequestError::Poisoned(_));
}
#[tokio::test]
async fn test_poison_negative_message_size() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
messenger.set_version_ranges(HashMap::from([(
ApiKey::ListOffsets,
ListOffsetsRequest::API_VERSION_RANGE,
)]));
sim.negative_message_size();
let err = messenger
.request(ListOffsetsRequest {
replica_id: NORMAL_CONSUMER,
isolation_level: None,
topics: vec![],
})
.await
.unwrap_err();
assert_matches!(err, RequestError::Poisoned(_));
let err = messenger
.request(ListOffsetsRequest {
replica_id: NORMAL_CONSUMER,
isolation_level: None,
topics: vec![],
})
.await
.unwrap_err();
assert_matches!(err, RequestError::Poisoned(_));
}
#[tokio::test]
async fn test_broken_msg_header_does_not_poison() {
let (sim, rx) = MessageSimulator::new();
let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
messenger.set_version_ranges(HashMap::from([(
ApiKey::ApiVersions,
ApiVersionsRequest::API_VERSION_RANGE,
)]));
sim.send(b"foo".to_vec());
let mut msg = vec![];
ResponseHeader {
correlation_id: Int32(0),
tagged_fields: Default::default(), }
.write_versioned(&mut msg, ApiVersion(Int16(0)))
.unwrap();
let resp = ApiVersionsResponse {
error_code: Some(ApiError::CorruptMessage),
api_keys: vec![],
throttle_time_ms: Some(Int32(1)),
tagged_fields: Some(TaggedFields::default()),
};
resp.write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
.unwrap();
sim.push(msg);
let actual = messenger
.request(ApiVersionsRequest {
client_software_name: Some(CompactString(String::new())),
client_software_version: Some(CompactString(String::new())),
tagged_fields: Some(TaggedFields::default()),
})
.await
.unwrap();
assert_eq!(actual, resp);
}
#[tokio::test]
async fn test_cancel_request() {
let (tx_front, rx_middle) = tokio::io::duplex(1);
let (tx_middle, mut rx_back) = tokio::io::duplex(1);
let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID));
let network_pause = Arc::new(Barrier::new(2));
let network_pause_captured = Arc::clone(&network_pause);
let network_continue = Arc::new(Barrier::new(2));
let network_continue_captured = Arc::clone(&network_continue);
let handle_network = tokio::spawn(async move {
let (mut rx_middle_read, mut rx_middle_write) = tokio::io::split(rx_middle);
let (mut tx_middle_read, mut tx_middle_write) = tokio::io::split(tx_middle);
let direction_client_broker = async {
for i in 0.. {
let mut buf = [0; 1];
rx_middle_read.read_exact(&mut buf).await.unwrap();
tx_middle_write.write_all(&buf).await.unwrap();
if i == 3 {
network_pause_captured.wait().await;
network_continue_captured.wait().await;
}
}
};
let direction_broker_client = async {
loop {
let mut buf = [0; 1];
tx_middle_read.read_exact(&mut buf).await.unwrap();
rx_middle_write.write_all(&buf).await.unwrap();
}
};
tokio::select! {
_ = direction_client_broker => {}
_ = direction_broker_client => {}
}
});
let handle_broker = tokio::spawn(async move {
for correlation_id in 0.. {
let data = rx_back.read_message(1_000).await.unwrap();
let mut data = Cursor::new(data);
let header =
RequestHeader::read_versioned(&mut data, ApiVersion(Int16(1))).unwrap();
assert_eq!(
header,
RequestHeader {
request_api_key: ApiKey::ApiVersions,
request_api_version: ApiVersion(Int16(0)),
correlation_id: Int32(correlation_id),
client_id: Some(NullableString(Some(String::from(env!("CARGO_PKG_NAME"))))),
tagged_fields: None,
}
);
let body =
ApiVersionsRequest::read_versioned(&mut data, ApiVersion(Int16(0))).unwrap();
assert_eq!(
body,
ApiVersionsRequest {
client_software_name: None,
client_software_version: None,
tagged_fields: None,
}
);
assert_eq!(data.position() as usize, data.get_ref().len());
let mut msg = vec![];
ResponseHeader {
correlation_id: Int32(correlation_id),
tagged_fields: Default::default(), }
.write_versioned(&mut msg, ApiVersion(Int16(0)))
.unwrap();
let resp = ApiVersionsResponse {
error_code: Some(ApiError::CorruptMessage),
api_keys: vec![],
throttle_time_ms: Some(Int32(1)),
tagged_fields: Some(TaggedFields::default()),
};
resp.write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.min())
.unwrap();
rx_back.write_message(&msg).await.unwrap();
}
});
messenger.set_version_ranges(HashMap::from([(
ApiKey::ApiVersions,
ApiVersionRange::new(ApiVersion(Int16(0)), ApiVersion(Int16(0))),
)]));
let task_to_cancel = (async {
messenger
.request(ApiVersionsRequest {
client_software_name: Some(CompactString(String::from("foo"))),
client_software_version: Some(CompactString(String::from("bar"))),
tagged_fields: Some(TaggedFields::default()),
})
.await
.unwrap();
})
.fuse();
{
pin_mut!(task_to_cancel);
futures::select_biased! {
_ = &mut task_to_cancel => panic!("should not have finished"),
_ = network_pause.wait().fuse() => {},
}
}
network_continue.wait().await;
tokio::time::timeout(Duration::from_millis(100), async {
messenger
.request(ApiVersionsRequest {
client_software_name: Some(CompactString(String::from("foo"))),
client_software_version: Some(CompactString(String::from("bar"))),
tagged_fields: Some(TaggedFields::default()),
})
.await
.unwrap();
})
.await
.unwrap();
handle_broker.abort();
handle_network.abort();
}
#[derive(Debug)]
enum Message {
Send(Vec<u8>),
Consume,
NegativeMessageSize,
HangUp,
}
struct MessageSimulator {
messages: UnboundedSender<Message>,
join_handle: JoinHandle<()>,
}
impl MessageSimulator {
fn new() -> (Self, DuplexStream) {
let (mut tx, rx) = tokio::io::duplex(1_000);
let (msg_tx, mut msg_rx) = tokio::sync::mpsc::unbounded_channel();
let join_handle = tokio::task::spawn(async move {
loop {
let message = match msg_rx.recv().await {
Some(msg) => msg,
None => return,
};
match message {
Message::Consume => {
tx.read_message(1_000).await.unwrap();
}
Message::Send(data) => {
tx.write_message(&data).await.unwrap();
}
Message::NegativeMessageSize => {
let mut buf = vec![];
Int32(-1).write(&mut buf).unwrap();
tx.write_all(&buf).await.unwrap()
}
Message::HangUp => {
return;
}
}
}
});
let this = Self {
messages: msg_tx,
join_handle,
};
(this, rx)
}
fn push(&self, msg: Vec<u8>) {
self.consume();
self.send(msg);
}
fn consume(&self) {
self.messages.send(Message::Consume).unwrap();
}
fn send(&self, msg: Vec<u8>) {
self.messages.send(Message::Send(msg)).unwrap();
}
fn negative_message_size(&self) {
self.messages.send(Message::NegativeMessageSize).unwrap();
}
fn hang_up(&self) {
self.messages.send(Message::HangUp).unwrap();
}
}
impl Drop for MessageSimulator {
fn drop(&mut self) {
self.join_handle.abort();
}
}
}