1use std::io;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4use std::time::Duration;
5
6use async_trait::async_trait;
7use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
8use celestia_types::ExtendedHeader;
9use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10use libp2p::core::Endpoint;
11use libp2p::core::transport::PortUse;
12use libp2p::request_response::{self, Codec, InboundFailure, OutboundFailure, ProtocolSupport};
13use libp2p::swarm::handler::ConnectionEvent;
14use libp2p::swarm::{
15 ConnectionDenied, ConnectionHandler, ConnectionHandlerEvent, ConnectionId, FromSwarm,
16 NetworkBehaviour, SubstreamProtocol, THandlerInEvent, THandlerOutEvent, ToSwarm,
17};
18use libp2p::{Multiaddr, PeerId, StreamProtocol};
19use lumina_utils::time::{Instant, timeout};
20use prost::Message;
21use tracing::{debug, instrument, warn};
22
23mod client;
24mod server;
25pub(crate) mod utils;
26
27use crate::p2p::P2pError;
28use crate::p2p::header_ex::client::HeaderExClientHandler;
29use crate::p2p::header_ex::server::HeaderExServerHandler;
30use crate::peer_tracker::PeerTracker;
31use crate::store::Store;
32use crate::utils::{OneshotResultSender, protocol_id};
33
34const REQUEST_SIZE_LIMIT: usize = 1024;
36const REQUEST_TIME_LIMIT: Duration = Duration::from_secs(1);
38const RESPONSE_SIZE_LIMIT: usize = 10 * 1024 * 1024;
40const RESPONSE_TIME_LIMIT: Duration = Duration::from_secs(5);
42const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(1);
44
45type RequestType = HeaderRequest;
46type ResponseType = Vec<HeaderResponse>;
47type ReqRespBehaviour = request_response::Behaviour<HeaderCodec>;
48type ReqRespEvent = request_response::Event<RequestType, ResponseType>;
49type ReqRespMessage = request_response::Message<RequestType, ResponseType>;
50type ReqRespConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;
51
52pub(crate) struct HeaderExBehaviour<S>
53where
54 S: Store + 'static,
55{
56 req_resp: ReqRespBehaviour,
57 client_handler: HeaderExClientHandler,
58 server_handler: HeaderExServerHandler<S>,
59}
60
61pub(crate) struct HeaderExConfig<'a, S> {
62 pub network_id: &'a str,
63 pub header_store: Arc<S>,
64}
65
66#[derive(Debug, thiserror::Error)]
68pub enum HeaderExError {
69 #[error("Header not found")]
71 HeaderNotFound,
72
73 #[error("Invalid response")]
75 InvalidResponse,
76
77 #[error("Invalid request")]
79 InvalidRequest,
80
81 #[error("Inbound failure: {0}")]
83 InboundFailure(InboundFailure),
84
85 #[error("Outbound failure: {0}")]
87 OutboundFailure(OutboundFailure),
88
89 #[error("Request cancelled because `Node` is stopping")]
93 RequestCancelled,
94}
95
96impl<S> HeaderExBehaviour<S>
97where
98 S: Store + 'static,
99{
100 pub(crate) fn new(config: HeaderExConfig<'_, S>) -> Self {
101 HeaderExBehaviour {
102 req_resp: ReqRespBehaviour::new(
103 [(
104 protocol_id(config.network_id, "/header-ex/v0.0.3"),
105 ProtocolSupport::Full,
106 )],
107 request_response::Config::default(),
108 ),
109 client_handler: HeaderExClientHandler::new(),
110 server_handler: HeaderExServerHandler::new(config.header_store),
111 }
112 }
113
114 #[instrument(level = "trace", skip(self, respond_to))]
115 pub(crate) fn send_request(
116 &mut self,
117 request: HeaderRequest,
118 respond_to: OneshotResultSender<Vec<ExtendedHeader>, P2pError>,
119 peer_tracker: &PeerTracker,
120 ) {
121 self.client_handler
122 .on_send_request(&mut self.req_resp, request, respond_to, peer_tracker);
123 }
124
125 pub(crate) fn stop(&mut self) {
126 self.client_handler.on_stop();
127 self.server_handler.on_stop();
128 }
129
130 fn on_to_swarm(
131 &mut self,
132 ev: ToSwarm<ReqRespEvent, THandlerInEvent<ReqRespBehaviour>>,
133 ) -> Option<ToSwarm<(), THandlerInEvent<Self>>> {
134 match ev {
135 ToSwarm::GenerateEvent(ev) => {
136 self.on_req_resp_event(ev);
137 None
138 }
139 _ => Some(ev.map_out(|_| ())),
140 }
141 }
142
143 #[instrument(level = "trace", skip_all)]
144 fn on_req_resp_event(&mut self, ev: ReqRespEvent) {
145 match ev {
146 ReqRespEvent::Message {
148 message:
149 ReqRespMessage::Response {
150 request_id,
151 response,
152 },
153 peer,
154 ..
155 } => {
156 self.client_handler
157 .on_response_received(peer, request_id, response);
158 }
159
160 ReqRespEvent::OutboundFailure {
162 peer,
163 request_id,
164 error,
165 ..
166 } => {
167 self.client_handler.on_failure(peer, request_id, error);
168 }
169
170 ReqRespEvent::Message {
172 message:
173 ReqRespMessage::Request {
174 request_id,
175 request,
176 channel,
177 },
178 peer,
179 ..
180 } => {
181 self.server_handler.on_request_received(
182 peer,
183 request_id,
184 request,
185 &mut self.req_resp,
186 channel,
187 );
188 }
189
190 ReqRespEvent::ResponseSent {
192 peer, request_id, ..
193 } => {
194 self.server_handler.on_response_sent(peer, request_id);
195 }
196
197 ReqRespEvent::InboundFailure {
199 peer,
200 request_id,
201 error,
202 ..
203 } => {
204 self.server_handler.on_failure(peer, request_id, error);
205 }
206 }
207 }
208}
209
210impl<S> NetworkBehaviour for HeaderExBehaviour<S>
211where
212 S: Store + 'static,
213{
214 type ConnectionHandler = ConnHandler;
215 type ToSwarm = ();
216
217 fn handle_established_inbound_connection(
218 &mut self,
219 connection_id: ConnectionId,
220 peer: PeerId,
221 local_addr: &Multiaddr,
222 remote_addr: &Multiaddr,
223 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
224 self.req_resp
225 .handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)
226 .map(ConnHandler)
227 }
228
229 fn handle_established_outbound_connection(
230 &mut self,
231 connection_id: ConnectionId,
232 peer: PeerId,
233 addr: &Multiaddr,
234 role_override: Endpoint,
235 port_use: PortUse,
236 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
237 self.req_resp
238 .handle_established_outbound_connection(
239 connection_id,
240 peer,
241 addr,
242 role_override,
243 port_use,
244 )
245 .map(ConnHandler)
246 }
247
248 fn handle_pending_inbound_connection(
249 &mut self,
250 connection_id: ConnectionId,
251 local_addr: &Multiaddr,
252 remote_addr: &Multiaddr,
253 ) -> Result<(), ConnectionDenied> {
254 self.req_resp
255 .handle_pending_inbound_connection(connection_id, local_addr, remote_addr)
256 }
257
258 fn handle_pending_outbound_connection(
259 &mut self,
260 connection_id: ConnectionId,
261 maybe_peer: Option<PeerId>,
262 addresses: &[Multiaddr],
263 effective_role: Endpoint,
264 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
265 self.req_resp.handle_pending_outbound_connection(
266 connection_id,
267 maybe_peer,
268 addresses,
269 effective_role,
270 )
271 }
272
273 fn on_swarm_event(&mut self, event: FromSwarm) {
274 self.req_resp.on_swarm_event(event)
275 }
276
277 fn on_connection_handler_event(
278 &mut self,
279 peer_id: PeerId,
280 connection_id: ConnectionId,
281 event: THandlerOutEvent<Self>,
282 ) {
283 self.req_resp
284 .on_connection_handler_event(peer_id, connection_id, event)
285 }
286
287 fn poll(
288 &mut self,
289 cx: &mut Context<'_>,
290 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
291 loop {
292 if let Poll::Ready(ev) = self.req_resp.poll(cx) {
293 if let Some(ev) = self.on_to_swarm(ev) {
294 return Poll::Ready(ev);
295 }
296
297 continue;
298 }
299
300 if self.client_handler.poll(cx).is_ready() {
301 continue;
302 }
303
304 if self.server_handler.poll(cx, &mut self.req_resp).is_ready() {
305 continue;
306 }
307
308 return Poll::Pending;
309 }
310 }
311}
312
313pub(crate) struct ConnHandler(ReqRespConnectionHandler);
314
315impl ConnectionHandler for ConnHandler {
316 type ToBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::ToBehaviour;
317 type FromBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::FromBehaviour;
318 type InboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::InboundProtocol;
319 type InboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::InboundOpenInfo;
320 type OutboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::OutboundProtocol;
321 type OutboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::OutboundOpenInfo;
322
323 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
324 self.0.listen_protocol().with_timeout(NEGOTIATION_TIMEOUT)
325 }
326
327 fn poll(
328 &mut self,
329 cx: &mut Context<'_>,
330 ) -> Poll<
331 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
332 > {
333 match self.0.poll(cx) {
334 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
335 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
336 protocol: protocol.with_timeout(NEGOTIATION_TIMEOUT),
337 })
338 }
339 ev => ev,
340 }
341 }
342
343 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
344 self.0.on_behaviour_event(event)
345 }
346
347 fn on_connection_event(
348 &mut self,
349 event: ConnectionEvent<
350 '_,
351 Self::InboundProtocol,
352 Self::OutboundProtocol,
353 Self::InboundOpenInfo,
354 Self::OutboundOpenInfo,
355 >,
356 ) {
357 self.0.on_connection_event(event)
358 }
359
360 fn connection_keep_alive(&self) -> bool {
361 self.0.connection_keep_alive()
362 }
363
364 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
365 self.0.poll_close(cx)
366 }
367}
368
369#[derive(Clone, Copy, Debug, Default)]
370pub(crate) struct HeaderCodec;
371
372#[async_trait]
373impl Codec for HeaderCodec {
374 type Protocol = StreamProtocol;
375 type Request = HeaderRequest;
376 type Response = Vec<HeaderResponse>;
377
378 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
379 where
380 T: AsyncRead + Unpin + Send,
381 {
382 let data = read_up_to(io, REQUEST_SIZE_LIMIT, REQUEST_TIME_LIMIT).await?;
383
384 if data.len() >= REQUEST_SIZE_LIMIT {
385 debug!("Message filled the whole buffer (len: {})", data.len());
386 }
387
388 parse_header_request(&data).ok_or_else(|| {
389 io::Error::other("invalid or incomplete request")
394 })
395 }
396
397 async fn read_response<T>(
398 &mut self,
399 _: &Self::Protocol,
400 io: &mut T,
401 ) -> io::Result<Self::Response>
402 where
403 T: AsyncRead + Unpin + Send,
404 {
405 let data = read_up_to(io, RESPONSE_SIZE_LIMIT, RESPONSE_TIME_LIMIT).await?;
406
407 if data.len() >= RESPONSE_SIZE_LIMIT {
408 debug!("Message filled the whole buffer (len: {})", data.len());
409 }
410
411 let mut data = &data[..];
412 let mut msgs = Vec::new();
413
414 while let Some((header, rest)) = parse_header_response(data) {
415 msgs.push(header);
416 data = rest;
417 }
418
419 if msgs.is_empty() {
420 return Err(io::Error::other("invalid or incomplete response"));
425 }
426
427 Ok(msgs)
428 }
429
430 async fn write_request<T>(
431 &mut self,
432 _: &Self::Protocol,
433 io: &mut T,
434 req: Self::Request,
435 ) -> io::Result<()>
436 where
437 T: AsyncWrite + Unpin + Send,
438 {
439 let mut buf = Vec::with_capacity(REQUEST_SIZE_LIMIT);
440
441 let _ = req.encode_length_delimited(&mut buf);
442
443 timeout(REQUEST_TIME_LIMIT, io.write_all(&buf))
444 .await
445 .map_err(|_| io::Error::other("writing request timed out"))??;
446
447 Ok(())
448 }
449
450 async fn write_response<T>(
451 &mut self,
452 _: &Self::Protocol,
453 io: &mut T,
454 resps: Self::Response,
455 ) -> io::Result<()>
456 where
457 T: AsyncWrite + Unpin + Send,
458 {
459 let mut buf = Vec::with_capacity(RESPONSE_SIZE_LIMIT);
460
461 for resp in resps {
462 if resp.encode_length_delimited(&mut buf).is_err() {
463 debug!("Sending partial response");
466 break;
467 }
468 }
469
470 timeout(RESPONSE_TIME_LIMIT, io.write_all(&buf))
471 .await
472 .map_err(|_| io::Error::other("writing response timed out"))??;
473
474 Ok(())
475 }
476}
477
478async fn read_up_to<T>(io: &mut T, size_limit: usize, time_limit: Duration) -> io::Result<Vec<u8>>
480where
481 T: AsyncRead + Unpin + Send,
482{
483 let mut buf = vec![0u8; size_limit];
484 let mut read_len = 0;
485 let now = Instant::now();
486
487 loop {
488 if read_len == buf.len() {
489 break;
491 }
492
493 let Some(time_limit) = time_limit.checked_sub(now.elapsed()) else {
494 break;
495 };
496
497 let len = match timeout(time_limit, io.read(&mut buf[read_len..])).await {
498 Ok(Ok(len)) => len,
499 Ok(Err(e)) => return Err(e),
500 Err(_) => break,
501 };
502
503 if len == 0 {
504 break;
506 }
507
508 read_len += len;
509 }
510
511 buf.truncate(read_len);
512
513 Ok(buf)
514}
515
516fn parse_delimiter(mut buf: &[u8]) -> Option<(usize, &[u8])> {
517 if buf.is_empty() {
518 return None;
519 }
520
521 let Ok(len) = prost::decode_length_delimiter(&mut buf) else {
522 return None;
523 };
524
525 Some((len, buf))
526}
527
528fn parse_header_response(buf: &[u8]) -> Option<(HeaderResponse, &[u8])> {
529 let (len, rest) = parse_delimiter(buf)?;
530
531 if rest.len() < len {
532 debug!("Message is incomplete: {len}");
533 return None;
534 }
535
536 let Ok(msg) = HeaderResponse::decode(&rest[..len]) else {
537 return None;
538 };
539
540 Some((msg, &rest[len..]))
541}
542
543fn parse_header_request(buf: &[u8]) -> Option<HeaderRequest> {
544 let (len, rest) = parse_delimiter(buf)?;
545
546 if rest.len() < len {
547 debug!("Message is incomplete: {len}");
548 return None;
549 }
550
551 let Ok(msg) = HeaderRequest::decode(&rest[..len]) else {
552 return None;
553 };
554
555 Some(msg)
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561 use bytes::BytesMut;
562 use celestia_proto::p2p::pb::header_request::Data;
563 use futures::io::{Cursor, Error};
564 use lumina_utils::test_utils::async_test;
565 use prost::encode_length_delimiter;
566 use std::io::ErrorKind;
567 use std::pin::Pin;
568
569 #[async_test]
570 async fn test_decode_header_request_empty() {
571 let header_request = HeaderRequest {
572 amount: 0,
573 data: None,
574 };
575
576 let encoded_header_request = header_request.encode_length_delimited_to_vec();
577
578 let mut reader = Cursor::new(encoded_header_request);
579
580 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
581 let mut codec = HeaderCodec {};
582
583 let decoded_header_request = codec
584 .read_request(&stream_protocol, &mut reader)
585 .await
586 .unwrap();
587
588 assert_eq!(header_request, decoded_header_request);
589 }
590
591 #[async_test]
592 async fn test_decode_multiple_small_header_response() {
593 const MSG_COUNT: usize = 10;
594 let header_response = HeaderResponse {
595 body: vec![1, 2, 3],
596 status_code: 1,
597 };
598
599 let encoded_header_response = header_response.encode_length_delimited_to_vec();
600
601 let mut multi_msg = vec![];
602 for _ in 0..MSG_COUNT {
603 multi_msg.extend_from_slice(&encoded_header_response);
604 }
605 let mut reader = Cursor::new(multi_msg);
606
607 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
608 let mut codec = HeaderCodec {};
609
610 let decoded_header_response = codec
611 .read_response(&stream_protocol, &mut reader)
612 .await
613 .unwrap();
614
615 for decoded_header in decoded_header_response.iter() {
616 assert_eq!(&header_response, decoded_header);
617 }
618 assert_eq!(decoded_header_response.len(), MSG_COUNT);
619 }
620
621 #[async_test]
622 async fn test_decode_header_request_too_large() {
623 let too_long_message_len = REQUEST_SIZE_LIMIT + 1;
624 let mut length_delimiter_buffer = BytesMut::new();
625 prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
626 let mut reader = Cursor::new(length_delimiter_buffer);
627
628 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
629 let mut codec = HeaderCodec {};
630
631 let decoding_error = codec
632 .read_request(&stream_protocol, &mut reader)
633 .await
634 .expect_err("expected error for too large request");
635
636 assert_eq!(decoding_error.kind(), ErrorKind::Other);
637 }
638
639 #[async_test]
640 async fn test_decode_header_response_too_large() {
641 let too_long_message_len = RESPONSE_SIZE_LIMIT + 1;
642 let mut length_delimiter_buffer = BytesMut::new();
643 encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
644 let mut reader = Cursor::new(length_delimiter_buffer);
645
646 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
647 let mut codec = HeaderCodec {};
648
649 let decoding_error = codec
650 .read_response(&stream_protocol, &mut reader)
651 .await
652 .expect_err("expected error for too large request");
653
654 assert_eq!(decoding_error.kind(), ErrorKind::Other);
655 }
656
657 #[test]
658 fn test_invalid_varint() {
659 let varint = [
662 0b1000_0000,
663 0b1000_0000,
664 0b1000_0000,
665 0b1000_0000,
666 0b1000_0000,
667 0b1000_0000,
668 0b1000_0000,
669 0b1000_0000,
670 0b1000_0000,
671 0b1000_0000,
672 0b0000_0001,
673 ];
674
675 assert_eq!(parse_delimiter(&varint), None);
676 }
677
678 #[test]
679 fn parse_trailing_zero_varint() {
680 let varint = [0b1000_0001, 0b0000_0000, 0b1111_1111];
681 assert!(matches!(parse_delimiter(&varint), Some((1, [255]))));
682
683 let varint = [0b1000_0000, 0b1000_0000, 0b1000_0000, 0b0000_0000];
684 assert!(matches!(parse_delimiter(&varint), Some((0, []))));
685 }
686
687 #[async_test]
688 async fn test_decode_header_double_response_data() {
689 let mut header_response_buffer = BytesMut::with_capacity(512);
690 let header_response0 = HeaderResponse {
691 body: b"9999888877776666555544443333222211110000".to_vec(),
692 status_code: 1,
693 };
694 let header_response1 = HeaderResponse {
695 body: b"0000111122223333444455556666777788889999".to_vec(),
696 status_code: 2,
697 };
698 header_response0
699 .encode_length_delimited(&mut header_response_buffer)
700 .unwrap();
701 header_response1
702 .encode_length_delimited(&mut header_response_buffer)
703 .unwrap();
704 let mut reader = Cursor::new(header_response_buffer);
705
706 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
707 let mut codec = HeaderCodec {};
708
709 let decoded_header_response = codec
710 .read_response(&stream_protocol, &mut reader)
711 .await
712 .unwrap();
713 assert_eq!(header_response0, decoded_header_response[0]);
714 assert_eq!(header_response1, decoded_header_response[1]);
715 }
716
717 #[async_test]
718 async fn test_decode_header_request_chunked_data() {
719 let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
720 let header_request = HeaderRequest {
721 amount: 1,
722 data: Some(Data::Hash(data.to_vec())),
723 };
724 let encoded_header_request = header_request.encode_length_delimited_to_vec();
725
726 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
727 let mut codec = HeaderCodec {};
728 {
729 let mut reader =
730 ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_request.clone()));
731 let decoded_header_request = codec
732 .read_request(&stream_protocol, &mut reader)
733 .await
734 .unwrap();
735 assert_eq!(header_request, decoded_header_request);
736 }
737 {
738 let mut reader =
739 ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_request.clone()));
740 let decoded_header_request = codec
741 .read_request(&stream_protocol, &mut reader)
742 .await
743 .unwrap();
744
745 assert_eq!(header_request, decoded_header_request);
746 }
747 {
748 let mut reader =
749 ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_request.clone()));
750 let decoded_header_request = codec
751 .read_request(&stream_protocol, &mut reader)
752 .await
753 .unwrap();
754
755 assert_eq!(header_request, decoded_header_request);
756 }
757 {
758 let mut reader =
759 ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_request.clone()));
760 let decoded_header_request = codec
761 .read_request(&stream_protocol, &mut reader)
762 .await
763 .unwrap();
764
765 assert_eq!(header_request, decoded_header_request);
766 }
767 }
768
769 #[async_test]
770 async fn test_decode_header_response_chunked_data() {
771 let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
772 let header_response = HeaderResponse {
773 body: data.to_vec(),
774 status_code: 2,
775 };
776 let encoded_header_response = header_response.encode_length_delimited_to_vec();
777
778 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
779 let mut codec = HeaderCodec {};
780 {
781 let mut reader =
782 ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone()));
783 let decoded_header_response = codec
784 .read_response(&stream_protocol, &mut reader)
785 .await
786 .unwrap();
787 assert_eq!(header_response, decoded_header_response[0]);
788 }
789 {
790 let mut reader =
791 ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone()));
792 let decoded_header_response = codec
793 .read_response(&stream_protocol, &mut reader)
794 .await
795 .unwrap();
796
797 assert_eq!(header_response, decoded_header_response[0]);
798 }
799 {
800 let mut reader =
801 ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone()));
802 let decoded_header_response = codec
803 .read_response(&stream_protocol, &mut reader)
804 .await
805 .unwrap();
806
807 assert_eq!(header_response, decoded_header_response[0]);
808 }
809 {
810 let mut reader =
811 ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone()));
812 let decoded_header_response = codec
813 .read_response(&stream_protocol, &mut reader)
814 .await
815 .unwrap();
816
817 assert_eq!(header_response, decoded_header_response[0]);
818 }
819 }
820
821 #[async_test]
822 async fn test_chunky_async_read() {
823 let read_data = "FOO123";
824 let cur0 = Cursor::new(read_data);
825 let mut chunky = ChunkyAsyncRead::<_, 3>::new(cur0);
826
827 let mut output_buffer: BytesMut = b"BAR987".as_ref().into();
828
829 let _ = chunky.read(&mut output_buffer[..]).await.unwrap();
830 assert_eq!(output_buffer, b"FOO987".as_ref());
831 }
832
833 struct ChunkyAsyncRead<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> {
834 inner: T,
835 }
836
837 impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> ChunkyAsyncRead<T, CHUNK_SIZE> {
838 fn new(inner: T) -> Self {
839 ChunkyAsyncRead { inner }
840 }
841 }
842
843 impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> AsyncRead for ChunkyAsyncRead<T, CHUNK_SIZE> {
844 fn poll_read(
845 mut self: Pin<&mut Self>,
846 cx: &mut Context<'_>,
847 buf: &mut [u8],
848 ) -> Poll<Result<usize, Error>> {
849 let len = buf.len().min(CHUNK_SIZE);
850 Pin::new(&mut self.inner).poll_read(cx, &mut buf[..len])
851 }
852 }
853}