1use std::io;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4use std::time::Duration;
5
6use async_trait::async_trait;
7use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
8use celestia_types::ExtendedHeader;
9use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10use libp2p::core::transport::PortUse;
11use libp2p::{
12 core::Endpoint,
13 request_response::{self, Codec, InboundFailure, OutboundFailure, ProtocolSupport},
14 swarm::{
15 handler::ConnectionEvent, ConnectionDenied, ConnectionHandler, ConnectionHandlerEvent,
16 ConnectionId, FromSwarm, NetworkBehaviour, SubstreamProtocol, THandlerInEvent,
17 THandlerOutEvent, ToSwarm,
18 },
19 Multiaddr, PeerId, StreamProtocol,
20};
21use lumina_utils::time::{timeout, Instant};
22use prost::Message;
23use tracing::{debug, instrument, warn};
24
25mod client;
26mod server;
27pub(crate) mod utils;
28
29use crate::p2p::header_ex::client::HeaderExClientHandler;
30use crate::p2p::header_ex::server::HeaderExServerHandler;
31use crate::p2p::P2pError;
32use crate::peer_tracker::PeerTracker;
33use crate::store::Store;
34use crate::utils::{protocol_id, OneshotResultSender};
35
36const REQUEST_SIZE_LIMIT: usize = 1024;
38const REQUEST_TIME_LIMIT: Duration = Duration::from_secs(1);
40const RESPONSE_SIZE_LIMIT: usize = 10 * 1024 * 1024;
42const RESPONSE_TIME_LIMIT: Duration = Duration::from_secs(5);
44const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(1);
46
47type RequestType = HeaderRequest;
48type ResponseType = Vec<HeaderResponse>;
49type ReqRespBehaviour = request_response::Behaviour<HeaderCodec>;
50type ReqRespEvent = request_response::Event<RequestType, ResponseType>;
51type ReqRespMessage = request_response::Message<RequestType, ResponseType>;
52type ReqRespConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;
53
54pub(crate) struct HeaderExBehaviour<S>
55where
56 S: Store + 'static,
57{
58 req_resp: ReqRespBehaviour,
59 client_handler: HeaderExClientHandler,
60 server_handler: HeaderExServerHandler<S>,
61}
62
63pub(crate) struct HeaderExConfig<'a, S> {
64 pub network_id: &'a str,
65 pub peer_tracker: Arc<PeerTracker>,
66 pub header_store: Arc<S>,
67}
68
69#[derive(Debug, thiserror::Error)]
71pub enum HeaderExError {
72 #[error("Header not found")]
74 HeaderNotFound,
75
76 #[error("Invalid response")]
78 InvalidResponse,
79
80 #[error("Invalid request")]
82 InvalidRequest,
83
84 #[error("Inbound failure: {0}")]
86 InboundFailure(InboundFailure),
87
88 #[error("Outbound failure: {0}")]
90 OutboundFailure(OutboundFailure),
91
92 #[error("Request cancelled because `Node` is stopping")]
96 RequestCancelled,
97}
98
99impl<S> HeaderExBehaviour<S>
100where
101 S: Store + 'static,
102{
103 pub(crate) fn new(config: HeaderExConfig<'_, S>) -> Self {
104 HeaderExBehaviour {
105 req_resp: ReqRespBehaviour::new(
106 [(
107 protocol_id(config.network_id, "/header-ex/v0.0.3"),
108 ProtocolSupport::Full,
109 )],
110 request_response::Config::default(),
111 ),
112 client_handler: HeaderExClientHandler::new(config.peer_tracker),
113 server_handler: HeaderExServerHandler::new(config.header_store),
114 }
115 }
116
117 #[instrument(level = "trace", skip(self, respond_to))]
118 pub(crate) fn send_request(
119 &mut self,
120 request: HeaderRequest,
121 respond_to: OneshotResultSender<Vec<ExtendedHeader>, P2pError>,
122 ) {
123 self.client_handler
124 .on_send_request(&mut self.req_resp, request, respond_to);
125 }
126
127 pub(crate) fn stop(&mut self) {
128 self.client_handler.on_stop();
129 self.server_handler.on_stop();
130 }
131
132 fn on_to_swarm(
133 &mut self,
134 ev: ToSwarm<ReqRespEvent, THandlerInEvent<ReqRespBehaviour>>,
135 ) -> Option<ToSwarm<(), THandlerInEvent<Self>>> {
136 match ev {
137 ToSwarm::GenerateEvent(ev) => {
138 self.on_req_resp_event(ev);
139 None
140 }
141 _ => Some(ev.map_out(|_| ())),
142 }
143 }
144
145 #[instrument(level = "trace", skip_all)]
146 fn on_req_resp_event(&mut self, ev: ReqRespEvent) {
147 match ev {
148 ReqRespEvent::Message {
150 message:
151 ReqRespMessage::Response {
152 request_id,
153 response,
154 },
155 peer,
156 ..
157 } => {
158 self.client_handler
159 .on_response_received(peer, request_id, response);
160 }
161
162 ReqRespEvent::OutboundFailure {
164 peer,
165 request_id,
166 error,
167 ..
168 } => {
169 self.client_handler.on_failure(peer, request_id, error);
170 }
171
172 ReqRespEvent::Message {
174 message:
175 ReqRespMessage::Request {
176 request_id,
177 request,
178 channel,
179 },
180 peer,
181 ..
182 } => {
183 self.server_handler.on_request_received(
184 peer,
185 request_id,
186 request,
187 &mut self.req_resp,
188 channel,
189 );
190 }
191
192 ReqRespEvent::ResponseSent {
194 peer, request_id, ..
195 } => {
196 self.server_handler.on_response_sent(peer, request_id);
197 }
198
199 ReqRespEvent::InboundFailure {
201 peer,
202 request_id,
203 error,
204 ..
205 } => {
206 self.server_handler.on_failure(peer, request_id, error);
207 }
208 }
209 }
210}
211
212impl<S> NetworkBehaviour for HeaderExBehaviour<S>
213where
214 S: Store + 'static,
215{
216 type ConnectionHandler = ConnHandler;
217 type ToSwarm = ();
218
219 fn handle_established_inbound_connection(
220 &mut self,
221 connection_id: ConnectionId,
222 peer: PeerId,
223 local_addr: &Multiaddr,
224 remote_addr: &Multiaddr,
225 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
226 self.req_resp
227 .handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)
228 .map(ConnHandler)
229 }
230
231 fn handle_established_outbound_connection(
232 &mut self,
233 connection_id: ConnectionId,
234 peer: PeerId,
235 addr: &Multiaddr,
236 role_override: Endpoint,
237 port_use: PortUse,
238 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
239 self.req_resp
240 .handle_established_outbound_connection(
241 connection_id,
242 peer,
243 addr,
244 role_override,
245 port_use,
246 )
247 .map(ConnHandler)
248 }
249
250 fn handle_pending_inbound_connection(
251 &mut self,
252 connection_id: ConnectionId,
253 local_addr: &Multiaddr,
254 remote_addr: &Multiaddr,
255 ) -> Result<(), ConnectionDenied> {
256 self.req_resp
257 .handle_pending_inbound_connection(connection_id, local_addr, remote_addr)
258 }
259
260 fn handle_pending_outbound_connection(
261 &mut self,
262 connection_id: ConnectionId,
263 maybe_peer: Option<PeerId>,
264 addresses: &[Multiaddr],
265 effective_role: Endpoint,
266 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
267 self.req_resp.handle_pending_outbound_connection(
268 connection_id,
269 maybe_peer,
270 addresses,
271 effective_role,
272 )
273 }
274
275 fn on_swarm_event(&mut self, event: FromSwarm) {
276 self.req_resp.on_swarm_event(event)
277 }
278
279 fn on_connection_handler_event(
280 &mut self,
281 peer_id: PeerId,
282 connection_id: ConnectionId,
283 event: THandlerOutEvent<Self>,
284 ) {
285 self.req_resp
286 .on_connection_handler_event(peer_id, connection_id, event)
287 }
288
289 fn poll(
290 &mut self,
291 cx: &mut Context<'_>,
292 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
293 loop {
294 if let Poll::Ready(ev) = self.req_resp.poll(cx) {
295 if let Some(ev) = self.on_to_swarm(ev) {
296 return Poll::Ready(ev);
297 }
298
299 continue;
300 }
301
302 if self.client_handler.poll(cx).is_ready() {
303 continue;
304 }
305
306 if self.server_handler.poll(cx, &mut self.req_resp).is_ready() {
307 continue;
308 }
309
310 return Poll::Pending;
311 }
312 }
313}
314
315pub(crate) struct ConnHandler(ReqRespConnectionHandler);
316
317impl ConnectionHandler for ConnHandler {
318 type ToBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::ToBehaviour;
319 type FromBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::FromBehaviour;
320 type InboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::InboundProtocol;
321 type InboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::InboundOpenInfo;
322 type OutboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::OutboundProtocol;
323 type OutboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::OutboundOpenInfo;
324
325 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
326 self.0.listen_protocol().with_timeout(NEGOTIATION_TIMEOUT)
327 }
328
329 fn poll(
330 &mut self,
331 cx: &mut Context<'_>,
332 ) -> Poll<
333 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
334 > {
335 match self.0.poll(cx) {
336 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
337 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
338 protocol: protocol.with_timeout(NEGOTIATION_TIMEOUT),
339 })
340 }
341 ev => ev,
342 }
343 }
344
345 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
346 self.0.on_behaviour_event(event)
347 }
348
349 fn on_connection_event(
350 &mut self,
351 event: ConnectionEvent<
352 '_,
353 Self::InboundProtocol,
354 Self::OutboundProtocol,
355 Self::InboundOpenInfo,
356 Self::OutboundOpenInfo,
357 >,
358 ) {
359 self.0.on_connection_event(event)
360 }
361
362 fn connection_keep_alive(&self) -> bool {
363 self.0.connection_keep_alive()
364 }
365
366 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
367 self.0.poll_close(cx)
368 }
369}
370
371#[derive(Clone, Copy, Debug, Default)]
372pub(crate) struct HeaderCodec;
373
374#[async_trait]
375impl Codec for HeaderCodec {
376 type Protocol = StreamProtocol;
377 type Request = HeaderRequest;
378 type Response = Vec<HeaderResponse>;
379
380 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
381 where
382 T: AsyncRead + Unpin + Send,
383 {
384 let data = read_up_to(io, REQUEST_SIZE_LIMIT, REQUEST_TIME_LIMIT).await?;
385
386 if data.len() >= REQUEST_SIZE_LIMIT {
387 debug!("Message filled the whole buffer (len: {})", data.len());
388 }
389
390 parse_header_request(&data).ok_or_else(|| {
391 io::Error::other("invalid or incomplete request")
396 })
397 }
398
399 async fn read_response<T>(
400 &mut self,
401 _: &Self::Protocol,
402 io: &mut T,
403 ) -> io::Result<Self::Response>
404 where
405 T: AsyncRead + Unpin + Send,
406 {
407 let data = read_up_to(io, RESPONSE_SIZE_LIMIT, RESPONSE_TIME_LIMIT).await?;
408
409 if data.len() >= RESPONSE_SIZE_LIMIT {
410 debug!("Message filled the whole buffer (len: {})", data.len());
411 }
412
413 let mut data = &data[..];
414 let mut msgs = Vec::new();
415
416 while let Some((header, rest)) = parse_header_response(data) {
417 msgs.push(header);
418 data = rest;
419 }
420
421 if msgs.is_empty() {
422 return Err(io::Error::other("invalid or incomplete response"));
427 }
428
429 Ok(msgs)
430 }
431
432 async fn write_request<T>(
433 &mut self,
434 _: &Self::Protocol,
435 io: &mut T,
436 req: Self::Request,
437 ) -> io::Result<()>
438 where
439 T: AsyncWrite + Unpin + Send,
440 {
441 let mut buf = Vec::with_capacity(REQUEST_SIZE_LIMIT);
442
443 let _ = req.encode_length_delimited(&mut buf);
444
445 timeout(REQUEST_TIME_LIMIT, io.write_all(&buf))
446 .await
447 .map_err(|_| io::Error::other("writing request timed out"))??;
448
449 Ok(())
450 }
451
452 async fn write_response<T>(
453 &mut self,
454 _: &Self::Protocol,
455 io: &mut T,
456 resps: Self::Response,
457 ) -> io::Result<()>
458 where
459 T: AsyncWrite + Unpin + Send,
460 {
461 let mut buf = Vec::with_capacity(RESPONSE_SIZE_LIMIT);
462
463 for resp in resps {
464 if resp.encode_length_delimited(&mut buf).is_err() {
465 debug!("Sending partial response");
468 break;
469 }
470 }
471
472 timeout(RESPONSE_TIME_LIMIT, io.write_all(&buf))
473 .await
474 .map_err(|_| io::Error::other("writing response timed out"))??;
475
476 Ok(())
477 }
478}
479
480async fn read_up_to<T>(io: &mut T, size_limit: usize, time_limit: Duration) -> io::Result<Vec<u8>>
482where
483 T: AsyncRead + Unpin + Send,
484{
485 let mut buf = vec![0u8; size_limit];
486 let mut read_len = 0;
487 let now = Instant::now();
488
489 loop {
490 if read_len == buf.len() {
491 break;
493 }
494
495 let Some(time_limit) = time_limit.checked_sub(now.elapsed()) else {
496 break;
497 };
498
499 let len = match timeout(time_limit, io.read(&mut buf[read_len..])).await {
500 Ok(Ok(len)) => len,
501 Ok(Err(e)) => return Err(e),
502 Err(_) => break,
503 };
504
505 if len == 0 {
506 break;
508 }
509
510 read_len += len;
511 }
512
513 buf.truncate(read_len);
514
515 Ok(buf)
516}
517
518fn parse_delimiter(mut buf: &[u8]) -> Option<(usize, &[u8])> {
519 if buf.is_empty() {
520 return None;
521 }
522
523 let Ok(len) = prost::decode_length_delimiter(&mut buf) else {
524 return None;
525 };
526
527 Some((len, buf))
528}
529
530fn parse_header_response(buf: &[u8]) -> Option<(HeaderResponse, &[u8])> {
531 let (len, rest) = parse_delimiter(buf)?;
532
533 if rest.len() < len {
534 debug!("Message is incomplete: {len}");
535 return None;
536 }
537
538 let Ok(msg) = HeaderResponse::decode(&rest[..len]) else {
539 return None;
540 };
541
542 Some((msg, &rest[len..]))
543}
544
545fn parse_header_request(buf: &[u8]) -> Option<HeaderRequest> {
546 let (len, rest) = parse_delimiter(buf)?;
547
548 if rest.len() < len {
549 debug!("Message is incomplete: {len}");
550 return None;
551 }
552
553 let Ok(msg) = HeaderRequest::decode(&rest[..len]) else {
554 return None;
555 };
556
557 Some(msg)
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563 use bytes::BytesMut;
564 use celestia_proto::p2p::pb::header_request::Data;
565 use futures::io::{Cursor, Error};
566 use lumina_utils::test_utils::async_test;
567 use prost::encode_length_delimiter;
568 use std::io::ErrorKind;
569 use std::pin::Pin;
570
571 #[async_test]
572 async fn test_decode_header_request_empty() {
573 let header_request = HeaderRequest {
574 amount: 0,
575 data: None,
576 };
577
578 let encoded_header_request = header_request.encode_length_delimited_to_vec();
579
580 let mut reader = Cursor::new(encoded_header_request);
581
582 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
583 let mut codec = HeaderCodec {};
584
585 let decoded_header_request = codec
586 .read_request(&stream_protocol, &mut reader)
587 .await
588 .unwrap();
589
590 assert_eq!(header_request, decoded_header_request);
591 }
592
593 #[async_test]
594 async fn test_decode_multiple_small_header_response() {
595 const MSG_COUNT: usize = 10;
596 let header_response = HeaderResponse {
597 body: vec![1, 2, 3],
598 status_code: 1,
599 };
600
601 let encoded_header_response = header_response.encode_length_delimited_to_vec();
602
603 let mut multi_msg = vec![];
604 for _ in 0..MSG_COUNT {
605 multi_msg.extend_from_slice(&encoded_header_response);
606 }
607 let mut reader = Cursor::new(multi_msg);
608
609 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
610 let mut codec = HeaderCodec {};
611
612 let decoded_header_response = codec
613 .read_response(&stream_protocol, &mut reader)
614 .await
615 .unwrap();
616
617 for decoded_header in decoded_header_response.iter() {
618 assert_eq!(&header_response, decoded_header);
619 }
620 assert_eq!(decoded_header_response.len(), MSG_COUNT);
621 }
622
623 #[async_test]
624 async fn test_decode_header_request_too_large() {
625 let too_long_message_len = REQUEST_SIZE_LIMIT + 1;
626 let mut length_delimiter_buffer = BytesMut::new();
627 prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
628 let mut reader = Cursor::new(length_delimiter_buffer);
629
630 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
631 let mut codec = HeaderCodec {};
632
633 let decoding_error = codec
634 .read_request(&stream_protocol, &mut reader)
635 .await
636 .expect_err("expected error for too large request");
637
638 assert_eq!(decoding_error.kind(), ErrorKind::Other);
639 }
640
641 #[async_test]
642 async fn test_decode_header_response_too_large() {
643 let too_long_message_len = RESPONSE_SIZE_LIMIT + 1;
644 let mut length_delimiter_buffer = BytesMut::new();
645 encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
646 let mut reader = Cursor::new(length_delimiter_buffer);
647
648 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
649 let mut codec = HeaderCodec {};
650
651 let decoding_error = codec
652 .read_response(&stream_protocol, &mut reader)
653 .await
654 .expect_err("expected error for too large request");
655
656 assert_eq!(decoding_error.kind(), ErrorKind::Other);
657 }
658
659 #[test]
660 fn test_invalid_varint() {
661 let varint = [
664 0b1000_0000,
665 0b1000_0000,
666 0b1000_0000,
667 0b1000_0000,
668 0b1000_0000,
669 0b1000_0000,
670 0b1000_0000,
671 0b1000_0000,
672 0b1000_0000,
673 0b1000_0000,
674 0b0000_0001,
675 ];
676
677 assert_eq!(parse_delimiter(&varint), None);
678 }
679
680 #[test]
681 fn parse_trailing_zero_varint() {
682 let varint = [0b1000_0001, 0b0000_0000, 0b1111_1111];
683 assert!(matches!(parse_delimiter(&varint), Some((1, [255]))));
684
685 let varint = [0b1000_0000, 0b1000_0000, 0b1000_0000, 0b0000_0000];
686 assert!(matches!(parse_delimiter(&varint), Some((0, []))));
687 }
688
689 #[async_test]
690 async fn test_decode_header_double_response_data() {
691 let mut header_response_buffer = BytesMut::with_capacity(512);
692 let header_response0 = HeaderResponse {
693 body: b"9999888877776666555544443333222211110000".to_vec(),
694 status_code: 1,
695 };
696 let header_response1 = HeaderResponse {
697 body: b"0000111122223333444455556666777788889999".to_vec(),
698 status_code: 2,
699 };
700 header_response0
701 .encode_length_delimited(&mut header_response_buffer)
702 .unwrap();
703 header_response1
704 .encode_length_delimited(&mut header_response_buffer)
705 .unwrap();
706 let mut reader = Cursor::new(header_response_buffer);
707
708 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
709 let mut codec = HeaderCodec {};
710
711 let decoded_header_response = codec
712 .read_response(&stream_protocol, &mut reader)
713 .await
714 .unwrap();
715 assert_eq!(header_response0, decoded_header_response[0]);
716 assert_eq!(header_response1, decoded_header_response[1]);
717 }
718
719 #[async_test]
720 async fn test_decode_header_request_chunked_data() {
721 let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
722 let header_request = HeaderRequest {
723 amount: 1,
724 data: Some(Data::Hash(data.to_vec())),
725 };
726 let encoded_header_request = header_request.encode_length_delimited_to_vec();
727
728 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
729 let mut codec = HeaderCodec {};
730 {
731 let mut reader =
732 ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_request.clone()));
733 let decoded_header_request = codec
734 .read_request(&stream_protocol, &mut reader)
735 .await
736 .unwrap();
737 assert_eq!(header_request, decoded_header_request);
738 }
739 {
740 let mut reader =
741 ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_request.clone()));
742 let decoded_header_request = codec
743 .read_request(&stream_protocol, &mut reader)
744 .await
745 .unwrap();
746
747 assert_eq!(header_request, decoded_header_request);
748 }
749 {
750 let mut reader =
751 ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_request.clone()));
752 let decoded_header_request = codec
753 .read_request(&stream_protocol, &mut reader)
754 .await
755 .unwrap();
756
757 assert_eq!(header_request, decoded_header_request);
758 }
759 {
760 let mut reader =
761 ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_request.clone()));
762 let decoded_header_request = codec
763 .read_request(&stream_protocol, &mut reader)
764 .await
765 .unwrap();
766
767 assert_eq!(header_request, decoded_header_request);
768 }
769 }
770
771 #[async_test]
772 async fn test_decode_header_response_chunked_data() {
773 let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
774 let header_response = HeaderResponse {
775 body: data.to_vec(),
776 status_code: 2,
777 };
778 let encoded_header_response = header_response.encode_length_delimited_to_vec();
779
780 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
781 let mut codec = HeaderCodec {};
782 {
783 let mut reader =
784 ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone()));
785 let decoded_header_response = codec
786 .read_response(&stream_protocol, &mut reader)
787 .await
788 .unwrap();
789 assert_eq!(header_response, decoded_header_response[0]);
790 }
791 {
792 let mut reader =
793 ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone()));
794 let decoded_header_response = codec
795 .read_response(&stream_protocol, &mut reader)
796 .await
797 .unwrap();
798
799 assert_eq!(header_response, decoded_header_response[0]);
800 }
801 {
802 let mut reader =
803 ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone()));
804 let decoded_header_response = codec
805 .read_response(&stream_protocol, &mut reader)
806 .await
807 .unwrap();
808
809 assert_eq!(header_response, decoded_header_response[0]);
810 }
811 {
812 let mut reader =
813 ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone()));
814 let decoded_header_response = codec
815 .read_response(&stream_protocol, &mut reader)
816 .await
817 .unwrap();
818
819 assert_eq!(header_response, decoded_header_response[0]);
820 }
821 }
822
823 #[async_test]
824 async fn test_chunky_async_read() {
825 let read_data = "FOO123";
826 let cur0 = Cursor::new(read_data);
827 let mut chunky = ChunkyAsyncRead::<_, 3>::new(cur0);
828
829 let mut output_buffer: BytesMut = b"BAR987".as_ref().into();
830
831 let _ = chunky.read(&mut output_buffer[..]).await.unwrap();
832 assert_eq!(output_buffer, b"FOO987".as_ref());
833 }
834
835 struct ChunkyAsyncRead<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> {
836 inner: T,
837 }
838
839 impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> ChunkyAsyncRead<T, CHUNK_SIZE> {
840 fn new(inner: T) -> Self {
841 ChunkyAsyncRead { inner }
842 }
843 }
844
845 impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> AsyncRead for ChunkyAsyncRead<T, CHUNK_SIZE> {
846 fn poll_read(
847 mut self: Pin<&mut Self>,
848 cx: &mut Context<'_>,
849 buf: &mut [u8],
850 ) -> Poll<Result<usize, Error>> {
851 let len = buf.len().min(CHUNK_SIZE);
852 Pin::new(&mut self.inner).poll_read(cx, &mut buf[..len])
853 }
854 }
855}