lumina_node/p2p/
header_ex.rs

1use std::io;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4use std::time::Duration;
5
6use async_trait::async_trait;
7use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
8use celestia_types::ExtendedHeader;
9use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10use libp2p::core::transport::PortUse;
11use libp2p::{
12    core::Endpoint,
13    request_response::{self, Codec, InboundFailure, OutboundFailure, ProtocolSupport},
14    swarm::{
15        handler::ConnectionEvent, ConnectionDenied, ConnectionHandler, ConnectionHandlerEvent,
16        ConnectionId, FromSwarm, NetworkBehaviour, SubstreamProtocol, THandlerInEvent,
17        THandlerOutEvent, ToSwarm,
18    },
19    Multiaddr, PeerId, StreamProtocol,
20};
21use lumina_utils::time::{timeout, Instant};
22use prost::Message;
23use tracing::{debug, instrument, warn};
24
25mod client;
26mod server;
27pub(crate) mod utils;
28
29use crate::p2p::header_ex::client::HeaderExClientHandler;
30use crate::p2p::header_ex::server::HeaderExServerHandler;
31use crate::p2p::P2pError;
32use crate::peer_tracker::PeerTracker;
33use crate::store::Store;
34use crate::utils::{protocol_id, OneshotResultSender};
35
36/// Size limit of a request in bytes
37const REQUEST_SIZE_LIMIT: usize = 1024;
38/// Time limit on reading/writing a request
39const REQUEST_TIME_LIMIT: Duration = Duration::from_secs(1);
40/// Size limit of a response in bytes
41const RESPONSE_SIZE_LIMIT: usize = 10 * 1024 * 1024;
42/// Time limit on reading/writing a response
43const RESPONSE_TIME_LIMIT: Duration = Duration::from_secs(5);
44/// Substream negotiation timeout
45const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(1);
46
47type RequestType = HeaderRequest;
48type ResponseType = Vec<HeaderResponse>;
49type ReqRespBehaviour = request_response::Behaviour<HeaderCodec>;
50type ReqRespEvent = request_response::Event<RequestType, ResponseType>;
51type ReqRespMessage = request_response::Message<RequestType, ResponseType>;
52type ReqRespConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;
53
54pub(crate) struct HeaderExBehaviour<S>
55where
56    S: Store + 'static,
57{
58    req_resp: ReqRespBehaviour,
59    client_handler: HeaderExClientHandler,
60    server_handler: HeaderExServerHandler<S>,
61}
62
63pub(crate) struct HeaderExConfig<'a, S> {
64    pub network_id: &'a str,
65    pub peer_tracker: Arc<PeerTracker>,
66    pub header_store: Arc<S>,
67}
68
69/// Representation of all the errors that can occur in `HeaderEx` component.
70#[derive(Debug, thiserror::Error)]
71pub enum HeaderExError {
72    /// Header not found.
73    #[error("Header not found")]
74    HeaderNotFound,
75
76    /// The response is invalid.
77    #[error("Invalid response")]
78    InvalidResponse,
79
80    /// The request is invalid.
81    #[error("Invalid request")]
82    InvalidRequest,
83
84    /// Error when handling connection from the client.
85    #[error("Inbound failure: {0}")]
86    InboundFailure(InboundFailure),
87
88    /// Error when handling connection to the server.
89    #[error("Outbound failure: {0}")]
90    OutboundFailure(OutboundFailure),
91
92    /// Request cancelled because [`Node`] is stopping.
93    ///
94    /// [`Node`]: crate::node::Node
95    #[error("Request cancelled because `Node` is stopping")]
96    RequestCancelled,
97}
98
99impl<S> HeaderExBehaviour<S>
100where
101    S: Store + 'static,
102{
103    pub(crate) fn new(config: HeaderExConfig<'_, S>) -> Self {
104        HeaderExBehaviour {
105            req_resp: ReqRespBehaviour::new(
106                [(
107                    protocol_id(config.network_id, "/header-ex/v0.0.3"),
108                    ProtocolSupport::Full,
109                )],
110                request_response::Config::default(),
111            ),
112            client_handler: HeaderExClientHandler::new(config.peer_tracker),
113            server_handler: HeaderExServerHandler::new(config.header_store),
114        }
115    }
116
117    #[instrument(level = "trace", skip(self, respond_to))]
118    pub(crate) fn send_request(
119        &mut self,
120        request: HeaderRequest,
121        respond_to: OneshotResultSender<Vec<ExtendedHeader>, P2pError>,
122    ) {
123        self.client_handler
124            .on_send_request(&mut self.req_resp, request, respond_to);
125    }
126
127    pub(crate) fn stop(&mut self) {
128        self.client_handler.on_stop();
129        self.server_handler.on_stop();
130    }
131
132    fn on_to_swarm(
133        &mut self,
134        ev: ToSwarm<ReqRespEvent, THandlerInEvent<ReqRespBehaviour>>,
135    ) -> Option<ToSwarm<(), THandlerInEvent<Self>>> {
136        match ev {
137            ToSwarm::GenerateEvent(ev) => {
138                self.on_req_resp_event(ev);
139                None
140            }
141            _ => Some(ev.map_out(|_| ())),
142        }
143    }
144
145    #[instrument(level = "trace", skip_all)]
146    fn on_req_resp_event(&mut self, ev: ReqRespEvent) {
147        match ev {
148            // Received a response for an ongoing outbound request
149            ReqRespEvent::Message {
150                message:
151                    ReqRespMessage::Response {
152                        request_id,
153                        response,
154                    },
155                peer,
156                ..
157            } => {
158                self.client_handler
159                    .on_response_received(peer, request_id, response);
160            }
161
162            // Failure while client requests
163            ReqRespEvent::OutboundFailure {
164                peer,
165                request_id,
166                error,
167                ..
168            } => {
169                self.client_handler.on_failure(peer, request_id, error);
170            }
171
172            // Received new inbound request
173            ReqRespEvent::Message {
174                message:
175                    ReqRespMessage::Request {
176                        request_id,
177                        request,
178                        channel,
179                    },
180                peer,
181                ..
182            } => {
183                self.server_handler.on_request_received(
184                    peer,
185                    request_id,
186                    request,
187                    &mut self.req_resp,
188                    channel,
189                );
190            }
191
192            // Response to inbound request was sent
193            ReqRespEvent::ResponseSent {
194                peer, request_id, ..
195            } => {
196                self.server_handler.on_response_sent(peer, request_id);
197            }
198
199            // Failure while server responds
200            ReqRespEvent::InboundFailure {
201                peer,
202                request_id,
203                error,
204                ..
205            } => {
206                self.server_handler.on_failure(peer, request_id, error);
207            }
208        }
209    }
210}
211
212impl<S> NetworkBehaviour for HeaderExBehaviour<S>
213where
214    S: Store + 'static,
215{
216    type ConnectionHandler = ConnHandler;
217    type ToSwarm = ();
218
219    fn handle_established_inbound_connection(
220        &mut self,
221        connection_id: ConnectionId,
222        peer: PeerId,
223        local_addr: &Multiaddr,
224        remote_addr: &Multiaddr,
225    ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
226        self.req_resp
227            .handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)
228            .map(ConnHandler)
229    }
230
231    fn handle_established_outbound_connection(
232        &mut self,
233        connection_id: ConnectionId,
234        peer: PeerId,
235        addr: &Multiaddr,
236        role_override: Endpoint,
237        port_use: PortUse,
238    ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
239        self.req_resp
240            .handle_established_outbound_connection(
241                connection_id,
242                peer,
243                addr,
244                role_override,
245                port_use,
246            )
247            .map(ConnHandler)
248    }
249
250    fn handle_pending_inbound_connection(
251        &mut self,
252        connection_id: ConnectionId,
253        local_addr: &Multiaddr,
254        remote_addr: &Multiaddr,
255    ) -> Result<(), ConnectionDenied> {
256        self.req_resp
257            .handle_pending_inbound_connection(connection_id, local_addr, remote_addr)
258    }
259
260    fn handle_pending_outbound_connection(
261        &mut self,
262        connection_id: ConnectionId,
263        maybe_peer: Option<PeerId>,
264        addresses: &[Multiaddr],
265        effective_role: Endpoint,
266    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
267        self.req_resp.handle_pending_outbound_connection(
268            connection_id,
269            maybe_peer,
270            addresses,
271            effective_role,
272        )
273    }
274
275    fn on_swarm_event(&mut self, event: FromSwarm) {
276        self.req_resp.on_swarm_event(event)
277    }
278
279    fn on_connection_handler_event(
280        &mut self,
281        peer_id: PeerId,
282        connection_id: ConnectionId,
283        event: THandlerOutEvent<Self>,
284    ) {
285        self.req_resp
286            .on_connection_handler_event(peer_id, connection_id, event)
287    }
288
289    fn poll(
290        &mut self,
291        cx: &mut Context<'_>,
292    ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
293        loop {
294            if let Poll::Ready(ev) = self.req_resp.poll(cx) {
295                if let Some(ev) = self.on_to_swarm(ev) {
296                    return Poll::Ready(ev);
297                }
298
299                continue;
300            }
301
302            if self.client_handler.poll(cx).is_ready() {
303                continue;
304            }
305
306            if self.server_handler.poll(cx, &mut self.req_resp).is_ready() {
307                continue;
308            }
309
310            return Poll::Pending;
311        }
312    }
313}
314
315pub(crate) struct ConnHandler(ReqRespConnectionHandler);
316
317impl ConnectionHandler for ConnHandler {
318    type ToBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::ToBehaviour;
319    type FromBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::FromBehaviour;
320    type InboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::InboundProtocol;
321    type InboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::InboundOpenInfo;
322    type OutboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::OutboundProtocol;
323    type OutboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::OutboundOpenInfo;
324
325    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
326        self.0.listen_protocol().with_timeout(NEGOTIATION_TIMEOUT)
327    }
328
329    fn poll(
330        &mut self,
331        cx: &mut Context<'_>,
332    ) -> Poll<
333        ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
334    > {
335        match self.0.poll(cx) {
336            Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
337                Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
338                    protocol: protocol.with_timeout(NEGOTIATION_TIMEOUT),
339                })
340            }
341            ev => ev,
342        }
343    }
344
345    fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
346        self.0.on_behaviour_event(event)
347    }
348
349    fn on_connection_event(
350        &mut self,
351        event: ConnectionEvent<
352            '_,
353            Self::InboundProtocol,
354            Self::OutboundProtocol,
355            Self::InboundOpenInfo,
356            Self::OutboundOpenInfo,
357        >,
358    ) {
359        self.0.on_connection_event(event)
360    }
361
362    fn connection_keep_alive(&self) -> bool {
363        self.0.connection_keep_alive()
364    }
365
366    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
367        self.0.poll_close(cx)
368    }
369}
370
371#[derive(Clone, Copy, Debug, Default)]
372pub(crate) struct HeaderCodec;
373
374#[async_trait]
375impl Codec for HeaderCodec {
376    type Protocol = StreamProtocol;
377    type Request = HeaderRequest;
378    type Response = Vec<HeaderResponse>;
379
380    async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
381    where
382        T: AsyncRead + Unpin + Send,
383    {
384        let data = read_up_to(io, REQUEST_SIZE_LIMIT, REQUEST_TIME_LIMIT).await?;
385
386        if data.len() >= REQUEST_SIZE_LIMIT {
387            debug!("Message filled the whole buffer (len: {})", data.len());
388        }
389
390        parse_header_request(&data).ok_or_else(|| {
391            // There are two cases that can reach here:
392            //
393            // 1. The request is invalid
394            // 2. The request is incomplete because of the size limit or time limit
395            io::Error::other("invalid or incomplete request")
396        })
397    }
398
399    async fn read_response<T>(
400        &mut self,
401        _: &Self::Protocol,
402        io: &mut T,
403    ) -> io::Result<Self::Response>
404    where
405        T: AsyncRead + Unpin + Send,
406    {
407        let data = read_up_to(io, RESPONSE_SIZE_LIMIT, RESPONSE_TIME_LIMIT).await?;
408
409        if data.len() >= RESPONSE_SIZE_LIMIT {
410            debug!("Message filled the whole buffer (len: {})", data.len());
411        }
412
413        let mut data = &data[..];
414        let mut msgs = Vec::new();
415
416        while let Some((header, rest)) = parse_header_response(data) {
417            msgs.push(header);
418            data = rest;
419        }
420
421        if msgs.is_empty() {
422            // There are two cases that can reach here:
423            //
424            // 1. The response is invalid
425            // 2. The response is incomplete because of the size limit or time limit
426            return Err(io::Error::other("invalid or incomplete response"));
427        }
428
429        Ok(msgs)
430    }
431
432    async fn write_request<T>(
433        &mut self,
434        _: &Self::Protocol,
435        io: &mut T,
436        req: Self::Request,
437    ) -> io::Result<()>
438    where
439        T: AsyncWrite + Unpin + Send,
440    {
441        let mut buf = Vec::with_capacity(REQUEST_SIZE_LIMIT);
442
443        let _ = req.encode_length_delimited(&mut buf);
444
445        timeout(REQUEST_TIME_LIMIT, io.write_all(&buf))
446            .await
447            .map_err(|_| io::Error::other("writing request timed out"))??;
448
449        Ok(())
450    }
451
452    async fn write_response<T>(
453        &mut self,
454        _: &Self::Protocol,
455        io: &mut T,
456        resps: Self::Response,
457    ) -> io::Result<()>
458    where
459        T: AsyncWrite + Unpin + Send,
460    {
461        let mut buf = Vec::with_capacity(RESPONSE_SIZE_LIMIT);
462
463        for resp in resps {
464            if resp.encode_length_delimited(&mut buf).is_err() {
465                // Error on encoding means the buffer is full.
466                // We will send a partial response back.
467                debug!("Sending partial response");
468                break;
469            }
470        }
471
472        timeout(RESPONSE_TIME_LIMIT, io.write_all(&buf))
473            .await
474            .map_err(|_| io::Error::other("writing response timed out"))??;
475
476        Ok(())
477    }
478}
479
480/// Reads up to `size_limit` within `time_limit`.
481async fn read_up_to<T>(io: &mut T, size_limit: usize, time_limit: Duration) -> io::Result<Vec<u8>>
482where
483    T: AsyncRead + Unpin + Send,
484{
485    let mut buf = vec![0u8; size_limit];
486    let mut read_len = 0;
487    let now = Instant::now();
488
489    loop {
490        if read_len == buf.len() {
491            // No empty space. Buffer is full.
492            break;
493        }
494
495        let Some(time_limit) = time_limit.checked_sub(now.elapsed()) else {
496            break;
497        };
498
499        let len = match timeout(time_limit, io.read(&mut buf[read_len..])).await {
500            Ok(Ok(len)) => len,
501            Ok(Err(e)) => return Err(e),
502            Err(_) => break,
503        };
504
505        if len == 0 {
506            // EOF
507            break;
508        }
509
510        read_len += len;
511    }
512
513    buf.truncate(read_len);
514
515    Ok(buf)
516}
517
518fn parse_delimiter(mut buf: &[u8]) -> Option<(usize, &[u8])> {
519    if buf.is_empty() {
520        return None;
521    }
522
523    let Ok(len) = prost::decode_length_delimiter(&mut buf) else {
524        return None;
525    };
526
527    Some((len, buf))
528}
529
530fn parse_header_response(buf: &[u8]) -> Option<(HeaderResponse, &[u8])> {
531    let (len, rest) = parse_delimiter(buf)?;
532
533    if rest.len() < len {
534        debug!("Message is incomplete: {len}");
535        return None;
536    }
537
538    let Ok(msg) = HeaderResponse::decode(&rest[..len]) else {
539        return None;
540    };
541
542    Some((msg, &rest[len..]))
543}
544
545fn parse_header_request(buf: &[u8]) -> Option<HeaderRequest> {
546    let (len, rest) = parse_delimiter(buf)?;
547
548    if rest.len() < len {
549        debug!("Message is incomplete: {len}");
550        return None;
551    }
552
553    let Ok(msg) = HeaderRequest::decode(&rest[..len]) else {
554        return None;
555    };
556
557    Some(msg)
558}
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use bytes::BytesMut;
564    use celestia_proto::p2p::pb::header_request::Data;
565    use futures::io::{Cursor, Error};
566    use lumina_utils::test_utils::async_test;
567    use prost::encode_length_delimiter;
568    use std::io::ErrorKind;
569    use std::pin::Pin;
570
571    #[async_test]
572    async fn test_decode_header_request_empty() {
573        let header_request = HeaderRequest {
574            amount: 0,
575            data: None,
576        };
577
578        let encoded_header_request = header_request.encode_length_delimited_to_vec();
579
580        let mut reader = Cursor::new(encoded_header_request);
581
582        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
583        let mut codec = HeaderCodec {};
584
585        let decoded_header_request = codec
586            .read_request(&stream_protocol, &mut reader)
587            .await
588            .unwrap();
589
590        assert_eq!(header_request, decoded_header_request);
591    }
592
593    #[async_test]
594    async fn test_decode_multiple_small_header_response() {
595        const MSG_COUNT: usize = 10;
596        let header_response = HeaderResponse {
597            body: vec![1, 2, 3],
598            status_code: 1,
599        };
600
601        let encoded_header_response = header_response.encode_length_delimited_to_vec();
602
603        let mut multi_msg = vec![];
604        for _ in 0..MSG_COUNT {
605            multi_msg.extend_from_slice(&encoded_header_response);
606        }
607        let mut reader = Cursor::new(multi_msg);
608
609        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
610        let mut codec = HeaderCodec {};
611
612        let decoded_header_response = codec
613            .read_response(&stream_protocol, &mut reader)
614            .await
615            .unwrap();
616
617        for decoded_header in decoded_header_response.iter() {
618            assert_eq!(&header_response, decoded_header);
619        }
620        assert_eq!(decoded_header_response.len(), MSG_COUNT);
621    }
622
623    #[async_test]
624    async fn test_decode_header_request_too_large() {
625        let too_long_message_len = REQUEST_SIZE_LIMIT + 1;
626        let mut length_delimiter_buffer = BytesMut::new();
627        prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
628        let mut reader = Cursor::new(length_delimiter_buffer);
629
630        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
631        let mut codec = HeaderCodec {};
632
633        let decoding_error = codec
634            .read_request(&stream_protocol, &mut reader)
635            .await
636            .expect_err("expected error for too large request");
637
638        assert_eq!(decoding_error.kind(), ErrorKind::Other);
639    }
640
641    #[async_test]
642    async fn test_decode_header_response_too_large() {
643        let too_long_message_len = RESPONSE_SIZE_LIMIT + 1;
644        let mut length_delimiter_buffer = BytesMut::new();
645        encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
646        let mut reader = Cursor::new(length_delimiter_buffer);
647
648        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
649        let mut codec = HeaderCodec {};
650
651        let decoding_error = codec
652            .read_response(&stream_protocol, &mut reader)
653            .await
654            .expect_err("expected error for too large request");
655
656        assert_eq!(decoding_error.kind(), ErrorKind::Other);
657    }
658
659    #[test]
660    fn test_invalid_varint() {
661        // 10 consecutive bytes with continuation bit set + 1 byte, which is longer than allowed
662        //    for length delimiter
663        let varint = [
664            0b1000_0000,
665            0b1000_0000,
666            0b1000_0000,
667            0b1000_0000,
668            0b1000_0000,
669            0b1000_0000,
670            0b1000_0000,
671            0b1000_0000,
672            0b1000_0000,
673            0b1000_0000,
674            0b0000_0001,
675        ];
676
677        assert_eq!(parse_delimiter(&varint), None);
678    }
679
680    #[test]
681    fn parse_trailing_zero_varint() {
682        let varint = [0b1000_0001, 0b0000_0000, 0b1111_1111];
683        assert!(matches!(parse_delimiter(&varint), Some((1, [255]))));
684
685        let varint = [0b1000_0000, 0b1000_0000, 0b1000_0000, 0b0000_0000];
686        assert!(matches!(parse_delimiter(&varint), Some((0, []))));
687    }
688
689    #[async_test]
690    async fn test_decode_header_double_response_data() {
691        let mut header_response_buffer = BytesMut::with_capacity(512);
692        let header_response0 = HeaderResponse {
693            body: b"9999888877776666555544443333222211110000".to_vec(),
694            status_code: 1,
695        };
696        let header_response1 = HeaderResponse {
697            body: b"0000111122223333444455556666777788889999".to_vec(),
698            status_code: 2,
699        };
700        header_response0
701            .encode_length_delimited(&mut header_response_buffer)
702            .unwrap();
703        header_response1
704            .encode_length_delimited(&mut header_response_buffer)
705            .unwrap();
706        let mut reader = Cursor::new(header_response_buffer);
707
708        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
709        let mut codec = HeaderCodec {};
710
711        let decoded_header_response = codec
712            .read_response(&stream_protocol, &mut reader)
713            .await
714            .unwrap();
715        assert_eq!(header_response0, decoded_header_response[0]);
716        assert_eq!(header_response1, decoded_header_response[1]);
717    }
718
719    #[async_test]
720    async fn test_decode_header_request_chunked_data() {
721        let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
722        let header_request = HeaderRequest {
723            amount: 1,
724            data: Some(Data::Hash(data.to_vec())),
725        };
726        let encoded_header_request = header_request.encode_length_delimited_to_vec();
727
728        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
729        let mut codec = HeaderCodec {};
730        {
731            let mut reader =
732                ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_request.clone()));
733            let decoded_header_request = codec
734                .read_request(&stream_protocol, &mut reader)
735                .await
736                .unwrap();
737            assert_eq!(header_request, decoded_header_request);
738        }
739        {
740            let mut reader =
741                ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_request.clone()));
742            let decoded_header_request = codec
743                .read_request(&stream_protocol, &mut reader)
744                .await
745                .unwrap();
746
747            assert_eq!(header_request, decoded_header_request);
748        }
749        {
750            let mut reader =
751                ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_request.clone()));
752            let decoded_header_request = codec
753                .read_request(&stream_protocol, &mut reader)
754                .await
755                .unwrap();
756
757            assert_eq!(header_request, decoded_header_request);
758        }
759        {
760            let mut reader =
761                ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_request.clone()));
762            let decoded_header_request = codec
763                .read_request(&stream_protocol, &mut reader)
764                .await
765                .unwrap();
766
767            assert_eq!(header_request, decoded_header_request);
768        }
769    }
770
771    #[async_test]
772    async fn test_decode_header_response_chunked_data() {
773        let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
774        let header_response = HeaderResponse {
775            body: data.to_vec(),
776            status_code: 2,
777        };
778        let encoded_header_response = header_response.encode_length_delimited_to_vec();
779
780        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
781        let mut codec = HeaderCodec {};
782        {
783            let mut reader =
784                ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone()));
785            let decoded_header_response = codec
786                .read_response(&stream_protocol, &mut reader)
787                .await
788                .unwrap();
789            assert_eq!(header_response, decoded_header_response[0]);
790        }
791        {
792            let mut reader =
793                ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone()));
794            let decoded_header_response = codec
795                .read_response(&stream_protocol, &mut reader)
796                .await
797                .unwrap();
798
799            assert_eq!(header_response, decoded_header_response[0]);
800        }
801        {
802            let mut reader =
803                ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone()));
804            let decoded_header_response = codec
805                .read_response(&stream_protocol, &mut reader)
806                .await
807                .unwrap();
808
809            assert_eq!(header_response, decoded_header_response[0]);
810        }
811        {
812            let mut reader =
813                ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone()));
814            let decoded_header_response = codec
815                .read_response(&stream_protocol, &mut reader)
816                .await
817                .unwrap();
818
819            assert_eq!(header_response, decoded_header_response[0]);
820        }
821    }
822
823    #[async_test]
824    async fn test_chunky_async_read() {
825        let read_data = "FOO123";
826        let cur0 = Cursor::new(read_data);
827        let mut chunky = ChunkyAsyncRead::<_, 3>::new(cur0);
828
829        let mut output_buffer: BytesMut = b"BAR987".as_ref().into();
830
831        let _ = chunky.read(&mut output_buffer[..]).await.unwrap();
832        assert_eq!(output_buffer, b"FOO987".as_ref());
833    }
834
835    struct ChunkyAsyncRead<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> {
836        inner: T,
837    }
838
839    impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> ChunkyAsyncRead<T, CHUNK_SIZE> {
840        fn new(inner: T) -> Self {
841            ChunkyAsyncRead { inner }
842        }
843    }
844
845    impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> AsyncRead for ChunkyAsyncRead<T, CHUNK_SIZE> {
846        fn poll_read(
847            mut self: Pin<&mut Self>,
848            cx: &mut Context<'_>,
849            buf: &mut [u8],
850        ) -> Poll<Result<usize, Error>> {
851            let len = buf.len().min(CHUNK_SIZE);
852            Pin::new(&mut self.inner).poll_read(cx, &mut buf[..len])
853        }
854    }
855}