lumina_node/p2p/
header_ex.rs

1use std::io;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4
5use async_trait::async_trait;
6use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
7use celestia_types::ExtendedHeader;
8use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9use libp2p::core::transport::PortUse;
10use libp2p::{
11    core::Endpoint,
12    request_response::{self, Codec, InboundFailure, OutboundFailure, ProtocolSupport},
13    swarm::{
14        handler::ConnectionEvent, ConnectionDenied, ConnectionHandler, ConnectionHandlerEvent,
15        ConnectionId, FromSwarm, NetworkBehaviour, SubstreamProtocol, THandlerInEvent,
16        THandlerOutEvent, ToSwarm,
17    },
18    Multiaddr, PeerId, StreamProtocol,
19};
20use lumina_utils::time::timeout;
21use prost::Message;
22use tracing::{debug, instrument, warn};
23use web_time::{Duration, Instant};
24
25mod client;
26mod server;
27pub(crate) mod utils;
28
29use crate::p2p::header_ex::client::HeaderExClientHandler;
30use crate::p2p::header_ex::server::HeaderExServerHandler;
31use crate::p2p::P2pError;
32use crate::peer_tracker::PeerTracker;
33use crate::store::Store;
34use crate::utils::{protocol_id, OneshotResultSender};
35
36/// Size limit of a request in bytes
37const REQUEST_SIZE_LIMIT: usize = 1024;
38/// Time limit on reading/writing a request
39const REQUEST_TIME_LIMIT: Duration = Duration::from_secs(1);
40/// Size limit of a response in bytes
41const RESPONSE_SIZE_LIMIT: usize = 10 * 1024 * 1024;
42/// Time limit on reading/writing a response
43const RESPONSE_TIME_LIMIT: Duration = Duration::from_secs(5);
44/// Substream negotiation timeout
45const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(1);
46
47type RequestType = HeaderRequest;
48type ResponseType = Vec<HeaderResponse>;
49type ReqRespBehaviour = request_response::Behaviour<HeaderCodec>;
50type ReqRespEvent = request_response::Event<RequestType, ResponseType>;
51type ReqRespMessage = request_response::Message<RequestType, ResponseType>;
52type ReqRespConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;
53
54pub(crate) struct HeaderExBehaviour<S>
55where
56    S: Store + 'static,
57{
58    req_resp: ReqRespBehaviour,
59    client_handler: HeaderExClientHandler,
60    server_handler: HeaderExServerHandler<S>,
61}
62
63pub(crate) struct HeaderExConfig<'a, S> {
64    pub network_id: &'a str,
65    pub peer_tracker: Arc<PeerTracker>,
66    pub header_store: Arc<S>,
67}
68
69/// Representation of all the errors that can occur in `HeaderEx` component.
70#[derive(Debug, thiserror::Error)]
71pub enum HeaderExError {
72    /// Header not found.
73    #[error("Header not found")]
74    HeaderNotFound,
75
76    /// The response is invalid.
77    #[error("Invalid response")]
78    InvalidResponse,
79
80    /// The request is invalid.
81    #[error("Invalid request")]
82    InvalidRequest,
83
84    /// Error when handling connection from the client.
85    #[error("Inbound failure: {0}")]
86    InboundFailure(InboundFailure),
87
88    /// Error when handling connection to the server.
89    #[error("Outbound failure: {0}")]
90    OutboundFailure(OutboundFailure),
91
92    /// Request cancelled because [`Node`] is stopping.
93    ///
94    /// [`Node`]: crate::node::Node
95    #[error("Request cancelled because `Node` is stopping")]
96    RequestCancelled,
97}
98
99impl<S> HeaderExBehaviour<S>
100where
101    S: Store + 'static,
102{
103    pub(crate) fn new(config: HeaderExConfig<'_, S>) -> Self {
104        HeaderExBehaviour {
105            req_resp: ReqRespBehaviour::new(
106                [(
107                    protocol_id(config.network_id, "/header-ex/v0.0.3"),
108                    ProtocolSupport::Full,
109                )],
110                request_response::Config::default(),
111            ),
112            client_handler: HeaderExClientHandler::new(config.peer_tracker),
113            server_handler: HeaderExServerHandler::new(config.header_store),
114        }
115    }
116
117    #[instrument(level = "trace", skip(self, respond_to))]
118    pub(crate) fn send_request(
119        &mut self,
120        request: HeaderRequest,
121        respond_to: OneshotResultSender<Vec<ExtendedHeader>, P2pError>,
122    ) {
123        self.client_handler
124            .on_send_request(&mut self.req_resp, request, respond_to);
125    }
126
127    pub(crate) fn stop(&mut self) {
128        self.client_handler.on_stop();
129        self.server_handler.on_stop();
130    }
131
132    fn on_to_swarm(
133        &mut self,
134        ev: ToSwarm<ReqRespEvent, THandlerInEvent<ReqRespBehaviour>>,
135    ) -> Option<ToSwarm<(), THandlerInEvent<Self>>> {
136        match ev {
137            ToSwarm::GenerateEvent(ev) => {
138                self.on_req_resp_event(ev);
139                None
140            }
141            _ => Some(ev.map_out(|_| ())),
142        }
143    }
144
145    #[instrument(level = "trace", skip_all)]
146    fn on_req_resp_event(&mut self, ev: ReqRespEvent) {
147        match ev {
148            // Received a response for an ongoing outbound request
149            ReqRespEvent::Message {
150                message:
151                    ReqRespMessage::Response {
152                        request_id,
153                        response,
154                    },
155                peer,
156            } => {
157                self.client_handler
158                    .on_response_received(peer, request_id, response);
159            }
160
161            // Failure while client requests
162            ReqRespEvent::OutboundFailure {
163                peer,
164                request_id,
165                error,
166            } => {
167                self.client_handler.on_failure(peer, request_id, error);
168            }
169
170            // Received new inbound request
171            ReqRespEvent::Message {
172                message:
173                    ReqRespMessage::Request {
174                        request_id,
175                        request,
176                        channel,
177                    },
178                peer,
179            } => {
180                self.server_handler.on_request_received(
181                    peer,
182                    request_id,
183                    request,
184                    &mut self.req_resp,
185                    channel,
186                );
187            }
188
189            // Response to inbound request was sent
190            ReqRespEvent::ResponseSent { peer, request_id } => {
191                self.server_handler.on_response_sent(peer, request_id);
192            }
193
194            // Failure while server responds
195            ReqRespEvent::InboundFailure {
196                peer,
197                request_id,
198                error,
199            } => {
200                self.server_handler.on_failure(peer, request_id, error);
201            }
202        }
203    }
204}
205
206impl<S> NetworkBehaviour for HeaderExBehaviour<S>
207where
208    S: Store + 'static,
209{
210    type ConnectionHandler = ConnHandler;
211    type ToSwarm = ();
212
213    fn handle_established_inbound_connection(
214        &mut self,
215        connection_id: ConnectionId,
216        peer: PeerId,
217        local_addr: &Multiaddr,
218        remote_addr: &Multiaddr,
219    ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
220        self.req_resp
221            .handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)
222            .map(ConnHandler)
223    }
224
225    fn handle_established_outbound_connection(
226        &mut self,
227        connection_id: ConnectionId,
228        peer: PeerId,
229        addr: &Multiaddr,
230        role_override: Endpoint,
231        port_use: PortUse,
232    ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
233        self.req_resp
234            .handle_established_outbound_connection(
235                connection_id,
236                peer,
237                addr,
238                role_override,
239                port_use,
240            )
241            .map(ConnHandler)
242    }
243
244    fn handle_pending_inbound_connection(
245        &mut self,
246        connection_id: ConnectionId,
247        local_addr: &Multiaddr,
248        remote_addr: &Multiaddr,
249    ) -> Result<(), ConnectionDenied> {
250        self.req_resp
251            .handle_pending_inbound_connection(connection_id, local_addr, remote_addr)
252    }
253
254    fn handle_pending_outbound_connection(
255        &mut self,
256        connection_id: ConnectionId,
257        maybe_peer: Option<PeerId>,
258        addresses: &[Multiaddr],
259        effective_role: Endpoint,
260    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
261        self.req_resp.handle_pending_outbound_connection(
262            connection_id,
263            maybe_peer,
264            addresses,
265            effective_role,
266        )
267    }
268
269    fn on_swarm_event(&mut self, event: FromSwarm) {
270        self.req_resp.on_swarm_event(event)
271    }
272
273    fn on_connection_handler_event(
274        &mut self,
275        peer_id: PeerId,
276        connection_id: ConnectionId,
277        event: THandlerOutEvent<Self>,
278    ) {
279        self.req_resp
280            .on_connection_handler_event(peer_id, connection_id, event)
281    }
282
283    fn poll(
284        &mut self,
285        cx: &mut Context<'_>,
286    ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
287        loop {
288            if let Poll::Ready(ev) = self.req_resp.poll(cx) {
289                if let Some(ev) = self.on_to_swarm(ev) {
290                    return Poll::Ready(ev);
291                }
292
293                continue;
294            }
295
296            if self.client_handler.poll(cx).is_ready() {
297                continue;
298            }
299
300            if self.server_handler.poll(cx, &mut self.req_resp).is_ready() {
301                continue;
302            }
303
304            return Poll::Pending;
305        }
306    }
307}
308
309pub(crate) struct ConnHandler(ReqRespConnectionHandler);
310
311impl ConnectionHandler for ConnHandler {
312    type ToBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::ToBehaviour;
313    type FromBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::FromBehaviour;
314    type InboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::InboundProtocol;
315    type InboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::InboundOpenInfo;
316    type OutboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::OutboundProtocol;
317    type OutboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::OutboundOpenInfo;
318
319    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
320        self.0.listen_protocol().with_timeout(NEGOTIATION_TIMEOUT)
321    }
322
323    fn poll(
324        &mut self,
325        cx: &mut Context<'_>,
326    ) -> Poll<
327        ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
328    > {
329        match self.0.poll(cx) {
330            Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
331                Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
332                    protocol: protocol.with_timeout(NEGOTIATION_TIMEOUT),
333                })
334            }
335            ev => ev,
336        }
337    }
338
339    fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
340        self.0.on_behaviour_event(event)
341    }
342
343    fn on_connection_event(
344        &mut self,
345        event: ConnectionEvent<
346            '_,
347            Self::InboundProtocol,
348            Self::OutboundProtocol,
349            Self::InboundOpenInfo,
350            Self::OutboundOpenInfo,
351        >,
352    ) {
353        self.0.on_connection_event(event)
354    }
355
356    fn connection_keep_alive(&self) -> bool {
357        self.0.connection_keep_alive()
358    }
359
360    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
361        self.0.poll_close(cx)
362    }
363}
364
365#[derive(Clone, Copy, Debug, Default)]
366pub(crate) struct HeaderCodec;
367
368#[async_trait]
369impl Codec for HeaderCodec {
370    type Protocol = StreamProtocol;
371    type Request = HeaderRequest;
372    type Response = Vec<HeaderResponse>;
373
374    async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
375    where
376        T: AsyncRead + Unpin + Send,
377    {
378        let data = read_up_to(io, REQUEST_SIZE_LIMIT, REQUEST_TIME_LIMIT).await?;
379
380        if data.len() >= REQUEST_SIZE_LIMIT {
381            debug!("Message filled the whole buffer (len: {})", data.len());
382        }
383
384        parse_header_request(&data).ok_or_else(|| {
385            // There are two cases that can reach here:
386            //
387            // 1. The request is invalid
388            // 2. The request is incomplete because of the size limit or time limit
389            io::Error::new(io::ErrorKind::Other, "invalid or incomplete request")
390        })
391    }
392
393    async fn read_response<T>(
394        &mut self,
395        _: &Self::Protocol,
396        io: &mut T,
397    ) -> io::Result<Self::Response>
398    where
399        T: AsyncRead + Unpin + Send,
400    {
401        let data = read_up_to(io, RESPONSE_SIZE_LIMIT, RESPONSE_TIME_LIMIT).await?;
402
403        if data.len() >= RESPONSE_SIZE_LIMIT {
404            debug!("Message filled the whole buffer (len: {})", data.len());
405        }
406
407        let mut data = &data[..];
408        let mut msgs = Vec::new();
409
410        while let Some((header, rest)) = parse_header_response(data) {
411            msgs.push(header);
412            data = rest;
413        }
414
415        if msgs.is_empty() {
416            // There are two cases that can reach here:
417            //
418            // 1. The response is invalid
419            // 2. The response is incomplete because of the size limit or time limit
420            return Err(io::Error::new(
421                io::ErrorKind::Other,
422                "invalid or incomplete response",
423            ));
424        }
425
426        Ok(msgs)
427    }
428
429    async fn write_request<T>(
430        &mut self,
431        _: &Self::Protocol,
432        io: &mut T,
433        req: Self::Request,
434    ) -> io::Result<()>
435    where
436        T: AsyncWrite + Unpin + Send,
437    {
438        let mut buf = Vec::with_capacity(REQUEST_SIZE_LIMIT);
439
440        let _ = req.encode_length_delimited(&mut buf);
441
442        timeout(REQUEST_TIME_LIMIT, io.write_all(&buf))
443            .await
444            .map_err(|_| io::Error::new(io::ErrorKind::Other, "writing request timed out"))??;
445
446        Ok(())
447    }
448
449    async fn write_response<T>(
450        &mut self,
451        _: &Self::Protocol,
452        io: &mut T,
453        resps: Self::Response,
454    ) -> io::Result<()>
455    where
456        T: AsyncWrite + Unpin + Send,
457    {
458        let mut buf = Vec::with_capacity(RESPONSE_SIZE_LIMIT);
459
460        for resp in resps {
461            if resp.encode_length_delimited(&mut buf).is_err() {
462                // Error on encoding means the buffer is full.
463                // We will send a partial response back.
464                debug!("Sending partial response");
465                break;
466            }
467        }
468
469        timeout(RESPONSE_TIME_LIMIT, io.write_all(&buf))
470            .await
471            .map_err(|_| io::Error::new(io::ErrorKind::Other, "writing response timed out"))??;
472
473        Ok(())
474    }
475}
476
477/// Reads up to `size_limit` within `time_limit`.
478async fn read_up_to<T>(io: &mut T, size_limit: usize, time_limit: Duration) -> io::Result<Vec<u8>>
479where
480    T: AsyncRead + Unpin + Send,
481{
482    let mut buf = vec![0u8; size_limit];
483    let mut read_len = 0;
484    let now = Instant::now();
485
486    loop {
487        if read_len == buf.len() {
488            // No empty space. Buffer is full.
489            break;
490        }
491
492        let Some(time_limit) = time_limit.checked_sub(now.elapsed()) else {
493            break;
494        };
495
496        let len = match timeout(time_limit, io.read(&mut buf[read_len..])).await {
497            Ok(Ok(len)) => len,
498            Ok(Err(e)) => return Err(e),
499            Err(_) => break,
500        };
501
502        if len == 0 {
503            // EOF
504            break;
505        }
506
507        read_len += len;
508    }
509
510    buf.truncate(read_len);
511
512    Ok(buf)
513}
514
515fn parse_delimiter(mut buf: &[u8]) -> Option<(usize, &[u8])> {
516    if buf.is_empty() {
517        return None;
518    }
519
520    let Ok(len) = prost::decode_length_delimiter(&mut buf) else {
521        return None;
522    };
523
524    Some((len, buf))
525}
526
527fn parse_header_response(buf: &[u8]) -> Option<(HeaderResponse, &[u8])> {
528    let (len, rest) = parse_delimiter(buf)?;
529
530    if rest.len() < len {
531        debug!("Message is incomplete: {len}");
532        return None;
533    }
534
535    let Ok(msg) = HeaderResponse::decode(&rest[..len]) else {
536        return None;
537    };
538
539    Some((msg, &rest[len..]))
540}
541
542fn parse_header_request(buf: &[u8]) -> Option<HeaderRequest> {
543    let (len, rest) = parse_delimiter(buf)?;
544
545    if rest.len() < len {
546        debug!("Message is incomplete: {len}");
547        return None;
548    }
549
550    let Ok(msg) = HeaderRequest::decode(&rest[..len]) else {
551        return None;
552    };
553
554    Some(msg)
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560    use bytes::BytesMut;
561    use celestia_proto::p2p::pb::header_request::Data;
562    use futures::io::{Cursor, Error};
563    use lumina_utils::test_utils::async_test;
564    use prost::encode_length_delimiter;
565    use std::io::ErrorKind;
566    use std::pin::Pin;
567
568    #[async_test]
569    async fn test_decode_header_request_empty() {
570        let header_request = HeaderRequest {
571            amount: 0,
572            data: None,
573        };
574
575        let encoded_header_request = header_request.encode_length_delimited_to_vec();
576
577        let mut reader = Cursor::new(encoded_header_request);
578
579        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
580        let mut codec = HeaderCodec {};
581
582        let decoded_header_request = codec
583            .read_request(&stream_protocol, &mut reader)
584            .await
585            .unwrap();
586
587        assert_eq!(header_request, decoded_header_request);
588    }
589
590    #[async_test]
591    async fn test_decode_multiple_small_header_response() {
592        const MSG_COUNT: usize = 10;
593        let header_response = HeaderResponse {
594            body: vec![1, 2, 3],
595            status_code: 1,
596        };
597
598        let encoded_header_response = header_response.encode_length_delimited_to_vec();
599
600        let mut multi_msg = vec![];
601        for _ in 0..MSG_COUNT {
602            multi_msg.extend_from_slice(&encoded_header_response);
603        }
604        let mut reader = Cursor::new(multi_msg);
605
606        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
607        let mut codec = HeaderCodec {};
608
609        let decoded_header_response = codec
610            .read_response(&stream_protocol, &mut reader)
611            .await
612            .unwrap();
613
614        for decoded_header in decoded_header_response.iter() {
615            assert_eq!(&header_response, decoded_header);
616        }
617        assert_eq!(decoded_header_response.len(), MSG_COUNT);
618    }
619
620    #[async_test]
621    async fn test_decode_header_request_too_large() {
622        let too_long_message_len = REQUEST_SIZE_LIMIT + 1;
623        let mut length_delimiter_buffer = BytesMut::new();
624        prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
625        let mut reader = Cursor::new(length_delimiter_buffer);
626
627        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
628        let mut codec = HeaderCodec {};
629
630        let decoding_error = codec
631            .read_request(&stream_protocol, &mut reader)
632            .await
633            .expect_err("expected error for too large request");
634
635        assert_eq!(decoding_error.kind(), ErrorKind::Other);
636    }
637
638    #[async_test]
639    async fn test_decode_header_response_too_large() {
640        let too_long_message_len = RESPONSE_SIZE_LIMIT + 1;
641        let mut length_delimiter_buffer = BytesMut::new();
642        encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
643        let mut reader = Cursor::new(length_delimiter_buffer);
644
645        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
646        let mut codec = HeaderCodec {};
647
648        let decoding_error = codec
649            .read_response(&stream_protocol, &mut reader)
650            .await
651            .expect_err("expected error for too large request");
652
653        assert_eq!(decoding_error.kind(), ErrorKind::Other);
654    }
655
656    #[test]
657    fn test_invalid_varint() {
658        // 10 consecutive bytes with continuation bit set + 1 byte, which is longer than allowed
659        //    for length delimiter
660        let varint = [
661            0b1000_0000,
662            0b1000_0000,
663            0b1000_0000,
664            0b1000_0000,
665            0b1000_0000,
666            0b1000_0000,
667            0b1000_0000,
668            0b1000_0000,
669            0b1000_0000,
670            0b1000_0000,
671            0b0000_0001,
672        ];
673
674        assert_eq!(parse_delimiter(&varint), None);
675    }
676
677    #[test]
678    fn parse_trailing_zero_varint() {
679        let varint = [0b1000_0001, 0b0000_0000, 0b1111_1111];
680        assert!(matches!(parse_delimiter(&varint), Some((1, [255]))));
681
682        let varint = [0b1000_0000, 0b1000_0000, 0b1000_0000, 0b0000_0000];
683        assert!(matches!(parse_delimiter(&varint), Some((0, []))));
684    }
685
686    #[async_test]
687    async fn test_decode_header_double_response_data() {
688        let mut header_response_buffer = BytesMut::with_capacity(512);
689        let header_response0 = HeaderResponse {
690            body: b"9999888877776666555544443333222211110000".to_vec(),
691            status_code: 1,
692        };
693        let header_response1 = HeaderResponse {
694            body: b"0000111122223333444455556666777788889999".to_vec(),
695            status_code: 2,
696        };
697        header_response0
698            .encode_length_delimited(&mut header_response_buffer)
699            .unwrap();
700        header_response1
701            .encode_length_delimited(&mut header_response_buffer)
702            .unwrap();
703        let mut reader = Cursor::new(header_response_buffer);
704
705        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
706        let mut codec = HeaderCodec {};
707
708        let decoded_header_response = codec
709            .read_response(&stream_protocol, &mut reader)
710            .await
711            .unwrap();
712        assert_eq!(header_response0, decoded_header_response[0]);
713        assert_eq!(header_response1, decoded_header_response[1]);
714    }
715
716    #[async_test]
717    async fn test_decode_header_request_chunked_data() {
718        let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
719        let header_request = HeaderRequest {
720            amount: 1,
721            data: Some(Data::Hash(data.to_vec())),
722        };
723        let encoded_header_request = header_request.encode_length_delimited_to_vec();
724
725        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
726        let mut codec = HeaderCodec {};
727        {
728            let mut reader =
729                ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_request.clone()));
730            let decoded_header_request = codec
731                .read_request(&stream_protocol, &mut reader)
732                .await
733                .unwrap();
734            assert_eq!(header_request, decoded_header_request);
735        }
736        {
737            let mut reader =
738                ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_request.clone()));
739            let decoded_header_request = codec
740                .read_request(&stream_protocol, &mut reader)
741                .await
742                .unwrap();
743
744            assert_eq!(header_request, decoded_header_request);
745        }
746        {
747            let mut reader =
748                ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_request.clone()));
749            let decoded_header_request = codec
750                .read_request(&stream_protocol, &mut reader)
751                .await
752                .unwrap();
753
754            assert_eq!(header_request, decoded_header_request);
755        }
756        {
757            let mut reader =
758                ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_request.clone()));
759            let decoded_header_request = codec
760                .read_request(&stream_protocol, &mut reader)
761                .await
762                .unwrap();
763
764            assert_eq!(header_request, decoded_header_request);
765        }
766    }
767
768    #[async_test]
769    async fn test_decode_header_response_chunked_data() {
770        let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
771        let header_response = HeaderResponse {
772            body: data.to_vec(),
773            status_code: 2,
774        };
775        let encoded_header_response = header_response.encode_length_delimited_to_vec();
776
777        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
778        let mut codec = HeaderCodec {};
779        {
780            let mut reader =
781                ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone()));
782            let decoded_header_response = codec
783                .read_response(&stream_protocol, &mut reader)
784                .await
785                .unwrap();
786            assert_eq!(header_response, decoded_header_response[0]);
787        }
788        {
789            let mut reader =
790                ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone()));
791            let decoded_header_response = codec
792                .read_response(&stream_protocol, &mut reader)
793                .await
794                .unwrap();
795
796            assert_eq!(header_response, decoded_header_response[0]);
797        }
798        {
799            let mut reader =
800                ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone()));
801            let decoded_header_response = codec
802                .read_response(&stream_protocol, &mut reader)
803                .await
804                .unwrap();
805
806            assert_eq!(header_response, decoded_header_response[0]);
807        }
808        {
809            let mut reader =
810                ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone()));
811            let decoded_header_response = codec
812                .read_response(&stream_protocol, &mut reader)
813                .await
814                .unwrap();
815
816            assert_eq!(header_response, decoded_header_response[0]);
817        }
818    }
819
820    #[async_test]
821    async fn test_chunky_async_read() {
822        let read_data = "FOO123";
823        let cur0 = Cursor::new(read_data);
824        let mut chunky = ChunkyAsyncRead::<_, 3>::new(cur0);
825
826        let mut output_buffer: BytesMut = b"BAR987".as_ref().into();
827
828        let _ = chunky.read(&mut output_buffer[..]).await.unwrap();
829        assert_eq!(output_buffer, b"FOO987".as_ref());
830    }
831
832    struct ChunkyAsyncRead<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> {
833        inner: T,
834    }
835
836    impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> ChunkyAsyncRead<T, CHUNK_SIZE> {
837        fn new(inner: T) -> Self {
838            ChunkyAsyncRead { inner }
839        }
840    }
841
842    impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> AsyncRead for ChunkyAsyncRead<T, CHUNK_SIZE> {
843        fn poll_read(
844            mut self: Pin<&mut Self>,
845            cx: &mut Context<'_>,
846            buf: &mut [u8],
847        ) -> Poll<Result<usize, Error>> {
848            let len = buf.len().min(CHUNK_SIZE);
849            Pin::new(&mut self.inner).poll_read(cx, &mut buf[..len])
850        }
851    }
852}