1use std::{
2 collections::HashMap,
3 future::Future,
4 io::Cursor,
5 ops::DerefMut,
6 sync::{
7 atomic::{AtomicI32, Ordering},
8 Arc,
9 },
10 task::Poll,
11};
12
13use futures::future::BoxFuture;
14use parking_lot::Mutex;
15use thiserror::Error;
16use tokio::{
17 io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf},
18 sync::{
19 oneshot::{channel, Sender},
20 Mutex as AsyncMutex,
21 },
22 task::JoinHandle,
23};
24use tracing::{debug, info, warn};
25
26use crate::protocol::{api_version::ApiVersionRange, primitives::CompactString};
27use crate::protocol::{messages::ApiVersionsRequest, traits::ReadType};
28use crate::{
29 backoff::ErrorOrThrottle,
30 protocol::{
31 api_key::ApiKey,
32 api_version::ApiVersion,
33 frame::{AsyncMessageRead, AsyncMessageWrite},
34 messages::{
35 ReadVersionedError, ReadVersionedType, RequestBody, RequestHeader, ResponseHeader,
36 WriteVersionedError, WriteVersionedType,
37 },
38 primitives::{Int16, Int32, NullableString, TaggedFields},
39 },
40 throttle::maybe_throttle,
41};
42
43#[derive(Debug)]
44struct Response {
45 #[allow(dead_code)]
46 header: ResponseHeader,
47 data: Cursor<Vec<u8>>,
48}
49
50#[derive(Debug)]
51struct ActiveRequest {
52 channel: Sender<Result<Response, RequestError>>,
53 use_tagged_fields_in_response: bool,
54}
55
56#[derive(Debug)]
57enum MessengerState {
58 RequestMap(HashMap<i32, ActiveRequest>),
62
63 Poison(Arc<RequestError>),
65}
66
67impl MessengerState {
68 fn poison(&mut self, err: RequestError) -> Arc<RequestError> {
69 match self {
70 Self::RequestMap(map) => {
71 let err = Arc::new(err);
72
73 for (_correlation_id, active_request) in map.drain() {
75 active_request
77 .channel
78 .send(Err(RequestError::Poisoned(Arc::clone(&err))))
79 .ok();
80 }
81
82 *self = Self::Poison(Arc::clone(&err));
83 err
84 }
85 Self::Poison(e) => {
86 Arc::clone(e)
88 }
89 }
90 }
91}
92
93#[derive(Debug)]
98pub struct Messenger<RW> {
99 stream_write: Arc<AsyncMutex<WriteHalf<RW>>>,
103
104 client_id: Arc<str>,
106
107 correlation_id: AtomicI32,
111
112 version_ranges: HashMap<ApiKey, ApiVersionRange>,
116
117 state: Arc<Mutex<MessengerState>>,
121
122 join_handle: JoinHandle<()>,
124}
125
126#[derive(Error, Debug)]
127#[non_exhaustive]
128pub enum RequestError {
129 #[error("Cannot find matching version for: {api_key:?}")]
130 NoVersionMatch { api_key: ApiKey },
131
132 #[error("Cannot write data: {0}")]
133 WriteError(#[from] WriteVersionedError),
134
135 #[error("Cannot write versioned data: {0}")]
136 WriteMessageError(#[from] crate::protocol::frame::WriteError),
137
138 #[error("Cannot read data: {0}")]
139 ReadError(#[from] crate::protocol::traits::ReadError),
140
141 #[error("Cannot read versioned data: {0}")]
142 ReadVersionedError(#[from] ReadVersionedError),
143
144 #[error("Cannot read/write data: {0}")]
145 IO(#[from] std::io::Error),
146
147 #[error(
148 "Data left at the end of the message. Got {message_size} bytes but only read {read} bytes. api_key={api_key:?} api_version={api_version}"
149 )]
150 TooMuchData {
151 message_size: u64,
152 read: u64,
153 api_key: ApiKey,
154 api_version: ApiVersion,
155 },
156
157 #[error("Cannot read framed message: {0}")]
158 ReadFramedMessageError(#[from] crate::protocol::frame::ReadError),
159
160 #[error("Connection is poisoned: {0}")]
161 Poisoned(Arc<RequestError>),
162}
163
164#[derive(Error, Debug)]
165#[non_exhaustive]
166pub enum SyncVersionsError {
167 #[error("Did not found a version for ApiVersion that works with that broker")]
168 NoWorkingVersion,
169
170 #[error("Request error: {0}")]
171 RequestError(#[from] RequestError),
172
173 #[error("Got flipped version from server for API key {api_key:?}: min={min:?} max={max:?}")]
174 FlippedVersionRange {
175 api_key: ApiKey,
176 min: ApiVersion,
177 max: ApiVersion,
178 },
179}
180
181impl<RW> Messenger<RW>
182where
183 RW: AsyncRead + AsyncWrite + Send + 'static,
184{
185 pub fn new(stream: RW, max_message_size: usize, client_id: Arc<str>) -> Self {
186 let (stream_read, stream_write) = tokio::io::split(stream);
187 let state = Arc::new(Mutex::new(MessengerState::RequestMap(HashMap::default())));
188 let state_captured = Arc::clone(&state);
189
190 let join_handle = tokio::spawn(async move {
191 let mut stream_read = stream_read;
192
193 loop {
194 match stream_read.read_message(max_message_size).await {
195 Ok(msg) => {
196 let mut cursor = Cursor::new(msg);
198
199 let mut header =
202 match ResponseHeader::read_versioned(&mut cursor, ApiVersion(Int16(0)))
203 {
204 Ok(header) => header,
205 Err(e) => {
206 warn!(%e, "Cannot read message header, ignoring message");
207 continue;
208 }
209 };
210
211 let active_request = match state_captured.lock().deref_mut() {
212 MessengerState::RequestMap(map) => {
213 if let Some(active_request) = map.remove(&header.correlation_id.0) {
214 active_request
215 } else {
216 warn!(
217 correlation_id = header.correlation_id.0,
218 "Got response for unknown request",
219 );
220 continue;
221 }
222 }
223 MessengerState::Poison(_) => {
224 return;
226 }
227 };
228
229 if active_request.use_tagged_fields_in_response {
231 header.tagged_fields = match TaggedFields::read(&mut cursor) {
232 Ok(fields) => Some(fields),
233 Err(e) => {
234 active_request
236 .channel
237 .send(Err(RequestError::ReadError(e)))
238 .ok();
239 continue;
240 }
241 };
242 }
243
244 active_request
246 .channel
247 .send(Ok(Response {
248 header,
249 data: cursor,
250 }))
251 .ok();
252 }
253 Err(e) => {
254 state_captured
255 .lock()
256 .poison(RequestError::ReadFramedMessageError(e));
257 return;
258 }
259 }
260 }
261 });
262
263 Self {
264 stream_write: Arc::new(AsyncMutex::new(stream_write)),
265 client_id,
266 correlation_id: AtomicI32::new(0),
267 version_ranges: HashMap::new(),
268 state,
269 join_handle,
270 }
271 }
272
273 #[cfg(feature = "unstable-fuzzing")]
274 pub fn override_version_ranges(&mut self, ranges: HashMap<ApiKey, ApiVersionRange>) {
275 self.set_version_ranges(ranges);
276 }
277
278 fn set_version_ranges(&mut self, ranges: HashMap<ApiKey, ApiVersionRange>) {
280 self.version_ranges = ranges;
281 }
282
283 pub async fn request<R>(&self, msg: R) -> Result<R::ResponseBody, RequestError>
284 where
285 R: RequestBody + Send + WriteVersionedType<Vec<u8>>,
286 R::ResponseBody: ReadVersionedType<Cursor<Vec<u8>>>,
287 {
288 self.request_with_version_ranges(msg, &self.version_ranges)
289 .await
290 }
291
292 async fn request_with_version_ranges<R>(
293 &self,
294 msg: R,
295 version_ranges: &HashMap<ApiKey, ApiVersionRange>,
296 ) -> Result<R::ResponseBody, RequestError>
297 where
298 R: RequestBody + Send + WriteVersionedType<Vec<u8>>,
299 R::ResponseBody: ReadVersionedType<Cursor<Vec<u8>>>,
300 {
301 let body_api_version = version_ranges
302 .get(&R::API_KEY)
303 .and_then(|range_server| match_versions(*range_server, R::API_VERSION_RANGE))
304 .ok_or(RequestError::NoVersionMatch {
305 api_key: R::API_KEY,
306 })?;
307
308 let use_tagged_fields_in_request =
313 body_api_version >= R::FIRST_TAGGED_FIELD_IN_REQUEST_VERSION;
314 let use_tagged_fields_in_response =
315 body_api_version >= R::FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION;
316
317 let correlation_id = self.correlation_id.fetch_add(1, Ordering::SeqCst);
319
320 let header = RequestHeader {
321 request_api_key: R::API_KEY,
322 request_api_version: body_api_version,
323 correlation_id: Int32(correlation_id),
324 client_id: Some(NullableString(Some(String::from(self.client_id.as_ref())))),
327 tagged_fields: Some(TaggedFields::default()),
328 };
329 let header_version = if use_tagged_fields_in_request {
330 ApiVersion(Int16(2))
331 } else {
332 ApiVersion(Int16(1))
333 };
334
335 let mut buf = Vec::new();
336 header
337 .write_versioned(&mut buf, header_version)
338 .expect("Writing header to buffer should always work");
339 msg.write_versioned(&mut buf, body_api_version)?;
340
341 let (tx, rx) = channel();
342
343 let cleanup_on_cancel =
346 CleanupRequestStateOnCancel::new(Arc::clone(&self.state), correlation_id);
347
348 match self.state.lock().deref_mut() {
349 MessengerState::RequestMap(map) => {
350 map.insert(
351 correlation_id,
352 ActiveRequest {
353 channel: tx,
354 use_tagged_fields_in_response,
355 },
356 );
357 }
358 MessengerState::Poison(e) => {
359 return Err(RequestError::Poisoned(Arc::clone(e)));
360 }
361 }
362
363 self.send_message(buf).await?;
364 cleanup_on_cancel.message_sent();
365
366 let mut response = rx.await.expect("Who closed this channel?!")?;
367 let body = R::ResponseBody::read_versioned(&mut response.data, body_api_version)?;
368
369 let read_bytes = response.data.position();
371 let message_bytes = response.data.into_inner().len() as u64;
372 if read_bytes != message_bytes {
373 return Err(RequestError::TooMuchData {
374 message_size: message_bytes,
375 read: read_bytes,
376 api_key: R::API_KEY,
377 api_version: body_api_version,
378 });
379 }
380
381 Ok(body)
382 }
383
384 async fn send_message(&self, msg: Vec<u8>) -> Result<(), RequestError> {
385 match self.send_message_inner(msg).await {
386 Ok(()) => Ok(()),
387 Err(e) => {
388 let mut state = self.state.lock();
390 Err(RequestError::Poisoned(state.poison(e)))
391 }
392 }
393 }
394
395 async fn send_message_inner(&self, msg: Vec<u8>) -> Result<(), RequestError> {
396 let mut stream_write = Arc::clone(&self.stream_write).lock_owned().await;
397
398 let fut = CancellationSafeFuture::new(async move {
400 stream_write.write_message(&msg).await?;
401 stream_write.flush().await?;
402 Ok(())
403 });
404
405 fut.await
406 }
407
408 pub async fn sync_versions(&mut self) -> Result<(), SyncVersionsError> {
412 'iter_upper_bound: for upper_bound in (ApiVersionsRequest::API_VERSION_RANGE.min().0 .0
413 ..=ApiVersionsRequest::API_VERSION_RANGE.max().0 .0)
414 .rev()
415 {
416 let version_ranges = HashMap::from([(
417 ApiKey::ApiVersions,
418 ApiVersionRange::new(
419 ApiVersionsRequest::API_VERSION_RANGE.min(),
420 ApiVersion(Int16(upper_bound)),
421 ),
422 )]);
423
424 let body = ApiVersionsRequest {
425 client_software_name: Some(CompactString(String::from(env!("CARGO_PKG_NAME")))),
426 client_software_version: Some(CompactString(String::from(env!(
427 "CARGO_PKG_VERSION"
428 )))),
429 tagged_fields: Some(TaggedFields::default()),
430 };
431
432 'throttle: loop {
433 match self
434 .request_with_version_ranges(&body, &version_ranges)
435 .await
436 {
437 Ok(response) => {
438 if let Err(ErrorOrThrottle::Throttle(throttle)) =
439 maybe_throttle::<SyncVersionsError>(response.throttle_time_ms)
440 {
441 info!(
442 ?throttle,
443 request_name = "version sync",
444 "broker asked us to throttle"
445 );
446 tokio::time::sleep(throttle).await;
447 continue 'throttle;
448 }
449
450 if let Some(e) = response.error_code {
451 debug!(
452 %e,
453 version=upper_bound,
454 "Got error during version sync, cannot use version for ApiVersionRequest",
455 );
456 continue 'iter_upper_bound;
457 }
458
459 for api_key in &response.api_keys {
461 if api_key.min_version.0 > api_key.max_version.0 {
462 return Err(SyncVersionsError::FlippedVersionRange {
463 api_key: api_key.api_key,
464 min: api_key.min_version,
465 max: api_key.max_version,
466 });
467 }
468 }
469
470 let ranges = response
471 .api_keys
472 .into_iter()
473 .map(|x| {
474 (
475 x.api_key,
476 ApiVersionRange::new(x.min_version, x.max_version),
477 )
478 })
479 .collect();
480 debug!(
481 versions=%sorted_ranges_repr(&ranges),
482 "Detected supported broker versions",
483 );
484 self.set_version_ranges(ranges);
485 return Ok(());
486 }
487 Err(RequestError::NoVersionMatch { .. }) => {
488 unreachable!("Just set to version range to a non-empty range")
489 }
490 Err(RequestError::ReadVersionedError(e)) => {
491 debug!(
492 %e,
493 version=upper_bound,
494 "Cannot read ApiVersionResponse for version",
495 );
496 continue 'iter_upper_bound;
497 }
498 Err(RequestError::ReadError(e)) => {
499 debug!(
500 %e,
501 version=upper_bound,
502 "Cannot read ApiVersionResponse for version",
503 );
504 continue 'iter_upper_bound;
505 }
506 Err(e @ RequestError::TooMuchData { .. }) => {
507 debug!(
508 %e,
509 version=upper_bound,
510 "Cannot read ApiVersionResponse for version",
511 );
512 continue 'iter_upper_bound;
513 }
514 Err(e) => {
515 return Err(SyncVersionsError::RequestError(e));
516 }
517 }
518 }
519 }
520
521 Err(SyncVersionsError::NoWorkingVersion)
522 }
523}
524
525impl<RW> Drop for Messenger<RW> {
526 fn drop(&mut self) {
527 self.join_handle.abort();
528 }
529}
530
531fn sorted_ranges_repr(ranges: &HashMap<ApiKey, ApiVersionRange>) -> String {
532 let mut ranges: Vec<_> = ranges.iter().map(|(key, range)| (*key, *range)).collect();
533 ranges.sort_by_key(|(key, _range)| *key);
534 let ranges: Vec<_> = ranges
535 .into_iter()
536 .map(|(key, range)| format!("{:?}: {}", key, range))
537 .collect();
538 ranges.join(", ")
539}
540
541fn match_versions(range_a: ApiVersionRange, range_b: ApiVersionRange) -> Option<ApiVersion> {
542 if range_a.min() <= range_b.max() && range_b.min() <= range_a.max() {
543 Some(range_a.max().min(range_b.max()))
544 } else {
545 None
546 }
547}
548
549struct CleanupRequestStateOnCancel {
551 state: Arc<Mutex<MessengerState>>,
552 correlation_id: i32,
553 message_sent: bool,
554}
555
556impl CleanupRequestStateOnCancel {
557 fn new(state: Arc<Mutex<MessengerState>>, correlation_id: i32) -> Self {
561 Self {
562 state,
563 correlation_id,
564 message_sent: false,
565 }
566 }
567
568 fn message_sent(mut self) {
570 self.message_sent = true;
571 }
572}
573
574impl Drop for CleanupRequestStateOnCancel {
575 fn drop(&mut self) {
576 if !self.message_sent {
577 if let MessengerState::RequestMap(map) = self.state.lock().deref_mut() {
578 map.remove(&self.correlation_id);
579 }
580 }
581 }
582}
583
584struct CancellationSafeFuture<F>
588where
589 F: Future + Send + 'static,
590{
591 done: bool,
593
594 inner: Option<BoxFuture<'static, F::Output>>,
601}
602
603impl<F> Drop for CancellationSafeFuture<F>
604where
605 F: Future + Send + 'static,
606{
607 fn drop(&mut self) {
608 if !self.done {
609 let inner = self.inner.take().expect("Double-drop?");
610 tokio::task::spawn(async move {
611 inner.await;
612 });
613 }
614 }
615}
616
617impl<F> CancellationSafeFuture<F>
618where
619 F: Future + Send,
620{
621 fn new(fut: F) -> Self {
622 Self {
623 done: false,
624 inner: Some(Box::pin(fut)),
625 }
626 }
627}
628
629impl<F> Future for CancellationSafeFuture<F>
630where
631 F: Future + Send,
632{
633 type Output = F::Output;
634
635 fn poll(
636 mut self: std::pin::Pin<&mut Self>,
637 cx: &mut std::task::Context<'_>,
638 ) -> Poll<Self::Output> {
639 match self.inner.as_mut().expect("no dropped").as_mut().poll(cx) {
640 Poll::Ready(res) => {
641 self.done = true;
642 Poll::Ready(res)
643 }
644 Poll::Pending => Poll::Pending,
645 }
646 }
647}
648
649#[cfg(test)]
650mod tests {
651 use std::time::Duration;
652
653 use assert_matches::assert_matches;
654 use futures::{pin_mut, FutureExt};
655 use tokio::{
656 io::{AsyncReadExt, DuplexStream},
657 sync::{mpsc::UnboundedSender, Barrier},
658 };
659
660 use super::*;
661
662 use crate::{
663 build_info::DEFAULT_CLIENT_ID,
664 protocol::{
665 error::Error as ApiError,
666 messages::{
667 ApiVersionsResponse, ApiVersionsResponseApiKey, ListOffsetsRequest, NORMAL_CONSUMER,
668 },
669 traits::WriteType,
670 },
671 };
672
673 #[test]
674 fn test_match_versions() {
675 assert_eq!(
676 match_versions(
677 ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(20))),
678 ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(20))),
679 ),
680 Some(ApiVersion(Int16(20))),
681 );
682
683 assert_eq!(
684 match_versions(
685 ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(15))),
686 ApiVersionRange::new(ApiVersion(Int16(13)), ApiVersion(Int16(20))),
687 ),
688 Some(ApiVersion(Int16(15))),
689 );
690
691 assert_eq!(
692 match_versions(
693 ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(15))),
694 ApiVersionRange::new(ApiVersion(Int16(15)), ApiVersion(Int16(20))),
695 ),
696 Some(ApiVersion(Int16(15))),
697 );
698
699 assert_eq!(
700 match_versions(
701 ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(14))),
702 ApiVersionRange::new(ApiVersion(Int16(15)), ApiVersion(Int16(20))),
703 ),
704 None,
705 );
706 }
707
708 #[tokio::test]
709 async fn test_sync_versions_ok() {
710 let (sim, rx) = MessageSimulator::new();
711 let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
712
713 let mut msg = vec![];
715 ResponseHeader {
716 correlation_id: Int32(0),
717 tagged_fields: Default::default(), }
719 .write_versioned(&mut msg, ApiVersion(Int16(0)))
720 .unwrap();
721 ApiVersionsResponse {
722 error_code: None,
723 api_keys: vec![ApiVersionsResponseApiKey {
724 api_key: ApiKey::Produce,
725 min_version: ApiVersion(Int16(1)),
726 max_version: ApiVersion(Int16(5)),
727 tagged_fields: Default::default(),
728 }],
729 throttle_time_ms: None,
730 tagged_fields: None,
731 }
732 .write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
733 .unwrap();
734 sim.push(msg);
735
736 messenger.sync_versions().await.unwrap();
738 let expected = HashMap::from([(
739 (ApiKey::Produce),
740 ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
741 )]);
742 assert_eq!(messenger.version_ranges, expected);
743 }
744
745 #[tokio::test]
746 async fn test_sync_versions_ignores_error_code() {
747 let (sim, rx) = MessageSimulator::new();
748 let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
749
750 let mut msg = vec![];
752 ResponseHeader {
753 correlation_id: Int32(0),
754 tagged_fields: Default::default(), }
756 .write_versioned(&mut msg, ApiVersion(Int16(0)))
757 .unwrap();
758 ApiVersionsResponse {
759 error_code: Some(ApiError::CorruptMessage),
760 api_keys: vec![ApiVersionsResponseApiKey {
761 api_key: ApiKey::Produce,
762 min_version: ApiVersion(Int16(2)),
763 max_version: ApiVersion(Int16(3)),
764 tagged_fields: Default::default(),
765 }],
766 throttle_time_ms: None,
767 tagged_fields: None,
768 }
769 .write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
770 .unwrap();
771 sim.push(msg);
772
773 let mut msg = vec![];
775 ResponseHeader {
776 correlation_id: Int32(1),
777 tagged_fields: Default::default(),
778 }
779 .write_versioned(&mut msg, ApiVersion(Int16(0)))
780 .unwrap();
781 ApiVersionsResponse {
782 error_code: None,
783 api_keys: vec![ApiVersionsResponseApiKey {
784 api_key: ApiKey::Produce,
785 min_version: ApiVersion(Int16(1)),
786 max_version: ApiVersion(Int16(5)),
787 tagged_fields: Default::default(),
788 }],
789 throttle_time_ms: None,
790 tagged_fields: None,
791 }
792 .write_versioned(
793 &mut msg,
794 ApiVersion(Int16(ApiVersionsRequest::API_VERSION_RANGE.max().0 .0 - 1)),
795 )
796 .unwrap();
797 sim.push(msg);
798
799 messenger.sync_versions().await.unwrap();
801 let expected = HashMap::from([(
802 (ApiKey::Produce),
803 ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
804 )]);
805 assert_eq!(messenger.version_ranges, expected);
806 }
807
808 #[tokio::test]
809 async fn test_sync_versions_ignores_read_code() {
810 let (sim, rx) = MessageSimulator::new();
811 let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
812
813 let mut msg = vec![];
815 ResponseHeader {
816 correlation_id: Int32(0),
817 tagged_fields: Default::default(), }
819 .write_versioned(&mut msg, ApiVersion(Int16(0)))
820 .unwrap();
821 msg.push(b'\0'); sim.push(msg);
823
824 let mut msg = vec![];
826 ResponseHeader {
827 correlation_id: Int32(1),
828 tagged_fields: Default::default(),
829 }
830 .write_versioned(&mut msg, ApiVersion(Int16(0)))
831 .unwrap();
832 ApiVersionsResponse {
833 error_code: None,
834 api_keys: vec![ApiVersionsResponseApiKey {
835 api_key: ApiKey::Produce,
836 min_version: ApiVersion(Int16(1)),
837 max_version: ApiVersion(Int16(5)),
838 tagged_fields: Default::default(),
839 }],
840 throttle_time_ms: None,
841 tagged_fields: None,
842 }
843 .write_versioned(
844 &mut msg,
845 ApiVersion(Int16(ApiVersionsRequest::API_VERSION_RANGE.max().0 .0 - 1)),
846 )
847 .unwrap();
848 sim.push(msg);
849
850 messenger.sync_versions().await.unwrap();
852 let expected = HashMap::from([(
853 (ApiKey::Produce),
854 ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
855 )]);
856 assert_eq!(messenger.version_ranges, expected);
857 }
858
859 #[tokio::test]
860 async fn test_sync_versions_err_flipped_range() {
861 let (sim, rx) = MessageSimulator::new();
862 let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
863
864 let mut msg = vec![];
866 ResponseHeader {
867 correlation_id: Int32(0),
868 tagged_fields: Default::default(), }
870 .write_versioned(&mut msg, ApiVersion(Int16(0)))
871 .unwrap();
872 ApiVersionsResponse {
873 error_code: None,
874 api_keys: vec![ApiVersionsResponseApiKey {
875 api_key: ApiKey::Produce,
876 min_version: ApiVersion(Int16(2)),
877 max_version: ApiVersion(Int16(1)),
878 tagged_fields: Default::default(),
879 }],
880 throttle_time_ms: None,
881 tagged_fields: None,
882 }
883 .write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
884 .unwrap();
885 sim.push(msg);
886
887 let err = messenger.sync_versions().await.unwrap_err();
889 assert_matches!(err, SyncVersionsError::FlippedVersionRange { .. });
890 }
891
892 #[tokio::test]
893 async fn test_sync_versions_ignores_garbage() {
894 let (sim, rx) = MessageSimulator::new();
895 let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
896
897 let mut msg = vec![];
899 ResponseHeader {
900 correlation_id: Int32(0),
901 tagged_fields: Default::default(), }
903 .write_versioned(&mut msg, ApiVersion(Int16(0)))
904 .unwrap();
905 ApiVersionsResponse {
906 error_code: None,
907 api_keys: vec![ApiVersionsResponseApiKey {
908 api_key: ApiKey::Produce,
909 min_version: ApiVersion(Int16(1)),
910 max_version: ApiVersion(Int16(2)),
911 tagged_fields: Default::default(),
912 }],
913 throttle_time_ms: None,
914 tagged_fields: None,
915 }
916 .write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
917 .unwrap();
918 msg.push(b'\0'); sim.push(msg);
920
921 let mut msg = vec![];
923 ResponseHeader {
924 correlation_id: Int32(1),
925 tagged_fields: Default::default(),
926 }
927 .write_versioned(&mut msg, ApiVersion(Int16(0)))
928 .unwrap();
929 ApiVersionsResponse {
930 error_code: None,
931 api_keys: vec![ApiVersionsResponseApiKey {
932 api_key: ApiKey::Produce,
933 min_version: ApiVersion(Int16(1)),
934 max_version: ApiVersion(Int16(5)),
935 tagged_fields: Default::default(),
936 }],
937 throttle_time_ms: None,
938 tagged_fields: None,
939 }
940 .write_versioned(
941 &mut msg,
942 ApiVersion(Int16(ApiVersionsRequest::API_VERSION_RANGE.max().0 .0 - 1)),
943 )
944 .unwrap();
945 sim.push(msg);
946
947 messenger.sync_versions().await.unwrap();
949 let expected = HashMap::from([(
950 (ApiKey::Produce),
951 ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
952 )]);
953 assert_eq!(messenger.version_ranges, expected);
954 }
955
956 #[tokio::test]
957 async fn test_sync_versions_err_no_working_version() {
958 let (sim, rx) = MessageSimulator::new();
959 let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
960
961 for (i, v) in ((ApiVersionsRequest::API_VERSION_RANGE.min().0 .0)
963 ..=(ApiVersionsRequest::API_VERSION_RANGE.max().0 .0))
964 .rev()
965 .enumerate()
966 {
967 let mut msg = vec![];
968 ResponseHeader {
969 correlation_id: Int32(i as i32),
970 tagged_fields: Default::default(),
971 }
972 .write_versioned(&mut msg, ApiVersion(Int16(0)))
973 .unwrap();
974 ApiVersionsResponse {
975 error_code: Some(ApiError::CorruptMessage),
976 api_keys: vec![ApiVersionsResponseApiKey {
977 api_key: ApiKey::Produce,
978 min_version: ApiVersion(Int16(1)),
979 max_version: ApiVersion(Int16(5)),
980 tagged_fields: Default::default(),
981 }],
982 throttle_time_ms: None,
983 tagged_fields: None,
984 }
985 .write_versioned(&mut msg, ApiVersion(Int16(v)))
986 .unwrap();
987 sim.push(msg);
988 }
989
990 let err = messenger.sync_versions().await.unwrap_err();
992 assert_matches!(err, SyncVersionsError::NoWorkingVersion);
993 }
994
995 #[tokio::test]
996 async fn test_poison_hangup() {
997 let (sim, rx) = MessageSimulator::new();
998 let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
999 messenger.set_version_ranges(HashMap::from([(
1000 ApiKey::ListOffsets,
1001 ListOffsetsRequest::API_VERSION_RANGE,
1002 )]));
1003
1004 sim.hang_up();
1005
1006 let err = messenger
1007 .request(ListOffsetsRequest {
1008 replica_id: NORMAL_CONSUMER,
1009 isolation_level: None,
1010 topics: vec![],
1011 })
1012 .await
1013 .unwrap_err();
1014 assert_matches!(err, RequestError::Poisoned(_));
1015 }
1016
1017 #[tokio::test]
1018 async fn test_poison_negative_message_size() {
1019 let (sim, rx) = MessageSimulator::new();
1020 let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1021 messenger.set_version_ranges(HashMap::from([(
1022 ApiKey::ListOffsets,
1023 ListOffsetsRequest::API_VERSION_RANGE,
1024 )]));
1025
1026 sim.negative_message_size();
1027
1028 let err = messenger
1029 .request(ListOffsetsRequest {
1030 replica_id: NORMAL_CONSUMER,
1031 isolation_level: None,
1032 topics: vec![],
1033 })
1034 .await
1035 .unwrap_err();
1036 assert_matches!(err, RequestError::Poisoned(_));
1037
1038 let err = messenger
1040 .request(ListOffsetsRequest {
1041 replica_id: NORMAL_CONSUMER,
1042 isolation_level: None,
1043 topics: vec![],
1044 })
1045 .await
1046 .unwrap_err();
1047 assert_matches!(err, RequestError::Poisoned(_));
1048 }
1049
1050 #[tokio::test]
1051 async fn test_broken_msg_header_does_not_poison() {
1052 let (sim, rx) = MessageSimulator::new();
1053 let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1054 messenger.set_version_ranges(HashMap::from([(
1055 ApiKey::ApiVersions,
1056 ApiVersionsRequest::API_VERSION_RANGE,
1057 )]));
1058
1059 sim.send(b"foo".to_vec());
1061
1062 let mut msg = vec![];
1064 ResponseHeader {
1065 correlation_id: Int32(0),
1066 tagged_fields: Default::default(), }
1068 .write_versioned(&mut msg, ApiVersion(Int16(0)))
1069 .unwrap();
1070 let resp = ApiVersionsResponse {
1071 error_code: Some(ApiError::CorruptMessage),
1072 api_keys: vec![],
1073 throttle_time_ms: Some(Int32(1)),
1074 tagged_fields: Some(TaggedFields::default()),
1075 };
1076 resp.write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
1077 .unwrap();
1078 sim.push(msg);
1079
1080 let actual = messenger
1081 .request(ApiVersionsRequest {
1082 client_software_name: Some(CompactString(String::new())),
1083 client_software_version: Some(CompactString(String::new())),
1084 tagged_fields: Some(TaggedFields::default()),
1085 })
1086 .await
1087 .unwrap();
1088 assert_eq!(actual, resp);
1089 }
1090
1091 #[tokio::test]
1092 async fn test_cancel_request() {
1093 let (tx_front, rx_middle) = tokio::io::duplex(1);
1096 let (tx_middle, mut rx_back) = tokio::io::duplex(1);
1097
1098 let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1099
1100 let network_pause = Arc::new(Barrier::new(2));
1108 let network_pause_captured = Arc::clone(&network_pause);
1109 let network_continue = Arc::new(Barrier::new(2));
1110 let network_continue_captured = Arc::clone(&network_continue);
1111 let handle_network = tokio::spawn(async move {
1112 let (mut rx_middle_read, mut rx_middle_write) = tokio::io::split(rx_middle);
1115 let (mut tx_middle_read, mut tx_middle_write) = tokio::io::split(tx_middle);
1116
1117 let direction_client_broker = async {
1118 for i in 0.. {
1119 let mut buf = [0; 1];
1120 rx_middle_read.read_exact(&mut buf).await.unwrap();
1121 tx_middle_write.write_all(&buf).await.unwrap();
1122
1123 if i == 3 {
1124 network_pause_captured.wait().await;
1125 network_continue_captured.wait().await;
1126 }
1127 }
1128 };
1129
1130 let direction_broker_client = async {
1131 loop {
1132 let mut buf = [0; 1];
1133 tx_middle_read.read_exact(&mut buf).await.unwrap();
1134 rx_middle_write.write_all(&buf).await.unwrap();
1135 }
1136 };
1137
1138 tokio::select! {
1139 _ = direction_client_broker => {}
1140 _ = direction_broker_client => {}
1141 }
1142 });
1143
1144 let handle_broker = tokio::spawn(async move {
1146 for correlation_id in 0.. {
1147 let data = rx_back.read_message(1_000).await.unwrap();
1148 let mut data = Cursor::new(data);
1149 let header =
1150 RequestHeader::read_versioned(&mut data, ApiVersion(Int16(1))).unwrap();
1151 assert_eq!(
1152 header,
1153 RequestHeader {
1154 request_api_key: ApiKey::ApiVersions,
1155 request_api_version: ApiVersion(Int16(0)),
1156 correlation_id: Int32(correlation_id),
1157 client_id: Some(NullableString(Some(String::from(env!("CARGO_PKG_NAME"))))),
1158 tagged_fields: None,
1159 }
1160 );
1161 let body =
1162 ApiVersionsRequest::read_versioned(&mut data, ApiVersion(Int16(0))).unwrap();
1163 assert_eq!(
1164 body,
1165 ApiVersionsRequest {
1166 client_software_name: None,
1167 client_software_version: None,
1168 tagged_fields: None,
1169 }
1170 );
1171 assert_eq!(data.position() as usize, data.get_ref().len());
1172
1173 let mut msg = vec![];
1174 ResponseHeader {
1175 correlation_id: Int32(correlation_id),
1176 tagged_fields: Default::default(), }
1178 .write_versioned(&mut msg, ApiVersion(Int16(0)))
1179 .unwrap();
1180 let resp = ApiVersionsResponse {
1181 error_code: Some(ApiError::CorruptMessage),
1182 api_keys: vec![],
1183 throttle_time_ms: Some(Int32(1)),
1184 tagged_fields: Some(TaggedFields::default()),
1185 };
1186 resp.write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.min())
1187 .unwrap();
1188 rx_back.write_message(&msg).await.unwrap();
1189 }
1190 });
1191
1192 messenger.set_version_ranges(HashMap::from([(
1193 ApiKey::ApiVersions,
1194 ApiVersionRange::new(ApiVersion(Int16(0)), ApiVersion(Int16(0))),
1195 )]));
1196
1197 let task_to_cancel = (async {
1199 messenger
1200 .request(ApiVersionsRequest {
1201 client_software_name: Some(CompactString(String::from("foo"))),
1202 client_software_version: Some(CompactString(String::from("bar"))),
1203 tagged_fields: Some(TaggedFields::default()),
1204 })
1205 .await
1206 .unwrap();
1207 })
1208 .fuse();
1209
1210 {
1211 pin_mut!(task_to_cancel);
1213
1214 futures::select_biased! {
1216 _ = &mut task_to_cancel => panic!("should not have finished"),
1217 _ = network_pause.wait().fuse() => {},
1218 }
1219 }
1220
1221 network_continue.wait().await;
1223
1224 tokio::time::timeout(Duration::from_millis(100), async {
1227 messenger
1228 .request(ApiVersionsRequest {
1229 client_software_name: Some(CompactString(String::from("foo"))),
1230 client_software_version: Some(CompactString(String::from("bar"))),
1231 tagged_fields: Some(TaggedFields::default()),
1232 })
1233 .await
1234 .unwrap();
1235 })
1236 .await
1237 .unwrap();
1238
1239 handle_broker.abort();
1241 handle_network.abort();
1242 }
1243
1244 #[derive(Debug)]
1245 enum Message {
1246 Send(Vec<u8>),
1247 Consume,
1248 NegativeMessageSize,
1249 HangUp,
1250 }
1251
1252 struct MessageSimulator {
1253 messages: UnboundedSender<Message>,
1254 join_handle: JoinHandle<()>,
1255 }
1256
1257 impl MessageSimulator {
1258 fn new() -> (Self, DuplexStream) {
1259 let (mut tx, rx) = tokio::io::duplex(1_000);
1260 let (msg_tx, mut msg_rx) = tokio::sync::mpsc::unbounded_channel();
1261
1262 let join_handle = tokio::task::spawn(async move {
1263 loop {
1264 let message = match msg_rx.recv().await {
1265 Some(msg) => msg,
1266 None => return,
1267 };
1268
1269 match message {
1270 Message::Consume => {
1271 tx.read_message(1_000).await.unwrap();
1272 }
1273 Message::Send(data) => {
1274 tx.write_message(&data).await.unwrap();
1275 }
1276 Message::NegativeMessageSize => {
1277 let mut buf = vec![];
1278 Int32(-1).write(&mut buf).unwrap();
1279 tx.write_all(&buf).await.unwrap()
1280 }
1281 Message::HangUp => {
1282 return;
1283 }
1284 }
1285 }
1286 });
1287
1288 let this = Self {
1289 messages: msg_tx,
1290 join_handle,
1291 };
1292 (this, rx)
1293 }
1294
1295 fn push(&self, msg: Vec<u8>) {
1296 self.consume();
1300 self.send(msg);
1301 }
1302
1303 fn consume(&self) {
1304 self.messages.send(Message::Consume).unwrap();
1305 }
1306
1307 fn send(&self, msg: Vec<u8>) {
1308 self.messages.send(Message::Send(msg)).unwrap();
1309 }
1310
1311 fn negative_message_size(&self) {
1312 self.messages.send(Message::NegativeMessageSize).unwrap();
1313 }
1314
1315 fn hang_up(&self) {
1316 self.messages.send(Message::HangUp).unwrap();
1317 }
1318 }
1319
1320 impl Drop for MessageSimulator {
1321 fn drop(&mut self) {
1322 self.join_handle.abort();
1324 }
1325 }
1326}