lumina_node/p2p/
header_ex.rs

1use std::io;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4use std::time::Duration;
5
6use async_trait::async_trait;
7use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
8use celestia_types::ExtendedHeader;
9use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10use libp2p::core::Endpoint;
11use libp2p::core::transport::PortUse;
12use libp2p::request_response::{self, Codec, InboundFailure, OutboundFailure, ProtocolSupport};
13use libp2p::swarm::handler::ConnectionEvent;
14use libp2p::swarm::{
15    ConnectionDenied, ConnectionHandler, ConnectionHandlerEvent, ConnectionId, FromSwarm,
16    NetworkBehaviour, SubstreamProtocol, THandlerInEvent, THandlerOutEvent, ToSwarm,
17};
18use libp2p::{Multiaddr, PeerId, StreamProtocol};
19use lumina_utils::time::{Instant, timeout};
20use prost::Message;
21use tracing::{debug, instrument, warn};
22
23mod client;
24mod server;
25pub(crate) mod utils;
26
27use crate::p2p::P2pError;
28use crate::p2p::header_ex::client::HeaderExClientHandler;
29use crate::p2p::header_ex::server::HeaderExServerHandler;
30use crate::peer_tracker::PeerTracker;
31use crate::store::Store;
32use crate::utils::{OneshotResultSender, protocol_id};
33
34/// Size limit of a request in bytes
35const REQUEST_SIZE_LIMIT: usize = 1024;
36/// Time limit on reading/writing a request
37const REQUEST_TIME_LIMIT: Duration = Duration::from_secs(1);
38/// Size limit of a response in bytes
39const RESPONSE_SIZE_LIMIT: usize = 10 * 1024 * 1024;
40/// Time limit on reading/writing a response
41const RESPONSE_TIME_LIMIT: Duration = Duration::from_secs(5);
42/// Substream negotiation timeout
43const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(1);
44
45type RequestType = HeaderRequest;
46type ResponseType = Vec<HeaderResponse>;
47type ReqRespBehaviour = request_response::Behaviour<HeaderCodec>;
48type ReqRespEvent = request_response::Event<RequestType, ResponseType>;
49type ReqRespMessage = request_response::Message<RequestType, ResponseType>;
50type ReqRespConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;
51
52pub(crate) struct Behaviour<S>
53where
54    S: Store + 'static,
55{
56    req_resp: ReqRespBehaviour,
57    client_handler: HeaderExClientHandler,
58    server_handler: HeaderExServerHandler<S>,
59}
60
61pub(crate) struct Config<'a, S> {
62    pub network_id: &'a str,
63    pub header_store: Arc<S>,
64}
65
66/// Representation of all the errors that can occur in `HeaderEx` component.
67#[derive(Debug, thiserror::Error)]
68pub enum HeaderExError {
69    /// Header not found.
70    #[error("Header not found")]
71    HeaderNotFound,
72
73    /// The response is invalid.
74    #[error("Invalid response")]
75    InvalidResponse,
76
77    /// The request is invalid.
78    #[error("Invalid request")]
79    InvalidRequest,
80
81    /// Error when handling connection from the client.
82    #[error("Inbound failure: {0}")]
83    InboundFailure(InboundFailure),
84
85    /// Error when handling connection to the server.
86    #[error("Outbound failure: {0}")]
87    OutboundFailure(OutboundFailure),
88
89    /// Request cancelled because [`Node`] is stopping.
90    ///
91    /// [`Node`]: crate::node::Node
92    #[error("Request cancelled because `Node` is stopping")]
93    RequestCancelled,
94}
95
96#[derive(Debug)]
97pub(crate) enum Event {
98    SchedulePendingRequests,
99    NeedTrustedPeers,
100    NeedArchivalPeers,
101}
102
103impl<S> Behaviour<S>
104where
105    S: Store + 'static,
106{
107    pub(crate) fn new(config: Config<'_, S>) -> Self {
108        Behaviour {
109            req_resp: ReqRespBehaviour::new(
110                [(
111                    protocol_id(config.network_id, "/header-ex/v0.0.3"),
112                    ProtocolSupport::Full,
113                )],
114                request_response::Config::default(),
115            ),
116            client_handler: HeaderExClientHandler::new(),
117            server_handler: HeaderExServerHandler::new(config.header_store),
118        }
119    }
120
121    #[instrument(level = "trace", skip(self, respond_to))]
122    pub(crate) fn send_request(
123        &mut self,
124        request: HeaderRequest,
125        respond_to: OneshotResultSender<Vec<ExtendedHeader>, P2pError>,
126    ) {
127        self.client_handler.on_send_request(request, respond_to);
128    }
129
130    pub(crate) fn schedule_pending_requests(&mut self, peer_tracker: &PeerTracker) {
131        self.client_handler
132            .schedule_pending_requests(&mut self.req_resp, peer_tracker);
133    }
134
135    pub(crate) fn stop(&mut self) {
136        self.client_handler.on_stop();
137        self.server_handler.on_stop();
138    }
139
140    fn on_to_swarm(
141        &mut self,
142        ev: ToSwarm<ReqRespEvent, THandlerInEvent<ReqRespBehaviour>>,
143    ) -> Option<ToSwarm<Event, THandlerInEvent<Self>>> {
144        match ev {
145            ToSwarm::GenerateEvent(ev) => {
146                self.on_req_resp_event(ev);
147                None
148            }
149            _ => Some(ev.map_out(|_| unreachable!("GenerateEvent handled"))),
150        }
151    }
152
153    #[instrument(level = "trace", skip_all)]
154    fn on_req_resp_event(&mut self, ev: ReqRespEvent) {
155        match ev {
156            // Received a response for an ongoing outbound request
157            ReqRespEvent::Message {
158                message:
159                    ReqRespMessage::Response {
160                        request_id,
161                        response,
162                    },
163                peer,
164                ..
165            } => {
166                self.client_handler
167                    .on_response_received(peer, request_id, response);
168            }
169
170            // Failure while client requests
171            ReqRespEvent::OutboundFailure {
172                peer,
173                request_id,
174                error,
175                ..
176            } => {
177                self.client_handler.on_failure(peer, request_id, error);
178            }
179
180            // Received new inbound request
181            ReqRespEvent::Message {
182                message:
183                    ReqRespMessage::Request {
184                        request_id,
185                        request,
186                        channel,
187                    },
188                peer,
189                ..
190            } => {
191                self.server_handler.on_request_received(
192                    peer,
193                    request_id,
194                    request,
195                    &mut self.req_resp,
196                    channel,
197                );
198            }
199
200            // Response to inbound request was sent
201            ReqRespEvent::ResponseSent {
202                peer, request_id, ..
203            } => {
204                self.server_handler.on_response_sent(peer, request_id);
205            }
206
207            // Failure while server responds
208            ReqRespEvent::InboundFailure {
209                peer,
210                request_id,
211                error,
212                ..
213            } => {
214                self.server_handler.on_failure(peer, request_id, error);
215            }
216        }
217    }
218}
219
220impl<S> NetworkBehaviour for Behaviour<S>
221where
222    S: Store + 'static,
223{
224    type ConnectionHandler = ConnHandler;
225    type ToSwarm = Event;
226
227    fn handle_established_inbound_connection(
228        &mut self,
229        connection_id: ConnectionId,
230        peer: PeerId,
231        local_addr: &Multiaddr,
232        remote_addr: &Multiaddr,
233    ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
234        self.req_resp
235            .handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)
236            .map(ConnHandler)
237    }
238
239    fn handle_established_outbound_connection(
240        &mut self,
241        connection_id: ConnectionId,
242        peer: PeerId,
243        addr: &Multiaddr,
244        role_override: Endpoint,
245        port_use: PortUse,
246    ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
247        self.req_resp
248            .handle_established_outbound_connection(
249                connection_id,
250                peer,
251                addr,
252                role_override,
253                port_use,
254            )
255            .map(ConnHandler)
256    }
257
258    fn handle_pending_inbound_connection(
259        &mut self,
260        connection_id: ConnectionId,
261        local_addr: &Multiaddr,
262        remote_addr: &Multiaddr,
263    ) -> Result<(), ConnectionDenied> {
264        self.req_resp
265            .handle_pending_inbound_connection(connection_id, local_addr, remote_addr)
266    }
267
268    fn handle_pending_outbound_connection(
269        &mut self,
270        connection_id: ConnectionId,
271        maybe_peer: Option<PeerId>,
272        addresses: &[Multiaddr],
273        effective_role: Endpoint,
274    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
275        self.req_resp.handle_pending_outbound_connection(
276            connection_id,
277            maybe_peer,
278            addresses,
279            effective_role,
280        )
281    }
282
283    fn on_swarm_event(&mut self, event: FromSwarm) {
284        self.req_resp.on_swarm_event(event)
285    }
286
287    fn on_connection_handler_event(
288        &mut self,
289        peer_id: PeerId,
290        connection_id: ConnectionId,
291        event: THandlerOutEvent<Self>,
292    ) {
293        self.req_resp
294            .on_connection_handler_event(peer_id, connection_id, event)
295    }
296
297    fn poll(
298        &mut self,
299        cx: &mut Context<'_>,
300    ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
301        loop {
302            if let Poll::Ready(ev) = self.req_resp.poll(cx) {
303                if let Some(ev) = self.on_to_swarm(ev) {
304                    return Poll::Ready(ev);
305                }
306
307                continue;
308            }
309
310            if let Poll::Ready(ev) = self.client_handler.poll(cx) {
311                return Poll::Ready(ToSwarm::GenerateEvent(ev));
312            }
313
314            if self.server_handler.poll(cx, &mut self.req_resp).is_ready() {
315                continue;
316            }
317
318            return Poll::Pending;
319        }
320    }
321}
322
323pub(crate) struct ConnHandler(ReqRespConnectionHandler);
324
325impl ConnectionHandler for ConnHandler {
326    type ToBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::ToBehaviour;
327    type FromBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::FromBehaviour;
328    type InboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::InboundProtocol;
329    type InboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::InboundOpenInfo;
330    type OutboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::OutboundProtocol;
331    type OutboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::OutboundOpenInfo;
332
333    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
334        self.0.listen_protocol().with_timeout(NEGOTIATION_TIMEOUT)
335    }
336
337    fn poll(
338        &mut self,
339        cx: &mut Context<'_>,
340    ) -> Poll<
341        ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
342    > {
343        match self.0.poll(cx) {
344            Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
345                Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
346                    protocol: protocol.with_timeout(NEGOTIATION_TIMEOUT),
347                })
348            }
349            ev => ev,
350        }
351    }
352
353    fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
354        self.0.on_behaviour_event(event)
355    }
356
357    fn on_connection_event(
358        &mut self,
359        event: ConnectionEvent<
360            '_,
361            Self::InboundProtocol,
362            Self::OutboundProtocol,
363            Self::InboundOpenInfo,
364            Self::OutboundOpenInfo,
365        >,
366    ) {
367        self.0.on_connection_event(event)
368    }
369
370    fn connection_keep_alive(&self) -> bool {
371        self.0.connection_keep_alive()
372    }
373
374    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
375        self.0.poll_close(cx)
376    }
377}
378
379#[derive(Clone, Copy, Debug, Default)]
380pub(crate) struct HeaderCodec;
381
382#[async_trait]
383impl Codec for HeaderCodec {
384    type Protocol = StreamProtocol;
385    type Request = HeaderRequest;
386    type Response = Vec<HeaderResponse>;
387
388    async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
389    where
390        T: AsyncRead + Unpin + Send,
391    {
392        let data = read_up_to(io, REQUEST_SIZE_LIMIT, REQUEST_TIME_LIMIT).await?;
393
394        if data.len() >= REQUEST_SIZE_LIMIT {
395            debug!("Message filled the whole buffer (len: {})", data.len());
396        }
397
398        parse_header_request(&data).ok_or_else(|| {
399            // There are two cases that can reach here:
400            //
401            // 1. The request is invalid
402            // 2. The request is incomplete because of the size limit or time limit
403            io::Error::other("invalid or incomplete request")
404        })
405    }
406
407    async fn read_response<T>(
408        &mut self,
409        _: &Self::Protocol,
410        io: &mut T,
411    ) -> io::Result<Self::Response>
412    where
413        T: AsyncRead + Unpin + Send,
414    {
415        let data = read_up_to(io, RESPONSE_SIZE_LIMIT, RESPONSE_TIME_LIMIT).await?;
416
417        if data.len() >= RESPONSE_SIZE_LIMIT {
418            debug!("Message filled the whole buffer (len: {})", data.len());
419        }
420
421        let mut data = &data[..];
422        let mut msgs = Vec::new();
423
424        while let Some((header, rest)) = parse_header_response(data) {
425            msgs.push(header);
426            data = rest;
427        }
428
429        if msgs.is_empty() {
430            // There are two cases that can reach here:
431            //
432            // 1. The response is invalid
433            // 2. The response is incomplete because of the size limit or time limit
434            return Err(io::Error::other("invalid or incomplete response"));
435        }
436
437        Ok(msgs)
438    }
439
440    async fn write_request<T>(
441        &mut self,
442        _: &Self::Protocol,
443        io: &mut T,
444        req: Self::Request,
445    ) -> io::Result<()>
446    where
447        T: AsyncWrite + Unpin + Send,
448    {
449        let mut buf = Vec::with_capacity(REQUEST_SIZE_LIMIT);
450
451        let _ = req.encode_length_delimited(&mut buf);
452
453        timeout(REQUEST_TIME_LIMIT, io.write_all(&buf))
454            .await
455            .map_err(|_| io::Error::other("writing request timed out"))??;
456
457        Ok(())
458    }
459
460    async fn write_response<T>(
461        &mut self,
462        _: &Self::Protocol,
463        io: &mut T,
464        resps: Self::Response,
465    ) -> io::Result<()>
466    where
467        T: AsyncWrite + Unpin + Send,
468    {
469        let mut buf = Vec::with_capacity(RESPONSE_SIZE_LIMIT);
470
471        for resp in resps {
472            if resp.encode_length_delimited(&mut buf).is_err() {
473                // Error on encoding means the buffer is full.
474                // We will send a partial response back.
475                debug!("Sending partial response");
476                break;
477            }
478        }
479
480        timeout(RESPONSE_TIME_LIMIT, io.write_all(&buf))
481            .await
482            .map_err(|_| io::Error::other("writing response timed out"))??;
483
484        Ok(())
485    }
486}
487
488/// Reads up to `size_limit` within `time_limit`.
489async fn read_up_to<T>(io: &mut T, size_limit: usize, time_limit: Duration) -> io::Result<Vec<u8>>
490where
491    T: AsyncRead + Unpin + Send,
492{
493    let mut buf = vec![0u8; size_limit];
494    let mut read_len = 0;
495    let now = Instant::now();
496
497    loop {
498        if read_len == buf.len() {
499            // No empty space. Buffer is full.
500            break;
501        }
502
503        let Some(time_limit) = time_limit.checked_sub(now.elapsed()) else {
504            break;
505        };
506
507        let len = match timeout(time_limit, io.read(&mut buf[read_len..])).await {
508            Ok(Ok(len)) => len,
509            Ok(Err(e)) => return Err(e),
510            Err(_) => break,
511        };
512
513        if len == 0 {
514            // EOF
515            break;
516        }
517
518        read_len += len;
519    }
520
521    buf.truncate(read_len);
522
523    Ok(buf)
524}
525
526fn parse_delimiter(mut buf: &[u8]) -> Option<(usize, &[u8])> {
527    if buf.is_empty() {
528        return None;
529    }
530
531    let Ok(len) = prost::decode_length_delimiter(&mut buf) else {
532        return None;
533    };
534
535    Some((len, buf))
536}
537
538fn parse_header_response(buf: &[u8]) -> Option<(HeaderResponse, &[u8])> {
539    let (len, rest) = parse_delimiter(buf)?;
540
541    if rest.len() < len {
542        debug!("Message is incomplete: {len}");
543        return None;
544    }
545
546    let Ok(msg) = HeaderResponse::decode(&rest[..len]) else {
547        return None;
548    };
549
550    Some((msg, &rest[len..]))
551}
552
553fn parse_header_request(buf: &[u8]) -> Option<HeaderRequest> {
554    let (len, rest) = parse_delimiter(buf)?;
555
556    if rest.len() < len {
557        debug!("Message is incomplete: {len}");
558        return None;
559    }
560
561    let Ok(msg) = HeaderRequest::decode(&rest[..len]) else {
562        return None;
563    };
564
565    Some(msg)
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571    use bytes::BytesMut;
572    use celestia_proto::p2p::pb::header_request::Data;
573    use futures::io::{Cursor, Error};
574    use lumina_utils::test_utils::async_test;
575    use prost::encode_length_delimiter;
576    use std::io::ErrorKind;
577    use std::pin::Pin;
578
579    #[async_test]
580    async fn test_decode_header_request_empty() {
581        let header_request = HeaderRequest {
582            amount: 0,
583            data: None,
584        };
585
586        let encoded_header_request = header_request.encode_length_delimited_to_vec();
587
588        let mut reader = Cursor::new(encoded_header_request);
589
590        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
591        let mut codec = HeaderCodec {};
592
593        let decoded_header_request = codec
594            .read_request(&stream_protocol, &mut reader)
595            .await
596            .unwrap();
597
598        assert_eq!(header_request, decoded_header_request);
599    }
600
601    #[async_test]
602    async fn test_decode_multiple_small_header_response() {
603        const MSG_COUNT: usize = 10;
604        let header_response = HeaderResponse {
605            body: vec![1, 2, 3],
606            status_code: 1,
607        };
608
609        let encoded_header_response = header_response.encode_length_delimited_to_vec();
610
611        let mut multi_msg = vec![];
612        for _ in 0..MSG_COUNT {
613            multi_msg.extend_from_slice(&encoded_header_response);
614        }
615        let mut reader = Cursor::new(multi_msg);
616
617        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
618        let mut codec = HeaderCodec {};
619
620        let decoded_header_response = codec
621            .read_response(&stream_protocol, &mut reader)
622            .await
623            .unwrap();
624
625        for decoded_header in decoded_header_response.iter() {
626            assert_eq!(&header_response, decoded_header);
627        }
628        assert_eq!(decoded_header_response.len(), MSG_COUNT);
629    }
630
631    #[async_test]
632    async fn test_decode_header_request_too_large() {
633        let too_long_message_len = REQUEST_SIZE_LIMIT + 1;
634        let mut length_delimiter_buffer = BytesMut::new();
635        prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
636        let mut reader = Cursor::new(length_delimiter_buffer);
637
638        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
639        let mut codec = HeaderCodec {};
640
641        let decoding_error = codec
642            .read_request(&stream_protocol, &mut reader)
643            .await
644            .expect_err("expected error for too large request");
645
646        assert_eq!(decoding_error.kind(), ErrorKind::Other);
647    }
648
649    #[async_test]
650    async fn test_decode_header_response_too_large() {
651        let too_long_message_len = RESPONSE_SIZE_LIMIT + 1;
652        let mut length_delimiter_buffer = BytesMut::new();
653        encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
654        let mut reader = Cursor::new(length_delimiter_buffer);
655
656        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
657        let mut codec = HeaderCodec {};
658
659        let decoding_error = codec
660            .read_response(&stream_protocol, &mut reader)
661            .await
662            .expect_err("expected error for too large request");
663
664        assert_eq!(decoding_error.kind(), ErrorKind::Other);
665    }
666
667    #[test]
668    fn test_invalid_varint() {
669        // 10 consecutive bytes with continuation bit set + 1 byte, which is longer than allowed
670        //    for length delimiter
671        let varint = [
672            0b1000_0000,
673            0b1000_0000,
674            0b1000_0000,
675            0b1000_0000,
676            0b1000_0000,
677            0b1000_0000,
678            0b1000_0000,
679            0b1000_0000,
680            0b1000_0000,
681            0b1000_0000,
682            0b0000_0001,
683        ];
684
685        assert_eq!(parse_delimiter(&varint), None);
686    }
687
688    #[test]
689    fn parse_trailing_zero_varint() {
690        let varint = [0b1000_0001, 0b0000_0000, 0b1111_1111];
691        assert!(matches!(parse_delimiter(&varint), Some((1, [255]))));
692
693        let varint = [0b1000_0000, 0b1000_0000, 0b1000_0000, 0b0000_0000];
694        assert!(matches!(parse_delimiter(&varint), Some((0, []))));
695    }
696
697    #[async_test]
698    async fn test_decode_header_double_response_data() {
699        let mut header_response_buffer = BytesMut::with_capacity(512);
700        let header_response0 = HeaderResponse {
701            body: b"9999888877776666555544443333222211110000".to_vec(),
702            status_code: 1,
703        };
704        let header_response1 = HeaderResponse {
705            body: b"0000111122223333444455556666777788889999".to_vec(),
706            status_code: 2,
707        };
708        header_response0
709            .encode_length_delimited(&mut header_response_buffer)
710            .unwrap();
711        header_response1
712            .encode_length_delimited(&mut header_response_buffer)
713            .unwrap();
714        let mut reader = Cursor::new(header_response_buffer);
715
716        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
717        let mut codec = HeaderCodec {};
718
719        let decoded_header_response = codec
720            .read_response(&stream_protocol, &mut reader)
721            .await
722            .unwrap();
723        assert_eq!(header_response0, decoded_header_response[0]);
724        assert_eq!(header_response1, decoded_header_response[1]);
725    }
726
727    #[async_test]
728    async fn test_decode_header_request_chunked_data() {
729        let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
730        let header_request = HeaderRequest {
731            amount: 1,
732            data: Some(Data::Hash(data.to_vec())),
733        };
734        let encoded_header_request = header_request.encode_length_delimited_to_vec();
735
736        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
737        let mut codec = HeaderCodec {};
738        {
739            let mut reader =
740                ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_request.clone()));
741            let decoded_header_request = codec
742                .read_request(&stream_protocol, &mut reader)
743                .await
744                .unwrap();
745            assert_eq!(header_request, decoded_header_request);
746        }
747        {
748            let mut reader =
749                ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_request.clone()));
750            let decoded_header_request = codec
751                .read_request(&stream_protocol, &mut reader)
752                .await
753                .unwrap();
754
755            assert_eq!(header_request, decoded_header_request);
756        }
757        {
758            let mut reader =
759                ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_request.clone()));
760            let decoded_header_request = codec
761                .read_request(&stream_protocol, &mut reader)
762                .await
763                .unwrap();
764
765            assert_eq!(header_request, decoded_header_request);
766        }
767        {
768            let mut reader =
769                ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_request.clone()));
770            let decoded_header_request = codec
771                .read_request(&stream_protocol, &mut reader)
772                .await
773                .unwrap();
774
775            assert_eq!(header_request, decoded_header_request);
776        }
777    }
778
779    #[async_test]
780    async fn test_decode_header_response_chunked_data() {
781        let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
782        let header_response = HeaderResponse {
783            body: data.to_vec(),
784            status_code: 2,
785        };
786        let encoded_header_response = header_response.encode_length_delimited_to_vec();
787
788        let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
789        let mut codec = HeaderCodec {};
790        {
791            let mut reader =
792                ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone()));
793            let decoded_header_response = codec
794                .read_response(&stream_protocol, &mut reader)
795                .await
796                .unwrap();
797            assert_eq!(header_response, decoded_header_response[0]);
798        }
799        {
800            let mut reader =
801                ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone()));
802            let decoded_header_response = codec
803                .read_response(&stream_protocol, &mut reader)
804                .await
805                .unwrap();
806
807            assert_eq!(header_response, decoded_header_response[0]);
808        }
809        {
810            let mut reader =
811                ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone()));
812            let decoded_header_response = codec
813                .read_response(&stream_protocol, &mut reader)
814                .await
815                .unwrap();
816
817            assert_eq!(header_response, decoded_header_response[0]);
818        }
819        {
820            let mut reader =
821                ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone()));
822            let decoded_header_response = codec
823                .read_response(&stream_protocol, &mut reader)
824                .await
825                .unwrap();
826
827            assert_eq!(header_response, decoded_header_response[0]);
828        }
829    }
830
831    #[async_test]
832    async fn test_chunky_async_read() {
833        let read_data = "FOO123";
834        let cur0 = Cursor::new(read_data);
835        let mut chunky = ChunkyAsyncRead::<_, 3>::new(cur0);
836
837        let mut output_buffer: BytesMut = b"BAR987".as_ref().into();
838
839        let _ = chunky.read(&mut output_buffer[..]).await.unwrap();
840        assert_eq!(output_buffer, b"FOO987".as_ref());
841    }
842
843    struct ChunkyAsyncRead<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> {
844        inner: T,
845    }
846
847    impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> ChunkyAsyncRead<T, CHUNK_SIZE> {
848        fn new(inner: T) -> Self {
849            ChunkyAsyncRead { inner }
850        }
851    }
852
853    impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> AsyncRead for ChunkyAsyncRead<T, CHUNK_SIZE> {
854        fn poll_read(
855            mut self: Pin<&mut Self>,
856            cx: &mut Context<'_>,
857            buf: &mut [u8],
858        ) -> Poll<Result<usize, Error>> {
859            let len = buf.len().min(CHUNK_SIZE);
860            Pin::new(&mut self.inner).poll_read(cx, &mut buf[..len])
861        }
862    }
863}