lumina_node/p2p/
header_ex.rs

1use std::io;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4use std::time::Duration;
5
6use async_trait::async_trait;
7use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
8use celestia_types::ExtendedHeader;
9use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10use libp2p::core::transport::PortUse;
11use libp2p::{
12    core::Endpoint,
13    request_response::{self, Codec, InboundFailure, OutboundFailure, ProtocolSupport},
14    swarm::{
15        handler::ConnectionEvent, ConnectionDenied, ConnectionHandler, ConnectionHandlerEvent,
16        ConnectionId, FromSwarm, NetworkBehaviour, SubstreamProtocol, THandlerInEvent,
17        THandlerOutEvent, ToSwarm,
18    },
19    Multiaddr, PeerId, StreamProtocol,
20};
21use lumina_utils::time::{timeout, Instant};
22use prost::Message;
23use tracing::{debug, instrument, warn};
24
25mod client;
26mod server;
27pub(crate) mod utils;
28
29use crate::p2p::header_ex::client::HeaderExClientHandler;
30use crate::p2p::header_ex::server::HeaderExServerHandler;
31use crate::p2p::P2pError;
32use crate::peer_tracker::PeerTracker;
33use crate::store::Store;
34use crate::utils::{protocol_id, OneshotResultSender};
35
36/// Size limit of a request in bytes
37const REQUEST_SIZE_LIMIT: usize = 1024;
38/// Time limit on reading/writing a request
39const REQUEST_TIME_LIMIT: Duration = Duration::from_secs(1);
40/// Size limit of a response in bytes
41const RESPONSE_SIZE_LIMIT: usize = 10 * 1024 * 1024;
42/// Time limit on reading/writing a response
43const RESPONSE_TIME_LIMIT: Duration = Duration::from_secs(5);
44/// Substream negotiation timeout
45const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(1);
46
47type RequestType = HeaderRequest;
48type ResponseType = Vec<HeaderResponse>;
49type ReqRespBehaviour = request_response::Behaviour<HeaderCodec>;
50type ReqRespEvent = request_response::Event<RequestType, ResponseType>;
51type ReqRespMessage = request_response::Message<RequestType, ResponseType>;
52type ReqRespConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;
53
54pub(crate) struct HeaderExBehaviour<S>
55where
56    S: Store + 'static,
57{
58    req_resp: ReqRespBehaviour,
59    client_handler: HeaderExClientHandler,
60    server_handler: HeaderExServerHandler<S>,
61}
62
63pub(crate) struct HeaderExConfig<'a, S> {
64    pub network_id: &'a str,
65    pub peer_tracker: Arc<PeerTracker>,
66    pub header_store: Arc<S>,
67}
68
69/// Representation of all the errors that can occur in `HeaderEx` component.
70#[derive(Debug, thiserror::Error)]
71pub enum HeaderExError {
72    /// Header not found.
73    #[error("Header not found")]
74    HeaderNotFound,
75
76    /// The response is invalid.
77    #[error("Invalid response")]
78    InvalidResponse,
79
80    /// The request is invalid.
81    #[error("Invalid request")]
82    InvalidRequest,
83
84    /// Error when handling connection from the client.
85    #[error("Inbound failure: {0}")]
86    InboundFailure(InboundFailure),
87
88    /// Error when handling connection to the server.
89    #[error("Outbound failure: {0}")]
90    OutboundFailure(OutboundFailure),
91
92    /// Request cancelled because [`Node`] is stopping.
93    ///
94    /// [`Node`]: crate::node::Node
95    #[error("Request cancelled because `Node` is stopping")]
96    RequestCancelled,
97}
98
99impl<S> HeaderExBehaviour<S>
100where
101    S: Store + 'static,
102{
103    pub(crate) fn new(config: HeaderExConfig<'_, S>) -> Self {
104        HeaderExBehaviour {
105            req_resp: ReqRespBehaviour::new(
106                [(
107                    protocol_id(config.network_id, "/header-ex/v0.0.3"),
108                    ProtocolSupport::Full,
109                )],
110                request_response::Config::default(),
111            ),
112            client_handler: HeaderExClientHandler::new(config.peer_tracker),
113            server_handler: HeaderExServerHandler::new(config.header_store),
114        }
115    }
116
117    #[instrument(level = "trace", skip(self, respond_to))]
118    pub(crate) fn send_request(
119        &mut self,
120        request: HeaderRequest,
121        respond_to: OneshotResultSender<Vec<ExtendedHeader>, P2pError>,
122    ) {
123        self.client_handler
124            .on_send_request(&mut self.req_resp, request, respond_to);
125    }
126
127    pub(crate) fn stop(&mut self) {
128        self.client_handler.on_stop();
129        self.server_handler.on_stop();
130    }
131
132    fn on_to_swarm(
133        &mut self,
134        ev: ToSwarm<ReqRespEvent, THandlerInEvent<ReqRespBehaviour>>,
135    ) -> Option<ToSwarm<(), THandlerInEvent<Self>>> {
136        match ev {
137            ToSwarm::GenerateEvent(ev) => {
138                self.on_req_resp_event(ev);
139                None
140            }
141            _ => Some(ev.map_out(|_| ())),
142        }
143    }
144
145    #[instrument(level = "trace", skip_all)]
146    fn on_req_resp_event(&mut self, ev: ReqRespEvent) {
147        match ev {
148            // Received a response for an ongoing outbound request
149            ReqRespEvent::Message {
150                message:
151                    ReqRespMessage::Response {
152                        request_id,
153                        response,
154                    },
155                peer,
156            } => {
157                self.client_handler
158                    .on_response_received(peer, request_id, response);
159            }
160
161            // Failure while client requests
162            ReqRespEvent::OutboundFailure {
163                peer,
164                request_id,
165                error,
166            } => {
167                self.client_handler.on_failure(peer, request_id, error);
168            }
169
170            // Received new inbound request
171            ReqRespEvent::Message {
172                message:
173                    ReqRespMessage::Request {
174                        request_id,
175                        request,
176                        channel,
177                    },
178                peer,
179            } => {
180                self.server_handler.on_request_received(
181                    peer,
182                    request_id,
183                    request,
184                    &mut self.req_resp,
185                    channel,
186                );
187            }
188
189            // Response to inbound request was sent
190            ReqRespEvent::ResponseSent { peer, request_id } => {
191                self.server_handler.on_response_sent(peer, request_id);
192            }
193
194            // Failure while server responds
195            ReqRespEvent::InboundFailure {
196                peer,
197                request_id,
198                error,
199            } => {
200                self.server_handler.on_failure(peer, request_id, error);
201            }
202        }
203    }
204}
205
206impl<S> NetworkBehaviour for HeaderExBehaviour<S>
207where
208    S: Store + 'static,
209{
210    type ConnectionHandler = ConnHandler;
211    type ToSwarm = ();
212
213    fn handle_established_inbound_connection(
214        &mut self,
215        connection_id: ConnectionId,
216        peer: PeerId,
217        local_addr: &Multiaddr,
218        remote_addr: &Multiaddr,
219    ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
220        self.req_resp
221            .handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)
222            .map(ConnHandler)
223    }
224
225    fn handle_established_outbound_connection(
226        &mut self,
227        connection_id: ConnectionId,
228        peer: PeerId,
229        addr: &Multiaddr,
230        role_override: Endpoint,
231        port_use: PortUse,
232    ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
233        self.req_resp
234            .handle_established_outbound_connection(
235                connection_id,
236                peer,
237                addr,
238                role_override,
239                port_use,
240            )
241            .map(ConnHandler)
242    }
243
244    fn handle_pending_inbound_connection(
245        &mut self,
246        connection_id: ConnectionId,
247        local_addr: &Multiaddr,
248        remote_addr: &Multiaddr,
249    ) -> Result<(), ConnectionDenied> {
250        self.req_resp
251            .handle_pending_inbound_connection(connection_id, local_addr, remote_addr)
252    }
253
254    fn handle_pending_outbound_connection(
255        &mut self,
256        connection_id: ConnectionId,
257        maybe_peer: Option<PeerId>,
258        addresses: &[Multiaddr],
259        effective_role: Endpoint,
260    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
261        self.req_resp.handle_pending_outbound_connection(
262            connection_id,
263            maybe_peer,
264            addresses,
265            effective_role,
266        )
267    }
268
269    fn on_swarm_event(&mut self, event: FromSwarm) {
270        self.req_resp.on_swarm_event(event)
271    }
272
273    fn on_connection_handler_event(
274        &mut self,
275        peer_id: PeerId,
276        connection_id: ConnectionId,
277        event: THandlerOutEvent<Self>,
278    ) {
279        self.req_resp
280            .on_connection_handler_event(peer_id, connection_id, event)
281    }
282
283    fn poll(
284        &mut self,
285        cx: &mut Context<'_>,
286    ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
287        loop {
288            if let Poll::Ready(ev) = self.req_resp.poll(cx) {
289                if let Some(ev) = self.on_to_swarm(ev) {
290                    return Poll::Ready(ev);
291                }
292
293                continue;
294            }
295
296            if self.client_handler.poll(cx).is_ready() {
297                continue;
298            }
299
300            if self.server_handler.poll(cx, &mut self.req_resp).is_ready() {
301                continue;
302            }
303
304            return Poll::Pending;
305        }
306    }
307}
308
309pub(crate) struct ConnHandler(ReqRespConnectionHandler);
310
311impl ConnectionHandler for ConnHandler {
312    type ToBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::ToBehaviour;
313    type FromBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::FromBehaviour;
314    type InboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::InboundProtocol;
315    type InboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::InboundOpenInfo;
316    type OutboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::OutboundProtocol;
317    type OutboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::OutboundOpenInfo;
318
319    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
320        self.0.listen_protocol().with_timeout(NEGOTIATION_TIMEOUT)
321    }
322
323    fn poll(
324        &mut self,
325        cx: &mut Context<'_>,
326    ) -> Poll<
327        ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
328    > {
329        match self.0.poll(cx) {
330            Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
331                Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
332                    protocol: protocol.with_timeout(NEGOTIATION_TIMEOUT),
333                })
334            }
335            ev => ev,
336        }
337    }
338
339    fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
340        self.0.on_behaviour_event(event)
341    }
342
343    fn on_connection_event(
344        &mut self,
345        event: ConnectionEvent<
346            '_,
347            Self::InboundProtocol,
348            Self::OutboundProtocol,
349            Self::InboundOpenInfo,
350            Self::OutboundOpenInfo,
351        >,
352    ) {
353        self.0.on_connection_event(event)
354    }
355
356    fn connection_keep_alive(&self) -> bool {
357        self.0.connection_keep_alive()
358    }
359
360    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
361        self.0.poll_close(cx)
362    }
363}
364
365#[derive(Clone, Copy, Debug, Default)]
366pub(crate) struct HeaderCodec;
367
368#[async_trait]
369impl Codec for HeaderCodec {
370    type Protocol = StreamProtocol;
371    type Request = HeaderRequest;
372    type Response = Vec<HeaderResponse>;
373
374    async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
375    where
376        T: AsyncRead + Unpin + Send,
377    {
378        let data = read_up_to(io, REQUEST_SIZE_LIMIT, REQUEST_TIME_LIMIT).await?;
379
380        if data.len() >= REQUEST_SIZE_LIMIT {
381            debug!("Message filled the whole buffer (len: {})", data.len());
382        }
383
384        parse_header_request(&data).ok_or_else(|| {
385            // There are two cases that can reach here:
386            //
387            // 1. The request is invalid
388            // 2. The request is incomplete because of the size limit or time limit
389            io::Error::other("invalid or incomplete request")
390        })
391    }
392
393    async fn read_response<T>(
394        &mut self,
395        _: &Self::Protocol,
396        io: &mut T,
397    ) -> io::Result<Self::Response>
398    where
399        T: AsyncRead + Unpin + Send,
400    {
401        let data = read_up_to(io, RESPONSE_SIZE_LIMIT, RESPONSE_TIME_LIMIT).await?;
402
403        if data.len() >= RESPONSE_SIZE_LIMIT {
404            debug!("Message filled the whole buffer (len: {})", data.len());
405        }
406
407        let mut data = &data[..];
408        let mut msgs = Vec::new();
409
410        while let Some((header, rest)) = parse_header_response(data) {
411            msgs.push(header);
412            data = rest;
413        }
414
415        if msgs.is_empty() {
416            // There are two cases that can reach here:
417            //
418            // 1. The response is invalid
419            // 2. The response is incomplete because of the size limit or time limit
420            return Err(io::Error::other("invalid or incomplete response"));
421        }
422
423        Ok(msgs)
424    }
425
426    async fn write_request<T>(
427        &mut self,
428        _: &Self::Protocol,
429        io: &mut T,
430        req: Self::Request,
431    ) -> io::Result<()>
432    where
433        T: AsyncWrite + Unpin + Send,
434    {
435        let mut buf = Vec::with_capacity(REQUEST_SIZE_LIMIT);
436
437        let _ = req.encode_length_delimited(&mut buf);
438
439        timeout(REQUEST_TIME_LIMIT, io.write_all(&buf))
440            .await
441            .map_err(|_| io::Error::other("writing request timed out"))??;
442
443        Ok(())
444    }
445
446    async fn write_response<T>(
447        &mut self,
448        _: &Self::Protocol,
449        io: &mut T,
450        resps: Self::Response,
451    ) -> io::Result<()>
452    where
453        T: AsyncWrite + Unpin + Send,
454    {
455        let mut buf = Vec::with_capacity(RESPONSE_SIZE_LIMIT);
456
457        for resp in resps {
458            if resp.encode_length_delimited(&mut buf).is_err() {
459                // Error on encoding means the buffer is full.
460                // We will send a partial response back.
461                debug!("Sending partial response");
462                break;
463            }
464        }
465
466        timeout(RESPONSE_TIME_LIMIT, io.write_all(&buf))
467            .await
468            .map_err(|_| io::Error::other("writing response timed out"))??;
469
470        Ok(())
471    }
472}
473
474/// Reads up to `size_limit` within `time_limit`.
475async fn read_up_to<T>(io: &mut T, size_limit: usize, time_limit: Duration) -> io::Result<Vec<u8>>
476where
477    T: AsyncRead + Unpin + Send,
478{
479    let mut buf = vec![0u8; size_limit];
480    let mut read_len = 0;
481    let now = Instant::now();
482
483    loop {
484        if read_len == buf.len() {
485            // No empty space. Buffer is full.
486            break;
487        }
488
489        let Some(time_limit) = time_limit.checked_sub(now.elapsed()) else {
490            break;
491        };
492
493        let len = match timeout(time_limit, io.read(&mut buf[read_len..])).await {
494            Ok(Ok(len)) => len,
495            Ok(Err(e)) => return Err(e),
496            Err(_) => break,
497        };
498
499        if len == 0 {
500            // EOF
501            break;
502        }
503
504        read_len += len;
505    }
506
507    buf.truncate(read_len);
508
509    Ok(buf)
510}
511
512fn parse_delimiter(mut buf: &[u8]) -> Option<(usize, &[u8])> {
513    if buf.is_empty() {
514        return None;
515    }
516
517    let Ok(len) = prost::decode_length_delimiter(&mut buf) else {
518        return None;
519    };
520
521    Some((len, buf))
522}
523
524fn parse_header_response(buf: &[u8]) -> Option<(HeaderResponse, &[u8])> {
525    let (len, rest) = parse_delimiter(buf)?;
526
527    if rest.len() < len {
528        debug!("Message is incomplete: {len}");
529        return None;
530    }
531
532    let Ok(msg) = HeaderResponse::decode(&rest[..len]) else {
533        return None;
534    };
535
536    Some((msg, &rest[len..]))
537}
538
539fn parse_header_request(buf: &[u8]) -> Option<HeaderRequest> {
540    let (len, rest) = parse_delimiter(buf)?;
541
542    if rest.len() < len {
543        debug!("Message is incomplete: {len}");
544        return None;
545    }
546
547    let Ok(msg) = HeaderRequest::decode(&rest[..len]) else {
548        return None;
549    };
550
551    Some(msg)
552}
553
554#[cfg(test)]
555mod tests {
556    use super::*;
557    use bytes::BytesMut;
558    use celestia_proto::p2p::pb::header_request::Data;
559    use futures::io::{Cursor, Error};
560    use lumina_utils::test_utils::async_test;
561    use prost::encode_length_delimiter;
562    use std::io::ErrorKind;
563    use std::pin::Pin;
564
565    #[async_test]
566    async fn test_decode_header_request_empty() {
567        let header_request = HeaderRequest {
568            amount: 0,
569            data: None,
570        };
571
572        let encoded_header_request = header_request.encode_length_delimited_to_vec();
573
574        let mut reader = Cursor::new(encoded_header_request);
575
576        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
577        let mut codec = HeaderCodec {};
578
579        let decoded_header_request = codec
580            .read_request(&stream_protocol, &mut reader)
581            .await
582            .unwrap();
583
584        assert_eq!(header_request, decoded_header_request);
585    }
586
587    #[async_test]
588    async fn test_decode_multiple_small_header_response() {
589        const MSG_COUNT: usize = 10;
590        let header_response = HeaderResponse {
591            body: vec![1, 2, 3],
592            status_code: 1,
593        };
594
595        let encoded_header_response = header_response.encode_length_delimited_to_vec();
596
597        let mut multi_msg = vec![];
598        for _ in 0..MSG_COUNT {
599            multi_msg.extend_from_slice(&encoded_header_response);
600        }
601        let mut reader = Cursor::new(multi_msg);
602
603        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
604        let mut codec = HeaderCodec {};
605
606        let decoded_header_response = codec
607            .read_response(&stream_protocol, &mut reader)
608            .await
609            .unwrap();
610
611        for decoded_header in decoded_header_response.iter() {
612            assert_eq!(&header_response, decoded_header);
613        }
614        assert_eq!(decoded_header_response.len(), MSG_COUNT);
615    }
616
617    #[async_test]
618    async fn test_decode_header_request_too_large() {
619        let too_long_message_len = REQUEST_SIZE_LIMIT + 1;
620        let mut length_delimiter_buffer = BytesMut::new();
621        prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
622        let mut reader = Cursor::new(length_delimiter_buffer);
623
624        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
625        let mut codec = HeaderCodec {};
626
627        let decoding_error = codec
628            .read_request(&stream_protocol, &mut reader)
629            .await
630            .expect_err("expected error for too large request");
631
632        assert_eq!(decoding_error.kind(), ErrorKind::Other);
633    }
634
635    #[async_test]
636    async fn test_decode_header_response_too_large() {
637        let too_long_message_len = RESPONSE_SIZE_LIMIT + 1;
638        let mut length_delimiter_buffer = BytesMut::new();
639        encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
640        let mut reader = Cursor::new(length_delimiter_buffer);
641
642        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
643        let mut codec = HeaderCodec {};
644
645        let decoding_error = codec
646            .read_response(&stream_protocol, &mut reader)
647            .await
648            .expect_err("expected error for too large request");
649
650        assert_eq!(decoding_error.kind(), ErrorKind::Other);
651    }
652
653    #[test]
654    fn test_invalid_varint() {
655        // 10 consecutive bytes with continuation bit set + 1 byte, which is longer than allowed
656        //    for length delimiter
657        let varint = [
658            0b1000_0000,
659            0b1000_0000,
660            0b1000_0000,
661            0b1000_0000,
662            0b1000_0000,
663            0b1000_0000,
664            0b1000_0000,
665            0b1000_0000,
666            0b1000_0000,
667            0b1000_0000,
668            0b0000_0001,
669        ];
670
671        assert_eq!(parse_delimiter(&varint), None);
672    }
673
674    #[test]
675    fn parse_trailing_zero_varint() {
676        let varint = [0b1000_0001, 0b0000_0000, 0b1111_1111];
677        assert!(matches!(parse_delimiter(&varint), Some((1, [255]))));
678
679        let varint = [0b1000_0000, 0b1000_0000, 0b1000_0000, 0b0000_0000];
680        assert!(matches!(parse_delimiter(&varint), Some((0, []))));
681    }
682
683    #[async_test]
684    async fn test_decode_header_double_response_data() {
685        let mut header_response_buffer = BytesMut::with_capacity(512);
686        let header_response0 = HeaderResponse {
687            body: b"9999888877776666555544443333222211110000".to_vec(),
688            status_code: 1,
689        };
690        let header_response1 = HeaderResponse {
691            body: b"0000111122223333444455556666777788889999".to_vec(),
692            status_code: 2,
693        };
694        header_response0
695            .encode_length_delimited(&mut header_response_buffer)
696            .unwrap();
697        header_response1
698            .encode_length_delimited(&mut header_response_buffer)
699            .unwrap();
700        let mut reader = Cursor::new(header_response_buffer);
701
702        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
703        let mut codec = HeaderCodec {};
704
705        let decoded_header_response = codec
706            .read_response(&stream_protocol, &mut reader)
707            .await
708            .unwrap();
709        assert_eq!(header_response0, decoded_header_response[0]);
710        assert_eq!(header_response1, decoded_header_response[1]);
711    }
712
713    #[async_test]
714    async fn test_decode_header_request_chunked_data() {
715        let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
716        let header_request = HeaderRequest {
717            amount: 1,
718            data: Some(Data::Hash(data.to_vec())),
719        };
720        let encoded_header_request = header_request.encode_length_delimited_to_vec();
721
722        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
723        let mut codec = HeaderCodec {};
724        {
725            let mut reader =
726                ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_request.clone()));
727            let decoded_header_request = codec
728                .read_request(&stream_protocol, &mut reader)
729                .await
730                .unwrap();
731            assert_eq!(header_request, decoded_header_request);
732        }
733        {
734            let mut reader =
735                ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_request.clone()));
736            let decoded_header_request = codec
737                .read_request(&stream_protocol, &mut reader)
738                .await
739                .unwrap();
740
741            assert_eq!(header_request, decoded_header_request);
742        }
743        {
744            let mut reader =
745                ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_request.clone()));
746            let decoded_header_request = codec
747                .read_request(&stream_protocol, &mut reader)
748                .await
749                .unwrap();
750
751            assert_eq!(header_request, decoded_header_request);
752        }
753        {
754            let mut reader =
755                ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_request.clone()));
756            let decoded_header_request = codec
757                .read_request(&stream_protocol, &mut reader)
758                .await
759                .unwrap();
760
761            assert_eq!(header_request, decoded_header_request);
762        }
763    }
764
765    #[async_test]
766    async fn test_decode_header_response_chunked_data() {
767        let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
768        let header_response = HeaderResponse {
769            body: data.to_vec(),
770            status_code: 2,
771        };
772        let encoded_header_response = header_response.encode_length_delimited_to_vec();
773
774        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
775        let mut codec = HeaderCodec {};
776        {
777            let mut reader =
778                ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone()));
779            let decoded_header_response = codec
780                .read_response(&stream_protocol, &mut reader)
781                .await
782                .unwrap();
783            assert_eq!(header_response, decoded_header_response[0]);
784        }
785        {
786            let mut reader =
787                ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone()));
788            let decoded_header_response = codec
789                .read_response(&stream_protocol, &mut reader)
790                .await
791                .unwrap();
792
793            assert_eq!(header_response, decoded_header_response[0]);
794        }
795        {
796            let mut reader =
797                ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone()));
798            let decoded_header_response = codec
799                .read_response(&stream_protocol, &mut reader)
800                .await
801                .unwrap();
802
803            assert_eq!(header_response, decoded_header_response[0]);
804        }
805        {
806            let mut reader =
807                ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone()));
808            let decoded_header_response = codec
809                .read_response(&stream_protocol, &mut reader)
810                .await
811                .unwrap();
812
813            assert_eq!(header_response, decoded_header_response[0]);
814        }
815    }
816
817    #[async_test]
818    async fn test_chunky_async_read() {
819        let read_data = "FOO123";
820        let cur0 = Cursor::new(read_data);
821        let mut chunky = ChunkyAsyncRead::<_, 3>::new(cur0);
822
823        let mut output_buffer: BytesMut = b"BAR987".as_ref().into();
824
825        let _ = chunky.read(&mut output_buffer[..]).await.unwrap();
826        assert_eq!(output_buffer, b"FOO987".as_ref());
827    }
828
829    struct ChunkyAsyncRead<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> {
830        inner: T,
831    }
832
833    impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> ChunkyAsyncRead<T, CHUNK_SIZE> {
834        fn new(inner: T) -> Self {
835            ChunkyAsyncRead { inner }
836        }
837    }
838
839    impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> AsyncRead for ChunkyAsyncRead<T, CHUNK_SIZE> {
840        fn poll_read(
841            mut self: Pin<&mut Self>,
842            cx: &mut Context<'_>,
843            buf: &mut [u8],
844        ) -> Poll<Result<usize, Error>> {
845            let len = buf.len().min(CHUNK_SIZE);
846            Pin::new(&mut self.inner).poll_read(cx, &mut buf[..len])
847        }
848    }
849}