rskafka/
messenger.rs

1use std::{
2    collections::HashMap,
3    future::Future,
4    io::Cursor,
5    ops::DerefMut,
6    sync::{
7        atomic::{AtomicI32, Ordering},
8        Arc,
9    },
10    task::Poll,
11};
12
13use futures::future::BoxFuture;
14use parking_lot::Mutex;
15use thiserror::Error;
16use tokio::{
17    io::{AsyncRead, AsyncWrite, AsyncWriteExt, WriteHalf},
18    sync::{
19        oneshot::{channel, Sender},
20        Mutex as AsyncMutex,
21    },
22    task::JoinHandle,
23};
24use tracing::{debug, info, warn};
25
26use crate::protocol::{api_version::ApiVersionRange, primitives::CompactString};
27use crate::protocol::{messages::ApiVersionsRequest, traits::ReadType};
28use crate::{
29    backoff::ErrorOrThrottle,
30    protocol::{
31        api_key::ApiKey,
32        api_version::ApiVersion,
33        frame::{AsyncMessageRead, AsyncMessageWrite},
34        messages::{
35            ReadVersionedError, ReadVersionedType, RequestBody, RequestHeader, ResponseHeader,
36            WriteVersionedError, WriteVersionedType,
37        },
38        primitives::{Int16, Int32, NullableString, TaggedFields},
39    },
40    throttle::maybe_throttle,
41};
42
43#[derive(Debug)]
44struct Response {
45    #[allow(dead_code)]
46    header: ResponseHeader,
47    data: Cursor<Vec<u8>>,
48}
49
50#[derive(Debug)]
51struct ActiveRequest {
52    channel: Sender<Result<Response, RequestError>>,
53    use_tagged_fields_in_response: bool,
54}
55
56#[derive(Debug)]
57enum MessengerState {
58    /// Currently active requests by correlation ID.
59    ///
60    /// An active request is one that got prepared or send but the response wasn't received yet.
61    RequestMap(HashMap<i32, ActiveRequest>),
62
63    /// One or our streams died and we are unable to process any more requests.
64    Poison(Arc<RequestError>),
65}
66
67impl MessengerState {
68    fn poison(&mut self, err: RequestError) -> Arc<RequestError> {
69        match self {
70            Self::RequestMap(map) => {
71                let err = Arc::new(err);
72
73                // inform all active requests
74                for (_correlation_id, active_request) in map.drain() {
75                    // it's OK if the other side is gone
76                    active_request
77                        .channel
78                        .send(Err(RequestError::Poisoned(Arc::clone(&err))))
79                        .ok();
80                }
81
82                *self = Self::Poison(Arc::clone(&err));
83                err
84            }
85            Self::Poison(e) => {
86                // already poisoned, used existing error
87                Arc::clone(e)
88            }
89        }
90    }
91}
92
93/// A connection to a single broker
94///
95/// Note: Requests to the same [`Messenger`] will be pipelined by Kafka
96///
97#[derive(Debug)]
98pub struct Messenger<RW> {
99    /// The half of the stream that we use to send data TO the broker.
100    ///
101    /// This will be used by [`request`](Self::request) to queue up messages.
102    stream_write: Arc<AsyncMutex<WriteHalf<RW>>>,
103
104    /// Client ID.
105    client_id: Arc<str>,
106
107    /// The next correlation ID.
108    ///
109    /// This is used to map responses to active requests.
110    correlation_id: AtomicI32,
111
112    /// Version ranges that we think are supported by the broker.
113    ///
114    /// This needs to be bootstrapped by [`sync_versions`](Self::sync_versions).
115    version_ranges: HashMap<ApiKey, ApiVersionRange>,
116
117    /// Current stream state.
118    ///
119    /// Note that this and `stream_write` are separate struct to allow sending and receiving data concurrently.
120    state: Arc<Mutex<MessengerState>>,
121
122    /// Join handle for the background worker that fetches responses.
123    join_handle: JoinHandle<()>,
124}
125
126#[derive(Error, Debug)]
127#[non_exhaustive]
128pub enum RequestError {
129    #[error("Cannot find matching version for: {api_key:?}")]
130    NoVersionMatch { api_key: ApiKey },
131
132    #[error("Cannot write data: {0}")]
133    WriteError(#[from] WriteVersionedError),
134
135    #[error("Cannot write versioned data: {0}")]
136    WriteMessageError(#[from] crate::protocol::frame::WriteError),
137
138    #[error("Cannot read data: {0}")]
139    ReadError(#[from] crate::protocol::traits::ReadError),
140
141    #[error("Cannot read versioned data: {0}")]
142    ReadVersionedError(#[from] ReadVersionedError),
143
144    #[error("Cannot read/write data: {0}")]
145    IO(#[from] std::io::Error),
146
147    #[error(
148        "Data left at the end of the message. Got {message_size} bytes but only read {read} bytes. api_key={api_key:?} api_version={api_version}"
149    )]
150    TooMuchData {
151        message_size: u64,
152        read: u64,
153        api_key: ApiKey,
154        api_version: ApiVersion,
155    },
156
157    #[error("Cannot read framed message: {0}")]
158    ReadFramedMessageError(#[from] crate::protocol::frame::ReadError),
159
160    #[error("Connection is poisoned: {0}")]
161    Poisoned(Arc<RequestError>),
162}
163
164#[derive(Error, Debug)]
165#[non_exhaustive]
166pub enum SyncVersionsError {
167    #[error("Did not found a version for ApiVersion that works with that broker")]
168    NoWorkingVersion,
169
170    #[error("Request error: {0}")]
171    RequestError(#[from] RequestError),
172
173    #[error("Got flipped version from server for API key {api_key:?}: min={min:?} max={max:?}")]
174    FlippedVersionRange {
175        api_key: ApiKey,
176        min: ApiVersion,
177        max: ApiVersion,
178    },
179}
180
181impl<RW> Messenger<RW>
182where
183    RW: AsyncRead + AsyncWrite + Send + 'static,
184{
185    pub fn new(stream: RW, max_message_size: usize, client_id: Arc<str>) -> Self {
186        let (stream_read, stream_write) = tokio::io::split(stream);
187        let state = Arc::new(Mutex::new(MessengerState::RequestMap(HashMap::default())));
188        let state_captured = Arc::clone(&state);
189
190        let join_handle = tokio::spawn(async move {
191            let mut stream_read = stream_read;
192
193            loop {
194                match stream_read.read_message(max_message_size).await {
195                    Ok(msg) => {
196                        // message was read, so all subsequent errors should not poison the whole stream
197                        let mut cursor = Cursor::new(msg);
198
199                        // read header as version 0 (w/o tagged fields) first since this is a strict prefix or the more advanced
200                        // header version
201                        let mut header =
202                            match ResponseHeader::read_versioned(&mut cursor, ApiVersion(Int16(0)))
203                            {
204                                Ok(header) => header,
205                                Err(e) => {
206                                    warn!(%e, "Cannot read message header, ignoring message");
207                                    continue;
208                                }
209                            };
210
211                        let active_request = match state_captured.lock().deref_mut() {
212                            MessengerState::RequestMap(map) => {
213                                if let Some(active_request) = map.remove(&header.correlation_id.0) {
214                                    active_request
215                                } else {
216                                    warn!(
217                                        correlation_id = header.correlation_id.0,
218                                        "Got response for unknown request",
219                                    );
220                                    continue;
221                                }
222                            }
223                            MessengerState::Poison(_) => {
224                                // stream is poisoned, no need to anything
225                                return;
226                            }
227                        };
228
229                        // optionally read tagged fields from the header as well
230                        if active_request.use_tagged_fields_in_response {
231                            header.tagged_fields = match TaggedFields::read(&mut cursor) {
232                                Ok(fields) => Some(fields),
233                                Err(e) => {
234                                    // we don't care if the other side is gone
235                                    active_request
236                                        .channel
237                                        .send(Err(RequestError::ReadError(e)))
238                                        .ok();
239                                    continue;
240                                }
241                            };
242                        }
243
244                        // we don't care if the other side is gone
245                        active_request
246                            .channel
247                            .send(Ok(Response {
248                                header,
249                                data: cursor,
250                            }))
251                            .ok();
252                    }
253                    Err(e) => {
254                        state_captured
255                            .lock()
256                            .poison(RequestError::ReadFramedMessageError(e));
257                        return;
258                    }
259                }
260            }
261        });
262
263        Self {
264            stream_write: Arc::new(AsyncMutex::new(stream_write)),
265            client_id,
266            correlation_id: AtomicI32::new(0),
267            version_ranges: HashMap::new(),
268            state,
269            join_handle,
270        }
271    }
272
273    #[cfg(feature = "unstable-fuzzing")]
274    pub fn override_version_ranges(&mut self, ranges: HashMap<ApiKey, ApiVersionRange>) {
275        self.set_version_ranges(ranges);
276    }
277
278    /// Set supported version range.
279    fn set_version_ranges(&mut self, ranges: HashMap<ApiKey, ApiVersionRange>) {
280        self.version_ranges = ranges;
281    }
282
283    pub async fn request<R>(&self, msg: R) -> Result<R::ResponseBody, RequestError>
284    where
285        R: RequestBody + Send + WriteVersionedType<Vec<u8>>,
286        R::ResponseBody: ReadVersionedType<Cursor<Vec<u8>>>,
287    {
288        self.request_with_version_ranges(msg, &self.version_ranges)
289            .await
290    }
291
292    async fn request_with_version_ranges<R>(
293        &self,
294        msg: R,
295        version_ranges: &HashMap<ApiKey, ApiVersionRange>,
296    ) -> Result<R::ResponseBody, RequestError>
297    where
298        R: RequestBody + Send + WriteVersionedType<Vec<u8>>,
299        R::ResponseBody: ReadVersionedType<Cursor<Vec<u8>>>,
300    {
301        let body_api_version = version_ranges
302            .get(&R::API_KEY)
303            .and_then(|range_server| match_versions(*range_server, R::API_VERSION_RANGE))
304            .ok_or(RequestError::NoVersionMatch {
305                api_key: R::API_KEY,
306            })?;
307
308        // determine if our request and response headers shall contain tagged fields. This system is borrowed from
309        // rdkafka ("flexver"), see:
310        // - https://github.com/edenhill/librdkafka/blob/2b76b65212e5efda213961d5f84e565038036270/src/rdkafka_request.c#L973
311        // - https://github.com/edenhill/librdkafka/blob/2b76b65212e5efda213961d5f84e565038036270/src/rdkafka_buf.c#L167-L174
312        let use_tagged_fields_in_request =
313            body_api_version >= R::FIRST_TAGGED_FIELD_IN_REQUEST_VERSION;
314        let use_tagged_fields_in_response =
315            body_api_version >= R::FIRST_TAGGED_FIELD_IN_RESPONSE_VERSION;
316
317        // Correlation ID so that we can de-multiplex the responses.
318        let correlation_id = self.correlation_id.fetch_add(1, Ordering::SeqCst);
319
320        let header = RequestHeader {
321            request_api_key: R::API_KEY,
322            request_api_version: body_api_version,
323            correlation_id: Int32(correlation_id),
324            // Technically we don't need to send a client_id, but newer redpanda version fail to parse the message
325            // without it. See https://github.com/influxdata/rskafka/issues/169 .
326            client_id: Some(NullableString(Some(String::from(self.client_id.as_ref())))),
327            tagged_fields: Some(TaggedFields::default()),
328        };
329        let header_version = if use_tagged_fields_in_request {
330            ApiVersion(Int16(2))
331        } else {
332            ApiVersion(Int16(1))
333        };
334
335        let mut buf = Vec::new();
336        header
337            .write_versioned(&mut buf, header_version)
338            .expect("Writing header to buffer should always work");
339        msg.write_versioned(&mut buf, body_api_version)?;
340
341        let (tx, rx) = channel();
342
343        // to prevent stale data in inner state, ensure that we would remove the request again if we are cancelled while
344        // sending the request
345        let cleanup_on_cancel =
346            CleanupRequestStateOnCancel::new(Arc::clone(&self.state), correlation_id);
347
348        match self.state.lock().deref_mut() {
349            MessengerState::RequestMap(map) => {
350                map.insert(
351                    correlation_id,
352                    ActiveRequest {
353                        channel: tx,
354                        use_tagged_fields_in_response,
355                    },
356                );
357            }
358            MessengerState::Poison(e) => {
359                return Err(RequestError::Poisoned(Arc::clone(e)));
360            }
361        }
362
363        self.send_message(buf).await?;
364        cleanup_on_cancel.message_sent();
365
366        let mut response = rx.await.expect("Who closed this channel?!")?;
367        let body = R::ResponseBody::read_versioned(&mut response.data, body_api_version)?;
368
369        // check if we fully consumed the message, otherwise there might be a bug in our protocol code
370        let read_bytes = response.data.position();
371        let message_bytes = response.data.into_inner().len() as u64;
372        if read_bytes != message_bytes {
373            return Err(RequestError::TooMuchData {
374                message_size: message_bytes,
375                read: read_bytes,
376                api_key: R::API_KEY,
377                api_version: body_api_version,
378            });
379        }
380
381        Ok(body)
382    }
383
384    async fn send_message(&self, msg: Vec<u8>) -> Result<(), RequestError> {
385        match self.send_message_inner(msg).await {
386            Ok(()) => Ok(()),
387            Err(e) => {
388                // need to poison the stream because message framing might be out-of-sync
389                let mut state = self.state.lock();
390                Err(RequestError::Poisoned(state.poison(e)))
391            }
392        }
393    }
394
395    async fn send_message_inner(&self, msg: Vec<u8>) -> Result<(), RequestError> {
396        let mut stream_write = Arc::clone(&self.stream_write).lock_owned().await;
397
398        // use a wrapper so that cancelation doesn't cancel the send operation and leaves half-send messages on the wire
399        let fut = CancellationSafeFuture::new(async move {
400            stream_write.write_message(&msg).await?;
401            stream_write.flush().await?;
402            Ok(())
403        });
404
405        fut.await
406    }
407
408    /// Sync supported version range.
409    ///
410    /// Takes `&self mut` to ensure exclusive access.
411    pub async fn sync_versions(&mut self) -> Result<(), SyncVersionsError> {
412        'iter_upper_bound: for upper_bound in (ApiVersionsRequest::API_VERSION_RANGE.min().0 .0
413            ..=ApiVersionsRequest::API_VERSION_RANGE.max().0 .0)
414            .rev()
415        {
416            let version_ranges = HashMap::from([(
417                ApiKey::ApiVersions,
418                ApiVersionRange::new(
419                    ApiVersionsRequest::API_VERSION_RANGE.min(),
420                    ApiVersion(Int16(upper_bound)),
421                ),
422            )]);
423
424            let body = ApiVersionsRequest {
425                client_software_name: Some(CompactString(String::from(env!("CARGO_PKG_NAME")))),
426                client_software_version: Some(CompactString(String::from(env!(
427                    "CARGO_PKG_VERSION"
428                )))),
429                tagged_fields: Some(TaggedFields::default()),
430            };
431
432            'throttle: loop {
433                match self
434                    .request_with_version_ranges(&body, &version_ranges)
435                    .await
436                {
437                    Ok(response) => {
438                        if let Err(ErrorOrThrottle::Throttle(throttle)) =
439                            maybe_throttle::<SyncVersionsError>(response.throttle_time_ms)
440                        {
441                            info!(
442                                ?throttle,
443                                request_name = "version sync",
444                                "broker asked us to throttle"
445                            );
446                            tokio::time::sleep(throttle).await;
447                            continue 'throttle;
448                        }
449
450                        if let Some(e) = response.error_code {
451                            debug!(
452                                %e,
453                                version=upper_bound,
454                                "Got error during version sync, cannot use version for ApiVersionRequest",
455                            );
456                            continue 'iter_upper_bound;
457                        }
458
459                        // check range sanity
460                        for api_key in &response.api_keys {
461                            if api_key.min_version.0 > api_key.max_version.0 {
462                                return Err(SyncVersionsError::FlippedVersionRange {
463                                    api_key: api_key.api_key,
464                                    min: api_key.min_version,
465                                    max: api_key.max_version,
466                                });
467                            }
468                        }
469
470                        let ranges = response
471                            .api_keys
472                            .into_iter()
473                            .map(|x| {
474                                (
475                                    x.api_key,
476                                    ApiVersionRange::new(x.min_version, x.max_version),
477                                )
478                            })
479                            .collect();
480                        debug!(
481                            versions=%sorted_ranges_repr(&ranges),
482                            "Detected supported broker versions",
483                        );
484                        self.set_version_ranges(ranges);
485                        return Ok(());
486                    }
487                    Err(RequestError::NoVersionMatch { .. }) => {
488                        unreachable!("Just set to version range to a non-empty range")
489                    }
490                    Err(RequestError::ReadVersionedError(e)) => {
491                        debug!(
492                            %e,
493                            version=upper_bound,
494                            "Cannot read ApiVersionResponse for version",
495                        );
496                        continue 'iter_upper_bound;
497                    }
498                    Err(RequestError::ReadError(e)) => {
499                        debug!(
500                            %e,
501                            version=upper_bound,
502                            "Cannot read ApiVersionResponse for version",
503                        );
504                        continue 'iter_upper_bound;
505                    }
506                    Err(e @ RequestError::TooMuchData { .. }) => {
507                        debug!(
508                            %e,
509                            version=upper_bound,
510                            "Cannot read ApiVersionResponse for version",
511                        );
512                        continue 'iter_upper_bound;
513                    }
514                    Err(e) => {
515                        return Err(SyncVersionsError::RequestError(e));
516                    }
517                }
518            }
519        }
520
521        Err(SyncVersionsError::NoWorkingVersion)
522    }
523}
524
525impl<RW> Drop for Messenger<RW> {
526    fn drop(&mut self) {
527        self.join_handle.abort();
528    }
529}
530
531fn sorted_ranges_repr(ranges: &HashMap<ApiKey, ApiVersionRange>) -> String {
532    let mut ranges: Vec<_> = ranges.iter().map(|(key, range)| (*key, *range)).collect();
533    ranges.sort_by_key(|(key, _range)| *key);
534    let ranges: Vec<_> = ranges
535        .into_iter()
536        .map(|(key, range)| format!("{:?}: {}", key, range))
537        .collect();
538    ranges.join(", ")
539}
540
541fn match_versions(range_a: ApiVersionRange, range_b: ApiVersionRange) -> Option<ApiVersion> {
542    if range_a.min() <= range_b.max() && range_b.min() <= range_a.max() {
543        Some(range_a.max().min(range_b.max()))
544    } else {
545        None
546    }
547}
548
549/// Helper that ensures that a request is removed when a request is cancelled before it was actually sent out.
550struct CleanupRequestStateOnCancel {
551    state: Arc<Mutex<MessengerState>>,
552    correlation_id: i32,
553    message_sent: bool,
554}
555
556impl CleanupRequestStateOnCancel {
557    /// Create new helper.
558    ///
559    /// You must call [`message_sent`](Self::message_sent) when the request was sent.
560    fn new(state: Arc<Mutex<MessengerState>>, correlation_id: i32) -> Self {
561        Self {
562            state,
563            correlation_id,
564            message_sent: false,
565        }
566    }
567
568    /// Request was sent. Do NOT clean the state any longer.
569    fn message_sent(mut self) {
570        self.message_sent = true;
571    }
572}
573
574impl Drop for CleanupRequestStateOnCancel {
575    fn drop(&mut self) {
576        if !self.message_sent {
577            if let MessengerState::RequestMap(map) = self.state.lock().deref_mut() {
578                map.remove(&self.correlation_id);
579            }
580        }
581    }
582}
583
584/// Wrapper around a future that cannot be cancelled.
585///
586/// When the future is dropped/cancelled, we'll spawn a tokio task to _rescue_ it.
587struct CancellationSafeFuture<F>
588where
589    F: Future + Send + 'static,
590{
591    /// Mark if the inner future finished. If not, we must spawn a helper task on drop.
592    done: bool,
593
594    /// Inner future.
595    ///
596    /// Wrapped in an `Option` so we can extract it during drop. Inside that option however we also need a pinned
597    /// box because once this wrapper is polled, it will be pinned in memory -- even during drop. Now the inner
598    /// future does not necessarily implement `Unpin`, so we need a heap allocation to pin it in memory even when we
599    /// move it out of this option.
600    inner: Option<BoxFuture<'static, F::Output>>,
601}
602
603impl<F> Drop for CancellationSafeFuture<F>
604where
605    F: Future + Send + 'static,
606{
607    fn drop(&mut self) {
608        if !self.done {
609            let inner = self.inner.take().expect("Double-drop?");
610            tokio::task::spawn(async move {
611                inner.await;
612            });
613        }
614    }
615}
616
617impl<F> CancellationSafeFuture<F>
618where
619    F: Future + Send,
620{
621    fn new(fut: F) -> Self {
622        Self {
623            done: false,
624            inner: Some(Box::pin(fut)),
625        }
626    }
627}
628
629impl<F> Future for CancellationSafeFuture<F>
630where
631    F: Future + Send,
632{
633    type Output = F::Output;
634
635    fn poll(
636        mut self: std::pin::Pin<&mut Self>,
637        cx: &mut std::task::Context<'_>,
638    ) -> Poll<Self::Output> {
639        match self.inner.as_mut().expect("no dropped").as_mut().poll(cx) {
640            Poll::Ready(res) => {
641                self.done = true;
642                Poll::Ready(res)
643            }
644            Poll::Pending => Poll::Pending,
645        }
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use std::time::Duration;
652
653    use assert_matches::assert_matches;
654    use futures::{pin_mut, FutureExt};
655    use tokio::{
656        io::{AsyncReadExt, DuplexStream},
657        sync::{mpsc::UnboundedSender, Barrier},
658    };
659
660    use super::*;
661
662    use crate::{
663        build_info::DEFAULT_CLIENT_ID,
664        protocol::{
665            error::Error as ApiError,
666            messages::{
667                ApiVersionsResponse, ApiVersionsResponseApiKey, ListOffsetsRequest, NORMAL_CONSUMER,
668            },
669            traits::WriteType,
670        },
671    };
672
673    #[test]
674    fn test_match_versions() {
675        assert_eq!(
676            match_versions(
677                ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(20))),
678                ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(20))),
679            ),
680            Some(ApiVersion(Int16(20))),
681        );
682
683        assert_eq!(
684            match_versions(
685                ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(15))),
686                ApiVersionRange::new(ApiVersion(Int16(13)), ApiVersion(Int16(20))),
687            ),
688            Some(ApiVersion(Int16(15))),
689        );
690
691        assert_eq!(
692            match_versions(
693                ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(15))),
694                ApiVersionRange::new(ApiVersion(Int16(15)), ApiVersion(Int16(20))),
695            ),
696            Some(ApiVersion(Int16(15))),
697        );
698
699        assert_eq!(
700            match_versions(
701                ApiVersionRange::new(ApiVersion(Int16(10)), ApiVersion(Int16(14))),
702                ApiVersionRange::new(ApiVersion(Int16(15)), ApiVersion(Int16(20))),
703            ),
704            None,
705        );
706    }
707
708    #[tokio::test]
709    async fn test_sync_versions_ok() {
710        let (sim, rx) = MessageSimulator::new();
711        let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
712
713        // construct response
714        let mut msg = vec![];
715        ResponseHeader {
716            correlation_id: Int32(0),
717            tagged_fields: Default::default(), // NOT serialized for ApiVersion!
718        }
719        .write_versioned(&mut msg, ApiVersion(Int16(0)))
720        .unwrap();
721        ApiVersionsResponse {
722            error_code: None,
723            api_keys: vec![ApiVersionsResponseApiKey {
724                api_key: ApiKey::Produce,
725                min_version: ApiVersion(Int16(1)),
726                max_version: ApiVersion(Int16(5)),
727                tagged_fields: Default::default(),
728            }],
729            throttle_time_ms: None,
730            tagged_fields: None,
731        }
732        .write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
733        .unwrap();
734        sim.push(msg);
735
736        // sync versions
737        messenger.sync_versions().await.unwrap();
738        let expected = HashMap::from([(
739            (ApiKey::Produce),
740            ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
741        )]);
742        assert_eq!(messenger.version_ranges, expected);
743    }
744
745    #[tokio::test]
746    async fn test_sync_versions_ignores_error_code() {
747        let (sim, rx) = MessageSimulator::new();
748        let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
749
750        // construct error response
751        let mut msg = vec![];
752        ResponseHeader {
753            correlation_id: Int32(0),
754            tagged_fields: Default::default(), // NOT serialized for ApiVersion!
755        }
756        .write_versioned(&mut msg, ApiVersion(Int16(0)))
757        .unwrap();
758        ApiVersionsResponse {
759            error_code: Some(ApiError::CorruptMessage),
760            api_keys: vec![ApiVersionsResponseApiKey {
761                api_key: ApiKey::Produce,
762                min_version: ApiVersion(Int16(2)),
763                max_version: ApiVersion(Int16(3)),
764                tagged_fields: Default::default(),
765            }],
766            throttle_time_ms: None,
767            tagged_fields: None,
768        }
769        .write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
770        .unwrap();
771        sim.push(msg);
772
773        // construct good response
774        let mut msg = vec![];
775        ResponseHeader {
776            correlation_id: Int32(1),
777            tagged_fields: Default::default(),
778        }
779        .write_versioned(&mut msg, ApiVersion(Int16(0)))
780        .unwrap();
781        ApiVersionsResponse {
782            error_code: None,
783            api_keys: vec![ApiVersionsResponseApiKey {
784                api_key: ApiKey::Produce,
785                min_version: ApiVersion(Int16(1)),
786                max_version: ApiVersion(Int16(5)),
787                tagged_fields: Default::default(),
788            }],
789            throttle_time_ms: None,
790            tagged_fields: None,
791        }
792        .write_versioned(
793            &mut msg,
794            ApiVersion(Int16(ApiVersionsRequest::API_VERSION_RANGE.max().0 .0 - 1)),
795        )
796        .unwrap();
797        sim.push(msg);
798
799        // sync versions
800        messenger.sync_versions().await.unwrap();
801        let expected = HashMap::from([(
802            (ApiKey::Produce),
803            ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
804        )]);
805        assert_eq!(messenger.version_ranges, expected);
806    }
807
808    #[tokio::test]
809    async fn test_sync_versions_ignores_read_code() {
810        let (sim, rx) = MessageSimulator::new();
811        let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
812
813        // construct error response
814        let mut msg = vec![];
815        ResponseHeader {
816            correlation_id: Int32(0),
817            tagged_fields: Default::default(), // NOT serialized for ApiVersion!
818        }
819        .write_versioned(&mut msg, ApiVersion(Int16(0)))
820        .unwrap();
821        msg.push(b'\0'); // malformed message body which can happen if the server doesn't really support this version
822        sim.push(msg);
823
824        // construct good response
825        let mut msg = vec![];
826        ResponseHeader {
827            correlation_id: Int32(1),
828            tagged_fields: Default::default(),
829        }
830        .write_versioned(&mut msg, ApiVersion(Int16(0)))
831        .unwrap();
832        ApiVersionsResponse {
833            error_code: None,
834            api_keys: vec![ApiVersionsResponseApiKey {
835                api_key: ApiKey::Produce,
836                min_version: ApiVersion(Int16(1)),
837                max_version: ApiVersion(Int16(5)),
838                tagged_fields: Default::default(),
839            }],
840            throttle_time_ms: None,
841            tagged_fields: None,
842        }
843        .write_versioned(
844            &mut msg,
845            ApiVersion(Int16(ApiVersionsRequest::API_VERSION_RANGE.max().0 .0 - 1)),
846        )
847        .unwrap();
848        sim.push(msg);
849
850        // sync versions
851        messenger.sync_versions().await.unwrap();
852        let expected = HashMap::from([(
853            (ApiKey::Produce),
854            ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
855        )]);
856        assert_eq!(messenger.version_ranges, expected);
857    }
858
859    #[tokio::test]
860    async fn test_sync_versions_err_flipped_range() {
861        let (sim, rx) = MessageSimulator::new();
862        let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
863
864        // construct response
865        let mut msg = vec![];
866        ResponseHeader {
867            correlation_id: Int32(0),
868            tagged_fields: Default::default(), // NOT serialized for ApiVersion!
869        }
870        .write_versioned(&mut msg, ApiVersion(Int16(0)))
871        .unwrap();
872        ApiVersionsResponse {
873            error_code: None,
874            api_keys: vec![ApiVersionsResponseApiKey {
875                api_key: ApiKey::Produce,
876                min_version: ApiVersion(Int16(2)),
877                max_version: ApiVersion(Int16(1)),
878                tagged_fields: Default::default(),
879            }],
880            throttle_time_ms: None,
881            tagged_fields: None,
882        }
883        .write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
884        .unwrap();
885        sim.push(msg);
886
887        // sync versions
888        let err = messenger.sync_versions().await.unwrap_err();
889        assert_matches!(err, SyncVersionsError::FlippedVersionRange { .. });
890    }
891
892    #[tokio::test]
893    async fn test_sync_versions_ignores_garbage() {
894        let (sim, rx) = MessageSimulator::new();
895        let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
896
897        // construct response
898        let mut msg = vec![];
899        ResponseHeader {
900            correlation_id: Int32(0),
901            tagged_fields: Default::default(), // NOT serialized for ApiVersion!
902        }
903        .write_versioned(&mut msg, ApiVersion(Int16(0)))
904        .unwrap();
905        ApiVersionsResponse {
906            error_code: None,
907            api_keys: vec![ApiVersionsResponseApiKey {
908                api_key: ApiKey::Produce,
909                min_version: ApiVersion(Int16(1)),
910                max_version: ApiVersion(Int16(2)),
911                tagged_fields: Default::default(),
912            }],
913            throttle_time_ms: None,
914            tagged_fields: None,
915        }
916        .write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
917        .unwrap();
918        msg.push(b'\0'); // add junk to the end of the message to trigger `TooMuchData`
919        sim.push(msg);
920
921        // construct good response
922        let mut msg = vec![];
923        ResponseHeader {
924            correlation_id: Int32(1),
925            tagged_fields: Default::default(),
926        }
927        .write_versioned(&mut msg, ApiVersion(Int16(0)))
928        .unwrap();
929        ApiVersionsResponse {
930            error_code: None,
931            api_keys: vec![ApiVersionsResponseApiKey {
932                api_key: ApiKey::Produce,
933                min_version: ApiVersion(Int16(1)),
934                max_version: ApiVersion(Int16(5)),
935                tagged_fields: Default::default(),
936            }],
937            throttle_time_ms: None,
938            tagged_fields: None,
939        }
940        .write_versioned(
941            &mut msg,
942            ApiVersion(Int16(ApiVersionsRequest::API_VERSION_RANGE.max().0 .0 - 1)),
943        )
944        .unwrap();
945        sim.push(msg);
946
947        // sync versions
948        messenger.sync_versions().await.unwrap();
949        let expected = HashMap::from([(
950            (ApiKey::Produce),
951            ApiVersionRange::new(ApiVersion(Int16(1)), ApiVersion(Int16(5))),
952        )]);
953        assert_eq!(messenger.version_ranges, expected);
954    }
955
956    #[tokio::test]
957    async fn test_sync_versions_err_no_working_version() {
958        let (sim, rx) = MessageSimulator::new();
959        let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
960
961        // construct error response
962        for (i, v) in ((ApiVersionsRequest::API_VERSION_RANGE.min().0 .0)
963            ..=(ApiVersionsRequest::API_VERSION_RANGE.max().0 .0))
964            .rev()
965            .enumerate()
966        {
967            let mut msg = vec![];
968            ResponseHeader {
969                correlation_id: Int32(i as i32),
970                tagged_fields: Default::default(),
971            }
972            .write_versioned(&mut msg, ApiVersion(Int16(0)))
973            .unwrap();
974            ApiVersionsResponse {
975                error_code: Some(ApiError::CorruptMessage),
976                api_keys: vec![ApiVersionsResponseApiKey {
977                    api_key: ApiKey::Produce,
978                    min_version: ApiVersion(Int16(1)),
979                    max_version: ApiVersion(Int16(5)),
980                    tagged_fields: Default::default(),
981                }],
982                throttle_time_ms: None,
983                tagged_fields: None,
984            }
985            .write_versioned(&mut msg, ApiVersion(Int16(v)))
986            .unwrap();
987            sim.push(msg);
988        }
989
990        // sync versions
991        let err = messenger.sync_versions().await.unwrap_err();
992        assert_matches!(err, SyncVersionsError::NoWorkingVersion);
993    }
994
995    #[tokio::test]
996    async fn test_poison_hangup() {
997        let (sim, rx) = MessageSimulator::new();
998        let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
999        messenger.set_version_ranges(HashMap::from([(
1000            ApiKey::ListOffsets,
1001            ListOffsetsRequest::API_VERSION_RANGE,
1002        )]));
1003
1004        sim.hang_up();
1005
1006        let err = messenger
1007            .request(ListOffsetsRequest {
1008                replica_id: NORMAL_CONSUMER,
1009                isolation_level: None,
1010                topics: vec![],
1011            })
1012            .await
1013            .unwrap_err();
1014        assert_matches!(err, RequestError::Poisoned(_));
1015    }
1016
1017    #[tokio::test]
1018    async fn test_poison_negative_message_size() {
1019        let (sim, rx) = MessageSimulator::new();
1020        let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1021        messenger.set_version_ranges(HashMap::from([(
1022            ApiKey::ListOffsets,
1023            ListOffsetsRequest::API_VERSION_RANGE,
1024        )]));
1025
1026        sim.negative_message_size();
1027
1028        let err = messenger
1029            .request(ListOffsetsRequest {
1030                replica_id: NORMAL_CONSUMER,
1031                isolation_level: None,
1032                topics: vec![],
1033            })
1034            .await
1035            .unwrap_err();
1036        assert_matches!(err, RequestError::Poisoned(_));
1037
1038        // follow-up message is broken as well
1039        let err = messenger
1040            .request(ListOffsetsRequest {
1041                replica_id: NORMAL_CONSUMER,
1042                isolation_level: None,
1043                topics: vec![],
1044            })
1045            .await
1046            .unwrap_err();
1047        assert_matches!(err, RequestError::Poisoned(_));
1048    }
1049
1050    #[tokio::test]
1051    async fn test_broken_msg_header_does_not_poison() {
1052        let (sim, rx) = MessageSimulator::new();
1053        let mut messenger = Messenger::new(rx, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1054        messenger.set_version_ranges(HashMap::from([(
1055            ApiKey::ApiVersions,
1056            ApiVersionsRequest::API_VERSION_RANGE,
1057        )]));
1058
1059        // garbage
1060        sim.send(b"foo".to_vec());
1061
1062        // construct good response
1063        let mut msg = vec![];
1064        ResponseHeader {
1065            correlation_id: Int32(0),
1066            tagged_fields: Default::default(), // NOT serialized for ApiVersion!
1067        }
1068        .write_versioned(&mut msg, ApiVersion(Int16(0)))
1069        .unwrap();
1070        let resp = ApiVersionsResponse {
1071            error_code: Some(ApiError::CorruptMessage),
1072            api_keys: vec![],
1073            throttle_time_ms: Some(Int32(1)),
1074            tagged_fields: Some(TaggedFields::default()),
1075        };
1076        resp.write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.max())
1077            .unwrap();
1078        sim.push(msg);
1079
1080        let actual = messenger
1081            .request(ApiVersionsRequest {
1082                client_software_name: Some(CompactString(String::new())),
1083                client_software_version: Some(CompactString(String::new())),
1084                tagged_fields: Some(TaggedFields::default()),
1085            })
1086            .await
1087            .unwrap();
1088        assert_eq!(actual, resp);
1089    }
1090
1091    #[tokio::test]
1092    async fn test_cancel_request() {
1093        // Use a "virtual" network between a simulated broker and a client. The network is intercepted in the middle to
1094        // pause it after 3 bytes are sent by the client.
1095        let (tx_front, rx_middle) = tokio::io::duplex(1);
1096        let (tx_middle, mut rx_back) = tokio::io::duplex(1);
1097
1098        let mut messenger = Messenger::new(tx_front, 1_000, Arc::from(DEFAULT_CLIENT_ID));
1099
1100        // create two barriers:
1101        // - pause: will be passed after 3 bytes were sent by the client
1102        // - continue: will be passed to continue client->broker traffic
1103        //
1104        // The barriers do NOT affect broker->client traffic.
1105        //
1106        // The sizes of the barriers is 2: one for the network simulation task and one for the main/control thread.
1107        let network_pause = Arc::new(Barrier::new(2));
1108        let network_pause_captured = Arc::clone(&network_pause);
1109        let network_continue = Arc::new(Barrier::new(2));
1110        let network_continue_captured = Arc::clone(&network_continue);
1111        let handle_network = tokio::spawn(async move {
1112            // Need to split both directions into read and write halfs so we can run full-duplex in two loops. Otherwise
1113            // the test might deadlock even though the code is just fine (TCP is full-duplex).
1114            let (mut rx_middle_read, mut rx_middle_write) = tokio::io::split(rx_middle);
1115            let (mut tx_middle_read, mut tx_middle_write) = tokio::io::split(tx_middle);
1116
1117            let direction_client_broker = async {
1118                for i in 0.. {
1119                    let mut buf = [0; 1];
1120                    rx_middle_read.read_exact(&mut buf).await.unwrap();
1121                    tx_middle_write.write_all(&buf).await.unwrap();
1122
1123                    if i == 3 {
1124                        network_pause_captured.wait().await;
1125                        network_continue_captured.wait().await;
1126                    }
1127                }
1128            };
1129
1130            let direction_broker_client = async {
1131                loop {
1132                    let mut buf = [0; 1];
1133                    tx_middle_read.read_exact(&mut buf).await.unwrap();
1134                    rx_middle_write.write_all(&buf).await.unwrap();
1135                }
1136            };
1137
1138            tokio::select! {
1139                _ = direction_client_broker => {}
1140                _ = direction_broker_client => {}
1141            }
1142        });
1143
1144        // simulated broker, just reads messages and answers w/ "api versions" responses
1145        let handle_broker = tokio::spawn(async move {
1146            for correlation_id in 0.. {
1147                let data = rx_back.read_message(1_000).await.unwrap();
1148                let mut data = Cursor::new(data);
1149                let header =
1150                    RequestHeader::read_versioned(&mut data, ApiVersion(Int16(1))).unwrap();
1151                assert_eq!(
1152                    header,
1153                    RequestHeader {
1154                        request_api_key: ApiKey::ApiVersions,
1155                        request_api_version: ApiVersion(Int16(0)),
1156                        correlation_id: Int32(correlation_id),
1157                        client_id: Some(NullableString(Some(String::from(env!("CARGO_PKG_NAME"))))),
1158                        tagged_fields: None,
1159                    }
1160                );
1161                let body =
1162                    ApiVersionsRequest::read_versioned(&mut data, ApiVersion(Int16(0))).unwrap();
1163                assert_eq!(
1164                    body,
1165                    ApiVersionsRequest {
1166                        client_software_name: None,
1167                        client_software_version: None,
1168                        tagged_fields: None,
1169                    }
1170                );
1171                assert_eq!(data.position() as usize, data.get_ref().len());
1172
1173                let mut msg = vec![];
1174                ResponseHeader {
1175                    correlation_id: Int32(correlation_id),
1176                    tagged_fields: Default::default(), // NOT serialized for ApiVersion!
1177                }
1178                .write_versioned(&mut msg, ApiVersion(Int16(0)))
1179                .unwrap();
1180                let resp = ApiVersionsResponse {
1181                    error_code: Some(ApiError::CorruptMessage),
1182                    api_keys: vec![],
1183                    throttle_time_ms: Some(Int32(1)),
1184                    tagged_fields: Some(TaggedFields::default()),
1185                };
1186                resp.write_versioned(&mut msg, ApiVersionsRequest::API_VERSION_RANGE.min())
1187                    .unwrap();
1188                rx_back.write_message(&msg).await.unwrap();
1189            }
1190        });
1191
1192        messenger.set_version_ranges(HashMap::from([(
1193            ApiKey::ApiVersions,
1194            ApiVersionRange::new(ApiVersion(Int16(0)), ApiVersion(Int16(0))),
1195        )]));
1196
1197        // send first message, this task will be canceled after 3 bytes got sent.
1198        let task_to_cancel = (async {
1199            messenger
1200                .request(ApiVersionsRequest {
1201                    client_software_name: Some(CompactString(String::from("foo"))),
1202                    client_software_version: Some(CompactString(String::from("bar"))),
1203                    tagged_fields: Some(TaggedFields::default()),
1204                })
1205                .await
1206                .unwrap();
1207        })
1208        .fuse();
1209
1210        {
1211            // cancel when we exit this block
1212            pin_mut!(task_to_cancel);
1213
1214            // write exactly 3 bytes via the client and then cancel the task.
1215            futures::select_biased! {
1216                _ = &mut task_to_cancel => panic!("should not have finished"),
1217                _ = network_pause.wait().fuse() => {},
1218            }
1219        }
1220
1221        // continue client->broker traffic
1222        network_continue.wait().await;
1223
1224        // IF the original bug in https://github.com/influxdata/rskafka/issues/103 exists, then the following statement
1225        // will timeout because the broker got garbage and will wait forever to read the message.
1226        tokio::time::timeout(Duration::from_millis(100), async {
1227            messenger
1228                .request(ApiVersionsRequest {
1229                    client_software_name: Some(CompactString(String::from("foo"))),
1230                    client_software_version: Some(CompactString(String::from("bar"))),
1231                    tagged_fields: Some(TaggedFields::default()),
1232                })
1233                .await
1234                .unwrap();
1235        })
1236        .await
1237        .unwrap();
1238
1239        // clean up helper tasks
1240        handle_broker.abort();
1241        handle_network.abort();
1242    }
1243
1244    #[derive(Debug)]
1245    enum Message {
1246        Send(Vec<u8>),
1247        Consume,
1248        NegativeMessageSize,
1249        HangUp,
1250    }
1251
1252    struct MessageSimulator {
1253        messages: UnboundedSender<Message>,
1254        join_handle: JoinHandle<()>,
1255    }
1256
1257    impl MessageSimulator {
1258        fn new() -> (Self, DuplexStream) {
1259            let (mut tx, rx) = tokio::io::duplex(1_000);
1260            let (msg_tx, mut msg_rx) = tokio::sync::mpsc::unbounded_channel();
1261
1262            let join_handle = tokio::task::spawn(async move {
1263                loop {
1264                    let message = match msg_rx.recv().await {
1265                        Some(msg) => msg,
1266                        None => return,
1267                    };
1268
1269                    match message {
1270                        Message::Consume => {
1271                            tx.read_message(1_000).await.unwrap();
1272                        }
1273                        Message::Send(data) => {
1274                            tx.write_message(&data).await.unwrap();
1275                        }
1276                        Message::NegativeMessageSize => {
1277                            let mut buf = vec![];
1278                            Int32(-1).write(&mut buf).unwrap();
1279                            tx.write_all(&buf).await.unwrap()
1280                        }
1281                        Message::HangUp => {
1282                            return;
1283                        }
1284                    }
1285                }
1286            });
1287
1288            let this = Self {
1289                messages: msg_tx,
1290                join_handle,
1291            };
1292            (this, rx)
1293        }
1294
1295        fn push(&self, msg: Vec<u8>) {
1296            // Must wait for the request message before reading the response, otherwise `Messenger` might read
1297            // our response that doesn't have a correlated request yet and throws it away. This is because
1298            // servers never send data without being asked to do so.
1299            self.consume();
1300            self.send(msg);
1301        }
1302
1303        fn consume(&self) {
1304            self.messages.send(Message::Consume).unwrap();
1305        }
1306
1307        fn send(&self, msg: Vec<u8>) {
1308            self.messages.send(Message::Send(msg)).unwrap();
1309        }
1310
1311        fn negative_message_size(&self) {
1312            self.messages.send(Message::NegativeMessageSize).unwrap();
1313        }
1314
1315        fn hang_up(&self) {
1316            self.messages.send(Message::HangUp).unwrap();
1317        }
1318    }
1319
1320    impl Drop for MessageSimulator {
1321        fn drop(&mut self) {
1322            // this will drop the future and therefore tx which will close th streams
1323            self.join_handle.abort();
1324        }
1325    }
1326}