lumina_node/p2p/
header_ex.rs

1use std::io;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4use std::time::Duration;
5
6use async_trait::async_trait;
7use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
8use celestia_types::ExtendedHeader;
9use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10use libp2p::core::Endpoint;
11use libp2p::core::transport::PortUse;
12use libp2p::request_response::{self, Codec, InboundFailure, OutboundFailure, ProtocolSupport};
13use libp2p::swarm::handler::ConnectionEvent;
14use libp2p::swarm::{
15    ConnectionDenied, ConnectionHandler, ConnectionHandlerEvent, ConnectionId, FromSwarm,
16    NetworkBehaviour, SubstreamProtocol, THandlerInEvent, THandlerOutEvent, ToSwarm,
17};
18use libp2p::{Multiaddr, PeerId, StreamProtocol};
19use lumina_utils::time::{Instant, timeout};
20use prost::Message;
21use tracing::{debug, instrument, warn};
22
23mod client;
24mod server;
25pub(crate) mod utils;
26
27use crate::p2p::P2pError;
28use crate::p2p::header_ex::client::HeaderExClientHandler;
29use crate::p2p::header_ex::server::HeaderExServerHandler;
30use crate::peer_tracker::PeerTracker;
31use crate::store::Store;
32use crate::utils::{OneshotResultSender, protocol_id};
33
34/// Size limit of a request in bytes
35const REQUEST_SIZE_LIMIT: usize = 1024;
36/// Time limit on reading/writing a request
37const REQUEST_TIME_LIMIT: Duration = Duration::from_secs(1);
38/// Size limit of a response in bytes
39const RESPONSE_SIZE_LIMIT: usize = 10 * 1024 * 1024;
40/// Time limit on reading/writing a response
41const RESPONSE_TIME_LIMIT: Duration = Duration::from_secs(5);
42/// Substream negotiation timeout
43const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(1);
44
45type RequestType = HeaderRequest;
46type ResponseType = Vec<HeaderResponse>;
47type ReqRespBehaviour = request_response::Behaviour<HeaderCodec>;
48type ReqRespEvent = request_response::Event<RequestType, ResponseType>;
49type ReqRespMessage = request_response::Message<RequestType, ResponseType>;
50type ReqRespConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;
51
52pub(crate) struct HeaderExBehaviour<S>
53where
54    S: Store + 'static,
55{
56    req_resp: ReqRespBehaviour,
57    client_handler: HeaderExClientHandler,
58    server_handler: HeaderExServerHandler<S>,
59}
60
61pub(crate) struct HeaderExConfig<'a, S> {
62    pub network_id: &'a str,
63    pub header_store: Arc<S>,
64}
65
66/// Representation of all the errors that can occur in `HeaderEx` component.
67#[derive(Debug, thiserror::Error)]
68pub enum HeaderExError {
69    /// Header not found.
70    #[error("Header not found")]
71    HeaderNotFound,
72
73    /// The response is invalid.
74    #[error("Invalid response")]
75    InvalidResponse,
76
77    /// The request is invalid.
78    #[error("Invalid request")]
79    InvalidRequest,
80
81    /// Error when handling connection from the client.
82    #[error("Inbound failure: {0}")]
83    InboundFailure(InboundFailure),
84
85    /// Error when handling connection to the server.
86    #[error("Outbound failure: {0}")]
87    OutboundFailure(OutboundFailure),
88
89    /// Request cancelled because [`Node`] is stopping.
90    ///
91    /// [`Node`]: crate::node::Node
92    #[error("Request cancelled because `Node` is stopping")]
93    RequestCancelled,
94}
95
96impl<S> HeaderExBehaviour<S>
97where
98    S: Store + 'static,
99{
100    pub(crate) fn new(config: HeaderExConfig<'_, S>) -> Self {
101        HeaderExBehaviour {
102            req_resp: ReqRespBehaviour::new(
103                [(
104                    protocol_id(config.network_id, "/header-ex/v0.0.3"),
105                    ProtocolSupport::Full,
106                )],
107                request_response::Config::default(),
108            ),
109            client_handler: HeaderExClientHandler::new(),
110            server_handler: HeaderExServerHandler::new(config.header_store),
111        }
112    }
113
114    #[instrument(level = "trace", skip(self, respond_to))]
115    pub(crate) fn send_request(
116        &mut self,
117        request: HeaderRequest,
118        respond_to: OneshotResultSender<Vec<ExtendedHeader>, P2pError>,
119        peer_tracker: &PeerTracker,
120    ) {
121        self.client_handler
122            .on_send_request(&mut self.req_resp, request, respond_to, peer_tracker);
123    }
124
125    pub(crate) fn stop(&mut self) {
126        self.client_handler.on_stop();
127        self.server_handler.on_stop();
128    }
129
130    fn on_to_swarm(
131        &mut self,
132        ev: ToSwarm<ReqRespEvent, THandlerInEvent<ReqRespBehaviour>>,
133    ) -> Option<ToSwarm<(), THandlerInEvent<Self>>> {
134        match ev {
135            ToSwarm::GenerateEvent(ev) => {
136                self.on_req_resp_event(ev);
137                None
138            }
139            _ => Some(ev.map_out(|_| ())),
140        }
141    }
142
143    #[instrument(level = "trace", skip_all)]
144    fn on_req_resp_event(&mut self, ev: ReqRespEvent) {
145        match ev {
146            // Received a response for an ongoing outbound request
147            ReqRespEvent::Message {
148                message:
149                    ReqRespMessage::Response {
150                        request_id,
151                        response,
152                    },
153                peer,
154                ..
155            } => {
156                self.client_handler
157                    .on_response_received(peer, request_id, response);
158            }
159
160            // Failure while client requests
161            ReqRespEvent::OutboundFailure {
162                peer,
163                request_id,
164                error,
165                ..
166            } => {
167                self.client_handler.on_failure(peer, request_id, error);
168            }
169
170            // Received new inbound request
171            ReqRespEvent::Message {
172                message:
173                    ReqRespMessage::Request {
174                        request_id,
175                        request,
176                        channel,
177                    },
178                peer,
179                ..
180            } => {
181                self.server_handler.on_request_received(
182                    peer,
183                    request_id,
184                    request,
185                    &mut self.req_resp,
186                    channel,
187                );
188            }
189
190            // Response to inbound request was sent
191            ReqRespEvent::ResponseSent {
192                peer, request_id, ..
193            } => {
194                self.server_handler.on_response_sent(peer, request_id);
195            }
196
197            // Failure while server responds
198            ReqRespEvent::InboundFailure {
199                peer,
200                request_id,
201                error,
202                ..
203            } => {
204                self.server_handler.on_failure(peer, request_id, error);
205            }
206        }
207    }
208}
209
210impl<S> NetworkBehaviour for HeaderExBehaviour<S>
211where
212    S: Store + 'static,
213{
214    type ConnectionHandler = ConnHandler;
215    type ToSwarm = ();
216
217    fn handle_established_inbound_connection(
218        &mut self,
219        connection_id: ConnectionId,
220        peer: PeerId,
221        local_addr: &Multiaddr,
222        remote_addr: &Multiaddr,
223    ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
224        self.req_resp
225            .handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)
226            .map(ConnHandler)
227    }
228
229    fn handle_established_outbound_connection(
230        &mut self,
231        connection_id: ConnectionId,
232        peer: PeerId,
233        addr: &Multiaddr,
234        role_override: Endpoint,
235        port_use: PortUse,
236    ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
237        self.req_resp
238            .handle_established_outbound_connection(
239                connection_id,
240                peer,
241                addr,
242                role_override,
243                port_use,
244            )
245            .map(ConnHandler)
246    }
247
248    fn handle_pending_inbound_connection(
249        &mut self,
250        connection_id: ConnectionId,
251        local_addr: &Multiaddr,
252        remote_addr: &Multiaddr,
253    ) -> Result<(), ConnectionDenied> {
254        self.req_resp
255            .handle_pending_inbound_connection(connection_id, local_addr, remote_addr)
256    }
257
258    fn handle_pending_outbound_connection(
259        &mut self,
260        connection_id: ConnectionId,
261        maybe_peer: Option<PeerId>,
262        addresses: &[Multiaddr],
263        effective_role: Endpoint,
264    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
265        self.req_resp.handle_pending_outbound_connection(
266            connection_id,
267            maybe_peer,
268            addresses,
269            effective_role,
270        )
271    }
272
273    fn on_swarm_event(&mut self, event: FromSwarm) {
274        self.req_resp.on_swarm_event(event)
275    }
276
277    fn on_connection_handler_event(
278        &mut self,
279        peer_id: PeerId,
280        connection_id: ConnectionId,
281        event: THandlerOutEvent<Self>,
282    ) {
283        self.req_resp
284            .on_connection_handler_event(peer_id, connection_id, event)
285    }
286
287    fn poll(
288        &mut self,
289        cx: &mut Context<'_>,
290    ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
291        loop {
292            if let Poll::Ready(ev) = self.req_resp.poll(cx) {
293                if let Some(ev) = self.on_to_swarm(ev) {
294                    return Poll::Ready(ev);
295                }
296
297                continue;
298            }
299
300            if self.client_handler.poll(cx).is_ready() {
301                continue;
302            }
303
304            if self.server_handler.poll(cx, &mut self.req_resp).is_ready() {
305                continue;
306            }
307
308            return Poll::Pending;
309        }
310    }
311}
312
313pub(crate) struct ConnHandler(ReqRespConnectionHandler);
314
315impl ConnectionHandler for ConnHandler {
316    type ToBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::ToBehaviour;
317    type FromBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::FromBehaviour;
318    type InboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::InboundProtocol;
319    type InboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::InboundOpenInfo;
320    type OutboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::OutboundProtocol;
321    type OutboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::OutboundOpenInfo;
322
323    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
324        self.0.listen_protocol().with_timeout(NEGOTIATION_TIMEOUT)
325    }
326
327    fn poll(
328        &mut self,
329        cx: &mut Context<'_>,
330    ) -> Poll<
331        ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
332    > {
333        match self.0.poll(cx) {
334            Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
335                Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
336                    protocol: protocol.with_timeout(NEGOTIATION_TIMEOUT),
337                })
338            }
339            ev => ev,
340        }
341    }
342
343    fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
344        self.0.on_behaviour_event(event)
345    }
346
347    fn on_connection_event(
348        &mut self,
349        event: ConnectionEvent<
350            '_,
351            Self::InboundProtocol,
352            Self::OutboundProtocol,
353            Self::InboundOpenInfo,
354            Self::OutboundOpenInfo,
355        >,
356    ) {
357        self.0.on_connection_event(event)
358    }
359
360    fn connection_keep_alive(&self) -> bool {
361        self.0.connection_keep_alive()
362    }
363
364    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
365        self.0.poll_close(cx)
366    }
367}
368
369#[derive(Clone, Copy, Debug, Default)]
370pub(crate) struct HeaderCodec;
371
372#[async_trait]
373impl Codec for HeaderCodec {
374    type Protocol = StreamProtocol;
375    type Request = HeaderRequest;
376    type Response = Vec<HeaderResponse>;
377
378    async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
379    where
380        T: AsyncRead + Unpin + Send,
381    {
382        let data = read_up_to(io, REQUEST_SIZE_LIMIT, REQUEST_TIME_LIMIT).await?;
383
384        if data.len() >= REQUEST_SIZE_LIMIT {
385            debug!("Message filled the whole buffer (len: {})", data.len());
386        }
387
388        parse_header_request(&data).ok_or_else(|| {
389            // There are two cases that can reach here:
390            //
391            // 1. The request is invalid
392            // 2. The request is incomplete because of the size limit or time limit
393            io::Error::other("invalid or incomplete request")
394        })
395    }
396
397    async fn read_response<T>(
398        &mut self,
399        _: &Self::Protocol,
400        io: &mut T,
401    ) -> io::Result<Self::Response>
402    where
403        T: AsyncRead + Unpin + Send,
404    {
405        let data = read_up_to(io, RESPONSE_SIZE_LIMIT, RESPONSE_TIME_LIMIT).await?;
406
407        if data.len() >= RESPONSE_SIZE_LIMIT {
408            debug!("Message filled the whole buffer (len: {})", data.len());
409        }
410
411        let mut data = &data[..];
412        let mut msgs = Vec::new();
413
414        while let Some((header, rest)) = parse_header_response(data) {
415            msgs.push(header);
416            data = rest;
417        }
418
419        if msgs.is_empty() {
420            // There are two cases that can reach here:
421            //
422            // 1. The response is invalid
423            // 2. The response is incomplete because of the size limit or time limit
424            return Err(io::Error::other("invalid or incomplete response"));
425        }
426
427        Ok(msgs)
428    }
429
430    async fn write_request<T>(
431        &mut self,
432        _: &Self::Protocol,
433        io: &mut T,
434        req: Self::Request,
435    ) -> io::Result<()>
436    where
437        T: AsyncWrite + Unpin + Send,
438    {
439        let mut buf = Vec::with_capacity(REQUEST_SIZE_LIMIT);
440
441        let _ = req.encode_length_delimited(&mut buf);
442
443        timeout(REQUEST_TIME_LIMIT, io.write_all(&buf))
444            .await
445            .map_err(|_| io::Error::other("writing request timed out"))??;
446
447        Ok(())
448    }
449
450    async fn write_response<T>(
451        &mut self,
452        _: &Self::Protocol,
453        io: &mut T,
454        resps: Self::Response,
455    ) -> io::Result<()>
456    where
457        T: AsyncWrite + Unpin + Send,
458    {
459        let mut buf = Vec::with_capacity(RESPONSE_SIZE_LIMIT);
460
461        for resp in resps {
462            if resp.encode_length_delimited(&mut buf).is_err() {
463                // Error on encoding means the buffer is full.
464                // We will send a partial response back.
465                debug!("Sending partial response");
466                break;
467            }
468        }
469
470        timeout(RESPONSE_TIME_LIMIT, io.write_all(&buf))
471            .await
472            .map_err(|_| io::Error::other("writing response timed out"))??;
473
474        Ok(())
475    }
476}
477
478/// Reads up to `size_limit` within `time_limit`.
479async fn read_up_to<T>(io: &mut T, size_limit: usize, time_limit: Duration) -> io::Result<Vec<u8>>
480where
481    T: AsyncRead + Unpin + Send,
482{
483    let mut buf = vec![0u8; size_limit];
484    let mut read_len = 0;
485    let now = Instant::now();
486
487    loop {
488        if read_len == buf.len() {
489            // No empty space. Buffer is full.
490            break;
491        }
492
493        let Some(time_limit) = time_limit.checked_sub(now.elapsed()) else {
494            break;
495        };
496
497        let len = match timeout(time_limit, io.read(&mut buf[read_len..])).await {
498            Ok(Ok(len)) => len,
499            Ok(Err(e)) => return Err(e),
500            Err(_) => break,
501        };
502
503        if len == 0 {
504            // EOF
505            break;
506        }
507
508        read_len += len;
509    }
510
511    buf.truncate(read_len);
512
513    Ok(buf)
514}
515
516fn parse_delimiter(mut buf: &[u8]) -> Option<(usize, &[u8])> {
517    if buf.is_empty() {
518        return None;
519    }
520
521    let Ok(len) = prost::decode_length_delimiter(&mut buf) else {
522        return None;
523    };
524
525    Some((len, buf))
526}
527
528fn parse_header_response(buf: &[u8]) -> Option<(HeaderResponse, &[u8])> {
529    let (len, rest) = parse_delimiter(buf)?;
530
531    if rest.len() < len {
532        debug!("Message is incomplete: {len}");
533        return None;
534    }
535
536    let Ok(msg) = HeaderResponse::decode(&rest[..len]) else {
537        return None;
538    };
539
540    Some((msg, &rest[len..]))
541}
542
543fn parse_header_request(buf: &[u8]) -> Option<HeaderRequest> {
544    let (len, rest) = parse_delimiter(buf)?;
545
546    if rest.len() < len {
547        debug!("Message is incomplete: {len}");
548        return None;
549    }
550
551    let Ok(msg) = HeaderRequest::decode(&rest[..len]) else {
552        return None;
553    };
554
555    Some(msg)
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561    use bytes::BytesMut;
562    use celestia_proto::p2p::pb::header_request::Data;
563    use futures::io::{Cursor, Error};
564    use lumina_utils::test_utils::async_test;
565    use prost::encode_length_delimiter;
566    use std::io::ErrorKind;
567    use std::pin::Pin;
568
569    #[async_test]
570    async fn test_decode_header_request_empty() {
571        let header_request = HeaderRequest {
572            amount: 0,
573            data: None,
574        };
575
576        let encoded_header_request = header_request.encode_length_delimited_to_vec();
577
578        let mut reader = Cursor::new(encoded_header_request);
579
580        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
581        let mut codec = HeaderCodec {};
582
583        let decoded_header_request = codec
584            .read_request(&stream_protocol, &mut reader)
585            .await
586            .unwrap();
587
588        assert_eq!(header_request, decoded_header_request);
589    }
590
591    #[async_test]
592    async fn test_decode_multiple_small_header_response() {
593        const MSG_COUNT: usize = 10;
594        let header_response = HeaderResponse {
595            body: vec![1, 2, 3],
596            status_code: 1,
597        };
598
599        let encoded_header_response = header_response.encode_length_delimited_to_vec();
600
601        let mut multi_msg = vec![];
602        for _ in 0..MSG_COUNT {
603            multi_msg.extend_from_slice(&encoded_header_response);
604        }
605        let mut reader = Cursor::new(multi_msg);
606
607        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
608        let mut codec = HeaderCodec {};
609
610        let decoded_header_response = codec
611            .read_response(&stream_protocol, &mut reader)
612            .await
613            .unwrap();
614
615        for decoded_header in decoded_header_response.iter() {
616            assert_eq!(&header_response, decoded_header);
617        }
618        assert_eq!(decoded_header_response.len(), MSG_COUNT);
619    }
620
621    #[async_test]
622    async fn test_decode_header_request_too_large() {
623        let too_long_message_len = REQUEST_SIZE_LIMIT + 1;
624        let mut length_delimiter_buffer = BytesMut::new();
625        prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
626        let mut reader = Cursor::new(length_delimiter_buffer);
627
628        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
629        let mut codec = HeaderCodec {};
630
631        let decoding_error = codec
632            .read_request(&stream_protocol, &mut reader)
633            .await
634            .expect_err("expected error for too large request");
635
636        assert_eq!(decoding_error.kind(), ErrorKind::Other);
637    }
638
639    #[async_test]
640    async fn test_decode_header_response_too_large() {
641        let too_long_message_len = RESPONSE_SIZE_LIMIT + 1;
642        let mut length_delimiter_buffer = BytesMut::new();
643        encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
644        let mut reader = Cursor::new(length_delimiter_buffer);
645
646        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
647        let mut codec = HeaderCodec {};
648
649        let decoding_error = codec
650            .read_response(&stream_protocol, &mut reader)
651            .await
652            .expect_err("expected error for too large request");
653
654        assert_eq!(decoding_error.kind(), ErrorKind::Other);
655    }
656
657    #[test]
658    fn test_invalid_varint() {
659        // 10 consecutive bytes with continuation bit set + 1 byte, which is longer than allowed
660        //    for length delimiter
661        let varint = [
662            0b1000_0000,
663            0b1000_0000,
664            0b1000_0000,
665            0b1000_0000,
666            0b1000_0000,
667            0b1000_0000,
668            0b1000_0000,
669            0b1000_0000,
670            0b1000_0000,
671            0b1000_0000,
672            0b0000_0001,
673        ];
674
675        assert_eq!(parse_delimiter(&varint), None);
676    }
677
678    #[test]
679    fn parse_trailing_zero_varint() {
680        let varint = [0b1000_0001, 0b0000_0000, 0b1111_1111];
681        assert!(matches!(parse_delimiter(&varint), Some((1, [255]))));
682
683        let varint = [0b1000_0000, 0b1000_0000, 0b1000_0000, 0b0000_0000];
684        assert!(matches!(parse_delimiter(&varint), Some((0, []))));
685    }
686
687    #[async_test]
688    async fn test_decode_header_double_response_data() {
689        let mut header_response_buffer = BytesMut::with_capacity(512);
690        let header_response0 = HeaderResponse {
691            body: b"9999888877776666555544443333222211110000".to_vec(),
692            status_code: 1,
693        };
694        let header_response1 = HeaderResponse {
695            body: b"0000111122223333444455556666777788889999".to_vec(),
696            status_code: 2,
697        };
698        header_response0
699            .encode_length_delimited(&mut header_response_buffer)
700            .unwrap();
701        header_response1
702            .encode_length_delimited(&mut header_response_buffer)
703            .unwrap();
704        let mut reader = Cursor::new(header_response_buffer);
705
706        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
707        let mut codec = HeaderCodec {};
708
709        let decoded_header_response = codec
710            .read_response(&stream_protocol, &mut reader)
711            .await
712            .unwrap();
713        assert_eq!(header_response0, decoded_header_response[0]);
714        assert_eq!(header_response1, decoded_header_response[1]);
715    }
716
717    #[async_test]
718    async fn test_decode_header_request_chunked_data() {
719        let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
720        let header_request = HeaderRequest {
721            amount: 1,
722            data: Some(Data::Hash(data.to_vec())),
723        };
724        let encoded_header_request = header_request.encode_length_delimited_to_vec();
725
726        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
727        let mut codec = HeaderCodec {};
728        {
729            let mut reader =
730                ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_request.clone()));
731            let decoded_header_request = codec
732                .read_request(&stream_protocol, &mut reader)
733                .await
734                .unwrap();
735            assert_eq!(header_request, decoded_header_request);
736        }
737        {
738            let mut reader =
739                ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_request.clone()));
740            let decoded_header_request = codec
741                .read_request(&stream_protocol, &mut reader)
742                .await
743                .unwrap();
744
745            assert_eq!(header_request, decoded_header_request);
746        }
747        {
748            let mut reader =
749                ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_request.clone()));
750            let decoded_header_request = codec
751                .read_request(&stream_protocol, &mut reader)
752                .await
753                .unwrap();
754
755            assert_eq!(header_request, decoded_header_request);
756        }
757        {
758            let mut reader =
759                ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_request.clone()));
760            let decoded_header_request = codec
761                .read_request(&stream_protocol, &mut reader)
762                .await
763                .unwrap();
764
765            assert_eq!(header_request, decoded_header_request);
766        }
767    }
768
769    #[async_test]
770    async fn test_decode_header_response_chunked_data() {
771        let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
772        let header_response = HeaderResponse {
773            body: data.to_vec(),
774            status_code: 2,
775        };
776        let encoded_header_response = header_response.encode_length_delimited_to_vec();
777
778        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
779        let mut codec = HeaderCodec {};
780        {
781            let mut reader =
782                ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone()));
783            let decoded_header_response = codec
784                .read_response(&stream_protocol, &mut reader)
785                .await
786                .unwrap();
787            assert_eq!(header_response, decoded_header_response[0]);
788        }
789        {
790            let mut reader =
791                ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone()));
792            let decoded_header_response = codec
793                .read_response(&stream_protocol, &mut reader)
794                .await
795                .unwrap();
796
797            assert_eq!(header_response, decoded_header_response[0]);
798        }
799        {
800            let mut reader =
801                ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone()));
802            let decoded_header_response = codec
803                .read_response(&stream_protocol, &mut reader)
804                .await
805                .unwrap();
806
807            assert_eq!(header_response, decoded_header_response[0]);
808        }
809        {
810            let mut reader =
811                ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone()));
812            let decoded_header_response = codec
813                .read_response(&stream_protocol, &mut reader)
814                .await
815                .unwrap();
816
817            assert_eq!(header_response, decoded_header_response[0]);
818        }
819    }
820
821    #[async_test]
822    async fn test_chunky_async_read() {
823        let read_data = "FOO123";
824        let cur0 = Cursor::new(read_data);
825        let mut chunky = ChunkyAsyncRead::<_, 3>::new(cur0);
826
827        let mut output_buffer: BytesMut = b"BAR987".as_ref().into();
828
829        let _ = chunky.read(&mut output_buffer[..]).await.unwrap();
830        assert_eq!(output_buffer, b"FOO987".as_ref());
831    }
832
833    struct ChunkyAsyncRead<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> {
834        inner: T,
835    }
836
837    impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> ChunkyAsyncRead<T, CHUNK_SIZE> {
838        fn new(inner: T) -> Self {
839            ChunkyAsyncRead { inner }
840        }
841    }
842
843    impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> AsyncRead for ChunkyAsyncRead<T, CHUNK_SIZE> {
844        fn poll_read(
845            mut self: Pin<&mut Self>,
846            cx: &mut Context<'_>,
847            buf: &mut [u8],
848        ) -> Poll<Result<usize, Error>> {
849            let len = buf.len().min(CHUNK_SIZE);
850            Pin::new(&mut self.inner).poll_read(cx, &mut buf[..len])
851        }
852    }
853}