1use std::io;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4use std::time::Duration;
5
6use async_trait::async_trait;
7use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
8use celestia_types::ExtendedHeader;
9use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10use libp2p::core::Endpoint;
11use libp2p::core::transport::PortUse;
12use libp2p::request_response::{self, Codec, InboundFailure, OutboundFailure, ProtocolSupport};
13use libp2p::swarm::handler::ConnectionEvent;
14use libp2p::swarm::{
15 ConnectionDenied, ConnectionHandler, ConnectionHandlerEvent, ConnectionId, FromSwarm,
16 NetworkBehaviour, SubstreamProtocol, THandlerInEvent, THandlerOutEvent, ToSwarm,
17};
18use libp2p::{Multiaddr, PeerId, StreamProtocol};
19use lumina_utils::time::{Instant, timeout};
20use prost::Message;
21use tracing::{debug, instrument, warn};
22
23mod client;
24mod server;
25pub(crate) mod utils;
26
27use crate::p2p::P2pError;
28use crate::p2p::header_ex::client::HeaderExClientHandler;
29use crate::p2p::header_ex::server::HeaderExServerHandler;
30use crate::peer_tracker::PeerTracker;
31use crate::store::Store;
32use crate::utils::{OneshotResultSender, protocol_id};
33
34const REQUEST_SIZE_LIMIT: usize = 1024;
36const REQUEST_TIME_LIMIT: Duration = Duration::from_secs(1);
38const RESPONSE_SIZE_LIMIT: usize = 10 * 1024 * 1024;
40const RESPONSE_TIME_LIMIT: Duration = Duration::from_secs(5);
42const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(1);
44
45type RequestType = HeaderRequest;
46type ResponseType = Vec<HeaderResponse>;
47type ReqRespBehaviour = request_response::Behaviour<HeaderCodec>;
48type ReqRespEvent = request_response::Event<RequestType, ResponseType>;
49type ReqRespMessage = request_response::Message<RequestType, ResponseType>;
50type ReqRespConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;
51
52pub(crate) struct Behaviour<S>
53where
54 S: Store + 'static,
55{
56 req_resp: ReqRespBehaviour,
57 client_handler: HeaderExClientHandler,
58 server_handler: HeaderExServerHandler<S>,
59}
60
61pub(crate) struct Config<'a, S> {
62 pub network_id: &'a str,
63 pub header_store: Arc<S>,
64}
65
66#[derive(Debug, thiserror::Error)]
68pub enum HeaderExError {
69 #[error("Header not found")]
71 HeaderNotFound,
72
73 #[error("Invalid response")]
75 InvalidResponse,
76
77 #[error("Invalid request")]
79 InvalidRequest,
80
81 #[error("Inbound failure: {0}")]
83 InboundFailure(InboundFailure),
84
85 #[error("Outbound failure: {0}")]
87 OutboundFailure(OutboundFailure),
88
89 #[error("Request cancelled because `Node` is stopping")]
93 RequestCancelled,
94}
95
96#[derive(Debug)]
97pub(crate) enum Event {
98 SchedulePendingRequests,
99 NeedTrustedPeers,
100 NeedArchivalPeers,
101}
102
103impl<S> Behaviour<S>
104where
105 S: Store + 'static,
106{
107 pub(crate) fn new(config: Config<'_, S>) -> Self {
108 Behaviour {
109 req_resp: ReqRespBehaviour::new(
110 [(
111 protocol_id(config.network_id, "/header-ex/v0.0.3"),
112 ProtocolSupport::Full,
113 )],
114 request_response::Config::default(),
115 ),
116 client_handler: HeaderExClientHandler::new(),
117 server_handler: HeaderExServerHandler::new(config.header_store),
118 }
119 }
120
121 #[instrument(level = "trace", skip(self, respond_to))]
122 pub(crate) fn send_request(
123 &mut self,
124 request: HeaderRequest,
125 respond_to: OneshotResultSender<Vec<ExtendedHeader>, P2pError>,
126 ) {
127 self.client_handler.on_send_request(request, respond_to);
128 }
129
130 pub(crate) fn schedule_pending_requests(&mut self, peer_tracker: &PeerTracker) {
131 self.client_handler
132 .schedule_pending_requests(&mut self.req_resp, peer_tracker);
133 }
134
135 pub(crate) fn stop(&mut self) {
136 self.client_handler.on_stop();
137 self.server_handler.on_stop();
138 }
139
140 fn on_to_swarm(
141 &mut self,
142 ev: ToSwarm<ReqRespEvent, THandlerInEvent<ReqRespBehaviour>>,
143 ) -> Option<ToSwarm<Event, THandlerInEvent<Self>>> {
144 match ev {
145 ToSwarm::GenerateEvent(ev) => {
146 self.on_req_resp_event(ev);
147 None
148 }
149 _ => Some(ev.map_out(|_| unreachable!("GenerateEvent handled"))),
150 }
151 }
152
153 #[instrument(level = "trace", skip_all)]
154 fn on_req_resp_event(&mut self, ev: ReqRespEvent) {
155 match ev {
156 ReqRespEvent::Message {
158 message:
159 ReqRespMessage::Response {
160 request_id,
161 response,
162 },
163 peer,
164 ..
165 } => {
166 self.client_handler
167 .on_response_received(peer, request_id, response);
168 }
169
170 ReqRespEvent::OutboundFailure {
172 peer,
173 request_id,
174 error,
175 ..
176 } => {
177 self.client_handler.on_failure(peer, request_id, error);
178 }
179
180 ReqRespEvent::Message {
182 message:
183 ReqRespMessage::Request {
184 request_id,
185 request,
186 channel,
187 },
188 peer,
189 ..
190 } => {
191 self.server_handler.on_request_received(
192 peer,
193 request_id,
194 request,
195 &mut self.req_resp,
196 channel,
197 );
198 }
199
200 ReqRespEvent::ResponseSent {
202 peer, request_id, ..
203 } => {
204 self.server_handler.on_response_sent(peer, request_id);
205 }
206
207 ReqRespEvent::InboundFailure {
209 peer,
210 request_id,
211 error,
212 ..
213 } => {
214 self.server_handler.on_failure(peer, request_id, error);
215 }
216 }
217 }
218}
219
220impl<S> NetworkBehaviour for Behaviour<S>
221where
222 S: Store + 'static,
223{
224 type ConnectionHandler = ConnHandler;
225 type ToSwarm = Event;
226
227 fn handle_established_inbound_connection(
228 &mut self,
229 connection_id: ConnectionId,
230 peer: PeerId,
231 local_addr: &Multiaddr,
232 remote_addr: &Multiaddr,
233 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
234 self.req_resp
235 .handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)
236 .map(ConnHandler)
237 }
238
239 fn handle_established_outbound_connection(
240 &mut self,
241 connection_id: ConnectionId,
242 peer: PeerId,
243 addr: &Multiaddr,
244 role_override: Endpoint,
245 port_use: PortUse,
246 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
247 self.req_resp
248 .handle_established_outbound_connection(
249 connection_id,
250 peer,
251 addr,
252 role_override,
253 port_use,
254 )
255 .map(ConnHandler)
256 }
257
258 fn handle_pending_inbound_connection(
259 &mut self,
260 connection_id: ConnectionId,
261 local_addr: &Multiaddr,
262 remote_addr: &Multiaddr,
263 ) -> Result<(), ConnectionDenied> {
264 self.req_resp
265 .handle_pending_inbound_connection(connection_id, local_addr, remote_addr)
266 }
267
268 fn handle_pending_outbound_connection(
269 &mut self,
270 connection_id: ConnectionId,
271 maybe_peer: Option<PeerId>,
272 addresses: &[Multiaddr],
273 effective_role: Endpoint,
274 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
275 self.req_resp.handle_pending_outbound_connection(
276 connection_id,
277 maybe_peer,
278 addresses,
279 effective_role,
280 )
281 }
282
283 fn on_swarm_event(&mut self, event: FromSwarm) {
284 self.req_resp.on_swarm_event(event)
285 }
286
287 fn on_connection_handler_event(
288 &mut self,
289 peer_id: PeerId,
290 connection_id: ConnectionId,
291 event: THandlerOutEvent<Self>,
292 ) {
293 self.req_resp
294 .on_connection_handler_event(peer_id, connection_id, event)
295 }
296
297 fn poll(
298 &mut self,
299 cx: &mut Context<'_>,
300 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
301 loop {
302 if let Poll::Ready(ev) = self.req_resp.poll(cx) {
303 if let Some(ev) = self.on_to_swarm(ev) {
304 return Poll::Ready(ev);
305 }
306
307 continue;
308 }
309
310 if let Poll::Ready(ev) = self.client_handler.poll(cx) {
311 return Poll::Ready(ToSwarm::GenerateEvent(ev));
312 }
313
314 if self.server_handler.poll(cx, &mut self.req_resp).is_ready() {
315 continue;
316 }
317
318 return Poll::Pending;
319 }
320 }
321}
322
323pub(crate) struct ConnHandler(ReqRespConnectionHandler);
324
325impl ConnectionHandler for ConnHandler {
326 type ToBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::ToBehaviour;
327 type FromBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::FromBehaviour;
328 type InboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::InboundProtocol;
329 type InboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::InboundOpenInfo;
330 type OutboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::OutboundProtocol;
331 type OutboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::OutboundOpenInfo;
332
333 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
334 self.0.listen_protocol().with_timeout(NEGOTIATION_TIMEOUT)
335 }
336
337 fn poll(
338 &mut self,
339 cx: &mut Context<'_>,
340 ) -> Poll<
341 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
342 > {
343 match self.0.poll(cx) {
344 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
345 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
346 protocol: protocol.with_timeout(NEGOTIATION_TIMEOUT),
347 })
348 }
349 ev => ev,
350 }
351 }
352
353 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
354 self.0.on_behaviour_event(event)
355 }
356
357 fn on_connection_event(
358 &mut self,
359 event: ConnectionEvent<
360 '_,
361 Self::InboundProtocol,
362 Self::OutboundProtocol,
363 Self::InboundOpenInfo,
364 Self::OutboundOpenInfo,
365 >,
366 ) {
367 self.0.on_connection_event(event)
368 }
369
370 fn connection_keep_alive(&self) -> bool {
371 self.0.connection_keep_alive()
372 }
373
374 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
375 self.0.poll_close(cx)
376 }
377}
378
379#[derive(Clone, Copy, Debug, Default)]
380pub(crate) struct HeaderCodec;
381
382#[async_trait]
383impl Codec for HeaderCodec {
384 type Protocol = StreamProtocol;
385 type Request = HeaderRequest;
386 type Response = Vec<HeaderResponse>;
387
388 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
389 where
390 T: AsyncRead + Unpin + Send,
391 {
392 let data = read_up_to(io, REQUEST_SIZE_LIMIT, REQUEST_TIME_LIMIT).await?;
393
394 if data.len() >= REQUEST_SIZE_LIMIT {
395 debug!("Message filled the whole buffer (len: {})", data.len());
396 }
397
398 parse_header_request(&data).ok_or_else(|| {
399 io::Error::other("invalid or incomplete request")
404 })
405 }
406
407 async fn read_response<T>(
408 &mut self,
409 _: &Self::Protocol,
410 io: &mut T,
411 ) -> io::Result<Self::Response>
412 where
413 T: AsyncRead + Unpin + Send,
414 {
415 let data = read_up_to(io, RESPONSE_SIZE_LIMIT, RESPONSE_TIME_LIMIT).await?;
416
417 if data.len() >= RESPONSE_SIZE_LIMIT {
418 debug!("Message filled the whole buffer (len: {})", data.len());
419 }
420
421 let mut data = &data[..];
422 let mut msgs = Vec::new();
423
424 while let Some((header, rest)) = parse_header_response(data) {
425 msgs.push(header);
426 data = rest;
427 }
428
429 if msgs.is_empty() {
430 return Err(io::Error::other("invalid or incomplete response"));
435 }
436
437 Ok(msgs)
438 }
439
440 async fn write_request<T>(
441 &mut self,
442 _: &Self::Protocol,
443 io: &mut T,
444 req: Self::Request,
445 ) -> io::Result<()>
446 where
447 T: AsyncWrite + Unpin + Send,
448 {
449 let mut buf = Vec::with_capacity(REQUEST_SIZE_LIMIT);
450
451 let _ = req.encode_length_delimited(&mut buf);
452
453 timeout(REQUEST_TIME_LIMIT, io.write_all(&buf))
454 .await
455 .map_err(|_| io::Error::other("writing request timed out"))??;
456
457 Ok(())
458 }
459
460 async fn write_response<T>(
461 &mut self,
462 _: &Self::Protocol,
463 io: &mut T,
464 resps: Self::Response,
465 ) -> io::Result<()>
466 where
467 T: AsyncWrite + Unpin + Send,
468 {
469 let mut buf = Vec::with_capacity(RESPONSE_SIZE_LIMIT);
470
471 for resp in resps {
472 if resp.encode_length_delimited(&mut buf).is_err() {
473 debug!("Sending partial response");
476 break;
477 }
478 }
479
480 timeout(RESPONSE_TIME_LIMIT, io.write_all(&buf))
481 .await
482 .map_err(|_| io::Error::other("writing response timed out"))??;
483
484 Ok(())
485 }
486}
487
488async fn read_up_to<T>(io: &mut T, size_limit: usize, time_limit: Duration) -> io::Result<Vec<u8>>
490where
491 T: AsyncRead + Unpin + Send,
492{
493 let mut buf = vec![0u8; size_limit];
494 let mut read_len = 0;
495 let now = Instant::now();
496
497 loop {
498 if read_len == buf.len() {
499 break;
501 }
502
503 let Some(time_limit) = time_limit.checked_sub(now.elapsed()) else {
504 break;
505 };
506
507 let len = match timeout(time_limit, io.read(&mut buf[read_len..])).await {
508 Ok(Ok(len)) => len,
509 Ok(Err(e)) => return Err(e),
510 Err(_) => break,
511 };
512
513 if len == 0 {
514 break;
516 }
517
518 read_len += len;
519 }
520
521 buf.truncate(read_len);
522
523 Ok(buf)
524}
525
526fn parse_delimiter(mut buf: &[u8]) -> Option<(usize, &[u8])> {
527 if buf.is_empty() {
528 return None;
529 }
530
531 let Ok(len) = prost::decode_length_delimiter(&mut buf) else {
532 return None;
533 };
534
535 Some((len, buf))
536}
537
538fn parse_header_response(buf: &[u8]) -> Option<(HeaderResponse, &[u8])> {
539 let (len, rest) = parse_delimiter(buf)?;
540
541 if rest.len() < len {
542 debug!("Message is incomplete: {len}");
543 return None;
544 }
545
546 let Ok(msg) = HeaderResponse::decode(&rest[..len]) else {
547 return None;
548 };
549
550 Some((msg, &rest[len..]))
551}
552
553fn parse_header_request(buf: &[u8]) -> Option<HeaderRequest> {
554 let (len, rest) = parse_delimiter(buf)?;
555
556 if rest.len() < len {
557 debug!("Message is incomplete: {len}");
558 return None;
559 }
560
561 let Ok(msg) = HeaderRequest::decode(&rest[..len]) else {
562 return None;
563 };
564
565 Some(msg)
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571 use bytes::BytesMut;
572 use celestia_proto::p2p::pb::header_request::Data;
573 use futures::io::{Cursor, Error};
574 use lumina_utils::test_utils::async_test;
575 use prost::encode_length_delimiter;
576 use std::io::ErrorKind;
577 use std::pin::Pin;
578
579 #[async_test]
580 async fn test_decode_header_request_empty() {
581 let header_request = HeaderRequest {
582 amount: 0,
583 data: None,
584 };
585
586 let encoded_header_request = header_request.encode_length_delimited_to_vec();
587
588 let mut reader = Cursor::new(encoded_header_request);
589
590 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
591 let mut codec = HeaderCodec {};
592
593 let decoded_header_request = codec
594 .read_request(&stream_protocol, &mut reader)
595 .await
596 .unwrap();
597
598 assert_eq!(header_request, decoded_header_request);
599 }
600
601 #[async_test]
602 async fn test_decode_multiple_small_header_response() {
603 const MSG_COUNT: usize = 10;
604 let header_response = HeaderResponse {
605 body: vec![1, 2, 3],
606 status_code: 1,
607 };
608
609 let encoded_header_response = header_response.encode_length_delimited_to_vec();
610
611 let mut multi_msg = vec![];
612 for _ in 0..MSG_COUNT {
613 multi_msg.extend_from_slice(&encoded_header_response);
614 }
615 let mut reader = Cursor::new(multi_msg);
616
617 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
618 let mut codec = HeaderCodec {};
619
620 let decoded_header_response = codec
621 .read_response(&stream_protocol, &mut reader)
622 .await
623 .unwrap();
624
625 for decoded_header in decoded_header_response.iter() {
626 assert_eq!(&header_response, decoded_header);
627 }
628 assert_eq!(decoded_header_response.len(), MSG_COUNT);
629 }
630
631 #[async_test]
632 async fn test_decode_header_request_too_large() {
633 let too_long_message_len = REQUEST_SIZE_LIMIT + 1;
634 let mut length_delimiter_buffer = BytesMut::new();
635 prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
636 let mut reader = Cursor::new(length_delimiter_buffer);
637
638 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
639 let mut codec = HeaderCodec {};
640
641 let decoding_error = codec
642 .read_request(&stream_protocol, &mut reader)
643 .await
644 .expect_err("expected error for too large request");
645
646 assert_eq!(decoding_error.kind(), ErrorKind::Other);
647 }
648
649 #[async_test]
650 async fn test_decode_header_response_too_large() {
651 let too_long_message_len = RESPONSE_SIZE_LIMIT + 1;
652 let mut length_delimiter_buffer = BytesMut::new();
653 encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
654 let mut reader = Cursor::new(length_delimiter_buffer);
655
656 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
657 let mut codec = HeaderCodec {};
658
659 let decoding_error = codec
660 .read_response(&stream_protocol, &mut reader)
661 .await
662 .expect_err("expected error for too large request");
663
664 assert_eq!(decoding_error.kind(), ErrorKind::Other);
665 }
666
667 #[test]
668 fn test_invalid_varint() {
669 let varint = [
672 0b1000_0000,
673 0b1000_0000,
674 0b1000_0000,
675 0b1000_0000,
676 0b1000_0000,
677 0b1000_0000,
678 0b1000_0000,
679 0b1000_0000,
680 0b1000_0000,
681 0b1000_0000,
682 0b0000_0001,
683 ];
684
685 assert_eq!(parse_delimiter(&varint), None);
686 }
687
688 #[test]
689 fn parse_trailing_zero_varint() {
690 let varint = [0b1000_0001, 0b0000_0000, 0b1111_1111];
691 assert!(matches!(parse_delimiter(&varint), Some((1, [255]))));
692
693 let varint = [0b1000_0000, 0b1000_0000, 0b1000_0000, 0b0000_0000];
694 assert!(matches!(parse_delimiter(&varint), Some((0, []))));
695 }
696
697 #[async_test]
698 async fn test_decode_header_double_response_data() {
699 let mut header_response_buffer = BytesMut::with_capacity(512);
700 let header_response0 = HeaderResponse {
701 body: b"9999888877776666555544443333222211110000".to_vec(),
702 status_code: 1,
703 };
704 let header_response1 = HeaderResponse {
705 body: b"0000111122223333444455556666777788889999".to_vec(),
706 status_code: 2,
707 };
708 header_response0
709 .encode_length_delimited(&mut header_response_buffer)
710 .unwrap();
711 header_response1
712 .encode_length_delimited(&mut header_response_buffer)
713 .unwrap();
714 let mut reader = Cursor::new(header_response_buffer);
715
716 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
717 let mut codec = HeaderCodec {};
718
719 let decoded_header_response = codec
720 .read_response(&stream_protocol, &mut reader)
721 .await
722 .unwrap();
723 assert_eq!(header_response0, decoded_header_response[0]);
724 assert_eq!(header_response1, decoded_header_response[1]);
725 }
726
727 #[async_test]
728 async fn test_decode_header_request_chunked_data() {
729 let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
730 let header_request = HeaderRequest {
731 amount: 1,
732 data: Some(Data::Hash(data.to_vec())),
733 };
734 let encoded_header_request = header_request.encode_length_delimited_to_vec();
735
736 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
737 let mut codec = HeaderCodec {};
738 {
739 let mut reader =
740 ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_request.clone()));
741 let decoded_header_request = codec
742 .read_request(&stream_protocol, &mut reader)
743 .await
744 .unwrap();
745 assert_eq!(header_request, decoded_header_request);
746 }
747 {
748 let mut reader =
749 ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_request.clone()));
750 let decoded_header_request = codec
751 .read_request(&stream_protocol, &mut reader)
752 .await
753 .unwrap();
754
755 assert_eq!(header_request, decoded_header_request);
756 }
757 {
758 let mut reader =
759 ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_request.clone()));
760 let decoded_header_request = codec
761 .read_request(&stream_protocol, &mut reader)
762 .await
763 .unwrap();
764
765 assert_eq!(header_request, decoded_header_request);
766 }
767 {
768 let mut reader =
769 ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_request.clone()));
770 let decoded_header_request = codec
771 .read_request(&stream_protocol, &mut reader)
772 .await
773 .unwrap();
774
775 assert_eq!(header_request, decoded_header_request);
776 }
777 }
778
779 #[async_test]
780 async fn test_decode_header_response_chunked_data() {
781 let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
782 let header_response = HeaderResponse {
783 body: data.to_vec(),
784 status_code: 2,
785 };
786 let encoded_header_response = header_response.encode_length_delimited_to_vec();
787
788 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
789 let mut codec = HeaderCodec {};
790 {
791 let mut reader =
792 ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone()));
793 let decoded_header_response = codec
794 .read_response(&stream_protocol, &mut reader)
795 .await
796 .unwrap();
797 assert_eq!(header_response, decoded_header_response[0]);
798 }
799 {
800 let mut reader =
801 ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone()));
802 let decoded_header_response = codec
803 .read_response(&stream_protocol, &mut reader)
804 .await
805 .unwrap();
806
807 assert_eq!(header_response, decoded_header_response[0]);
808 }
809 {
810 let mut reader =
811 ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone()));
812 let decoded_header_response = codec
813 .read_response(&stream_protocol, &mut reader)
814 .await
815 .unwrap();
816
817 assert_eq!(header_response, decoded_header_response[0]);
818 }
819 {
820 let mut reader =
821 ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone()));
822 let decoded_header_response = codec
823 .read_response(&stream_protocol, &mut reader)
824 .await
825 .unwrap();
826
827 assert_eq!(header_response, decoded_header_response[0]);
828 }
829 }
830
831 #[async_test]
832 async fn test_chunky_async_read() {
833 let read_data = "FOO123";
834 let cur0 = Cursor::new(read_data);
835 let mut chunky = ChunkyAsyncRead::<_, 3>::new(cur0);
836
837 let mut output_buffer: BytesMut = b"BAR987".as_ref().into();
838
839 let _ = chunky.read(&mut output_buffer[..]).await.unwrap();
840 assert_eq!(output_buffer, b"FOO987".as_ref());
841 }
842
843 struct ChunkyAsyncRead<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> {
844 inner: T,
845 }
846
847 impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> ChunkyAsyncRead<T, CHUNK_SIZE> {
848 fn new(inner: T) -> Self {
849 ChunkyAsyncRead { inner }
850 }
851 }
852
853 impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> AsyncRead for ChunkyAsyncRead<T, CHUNK_SIZE> {
854 fn poll_read(
855 mut self: Pin<&mut Self>,
856 cx: &mut Context<'_>,
857 buf: &mut [u8],
858 ) -> Poll<Result<usize, Error>> {
859 let len = buf.len().min(CHUNK_SIZE);
860 Pin::new(&mut self.inner).poll_read(cx, &mut buf[..len])
861 }
862 }
863}