1use std::io;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4use std::time::Duration;
5
6use async_trait::async_trait;
7use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
8use celestia_types::ExtendedHeader;
9use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10use libp2p::core::transport::PortUse;
11use libp2p::{
12 core::Endpoint,
13 request_response::{self, Codec, InboundFailure, OutboundFailure, ProtocolSupport},
14 swarm::{
15 handler::ConnectionEvent, ConnectionDenied, ConnectionHandler, ConnectionHandlerEvent,
16 ConnectionId, FromSwarm, NetworkBehaviour, SubstreamProtocol, THandlerInEvent,
17 THandlerOutEvent, ToSwarm,
18 },
19 Multiaddr, PeerId, StreamProtocol,
20};
21use lumina_utils::time::{timeout, Instant};
22use prost::Message;
23use tracing::{debug, instrument, warn};
24
25mod client;
26mod server;
27pub(crate) mod utils;
28
29use crate::p2p::header_ex::client::HeaderExClientHandler;
30use crate::p2p::header_ex::server::HeaderExServerHandler;
31use crate::p2p::P2pError;
32use crate::peer_tracker::PeerTracker;
33use crate::store::Store;
34use crate::utils::{protocol_id, OneshotResultSender};
35
36const REQUEST_SIZE_LIMIT: usize = 1024;
38const REQUEST_TIME_LIMIT: Duration = Duration::from_secs(1);
40const RESPONSE_SIZE_LIMIT: usize = 10 * 1024 * 1024;
42const RESPONSE_TIME_LIMIT: Duration = Duration::from_secs(5);
44const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(1);
46
47type RequestType = HeaderRequest;
48type ResponseType = Vec<HeaderResponse>;
49type ReqRespBehaviour = request_response::Behaviour<HeaderCodec>;
50type ReqRespEvent = request_response::Event<RequestType, ResponseType>;
51type ReqRespMessage = request_response::Message<RequestType, ResponseType>;
52type ReqRespConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;
53
54pub(crate) struct HeaderExBehaviour<S>
55where
56 S: Store + 'static,
57{
58 req_resp: ReqRespBehaviour,
59 client_handler: HeaderExClientHandler,
60 server_handler: HeaderExServerHandler<S>,
61}
62
63pub(crate) struct HeaderExConfig<'a, S> {
64 pub network_id: &'a str,
65 pub peer_tracker: Arc<PeerTracker>,
66 pub header_store: Arc<S>,
67}
68
69#[derive(Debug, thiserror::Error)]
71pub enum HeaderExError {
72 #[error("Header not found")]
74 HeaderNotFound,
75
76 #[error("Invalid response")]
78 InvalidResponse,
79
80 #[error("Invalid request")]
82 InvalidRequest,
83
84 #[error("Inbound failure: {0}")]
86 InboundFailure(InboundFailure),
87
88 #[error("Outbound failure: {0}")]
90 OutboundFailure(OutboundFailure),
91
92 #[error("Request cancelled because `Node` is stopping")]
96 RequestCancelled,
97}
98
99impl<S> HeaderExBehaviour<S>
100where
101 S: Store + 'static,
102{
103 pub(crate) fn new(config: HeaderExConfig<'_, S>) -> Self {
104 HeaderExBehaviour {
105 req_resp: ReqRespBehaviour::new(
106 [(
107 protocol_id(config.network_id, "/header-ex/v0.0.3"),
108 ProtocolSupport::Full,
109 )],
110 request_response::Config::default(),
111 ),
112 client_handler: HeaderExClientHandler::new(config.peer_tracker),
113 server_handler: HeaderExServerHandler::new(config.header_store),
114 }
115 }
116
117 #[instrument(level = "trace", skip(self, respond_to))]
118 pub(crate) fn send_request(
119 &mut self,
120 request: HeaderRequest,
121 respond_to: OneshotResultSender<Vec<ExtendedHeader>, P2pError>,
122 ) {
123 self.client_handler
124 .on_send_request(&mut self.req_resp, request, respond_to);
125 }
126
127 pub(crate) fn stop(&mut self) {
128 self.client_handler.on_stop();
129 self.server_handler.on_stop();
130 }
131
132 fn on_to_swarm(
133 &mut self,
134 ev: ToSwarm<ReqRespEvent, THandlerInEvent<ReqRespBehaviour>>,
135 ) -> Option<ToSwarm<(), THandlerInEvent<Self>>> {
136 match ev {
137 ToSwarm::GenerateEvent(ev) => {
138 self.on_req_resp_event(ev);
139 None
140 }
141 _ => Some(ev.map_out(|_| ())),
142 }
143 }
144
145 #[instrument(level = "trace", skip_all)]
146 fn on_req_resp_event(&mut self, ev: ReqRespEvent) {
147 match ev {
148 ReqRespEvent::Message {
150 message:
151 ReqRespMessage::Response {
152 request_id,
153 response,
154 },
155 peer,
156 } => {
157 self.client_handler
158 .on_response_received(peer, request_id, response);
159 }
160
161 ReqRespEvent::OutboundFailure {
163 peer,
164 request_id,
165 error,
166 } => {
167 self.client_handler.on_failure(peer, request_id, error);
168 }
169
170 ReqRespEvent::Message {
172 message:
173 ReqRespMessage::Request {
174 request_id,
175 request,
176 channel,
177 },
178 peer,
179 } => {
180 self.server_handler.on_request_received(
181 peer,
182 request_id,
183 request,
184 &mut self.req_resp,
185 channel,
186 );
187 }
188
189 ReqRespEvent::ResponseSent { peer, request_id } => {
191 self.server_handler.on_response_sent(peer, request_id);
192 }
193
194 ReqRespEvent::InboundFailure {
196 peer,
197 request_id,
198 error,
199 } => {
200 self.server_handler.on_failure(peer, request_id, error);
201 }
202 }
203 }
204}
205
206impl<S> NetworkBehaviour for HeaderExBehaviour<S>
207where
208 S: Store + 'static,
209{
210 type ConnectionHandler = ConnHandler;
211 type ToSwarm = ();
212
213 fn handle_established_inbound_connection(
214 &mut self,
215 connection_id: ConnectionId,
216 peer: PeerId,
217 local_addr: &Multiaddr,
218 remote_addr: &Multiaddr,
219 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
220 self.req_resp
221 .handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)
222 .map(ConnHandler)
223 }
224
225 fn handle_established_outbound_connection(
226 &mut self,
227 connection_id: ConnectionId,
228 peer: PeerId,
229 addr: &Multiaddr,
230 role_override: Endpoint,
231 port_use: PortUse,
232 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
233 self.req_resp
234 .handle_established_outbound_connection(
235 connection_id,
236 peer,
237 addr,
238 role_override,
239 port_use,
240 )
241 .map(ConnHandler)
242 }
243
244 fn handle_pending_inbound_connection(
245 &mut self,
246 connection_id: ConnectionId,
247 local_addr: &Multiaddr,
248 remote_addr: &Multiaddr,
249 ) -> Result<(), ConnectionDenied> {
250 self.req_resp
251 .handle_pending_inbound_connection(connection_id, local_addr, remote_addr)
252 }
253
254 fn handle_pending_outbound_connection(
255 &mut self,
256 connection_id: ConnectionId,
257 maybe_peer: Option<PeerId>,
258 addresses: &[Multiaddr],
259 effective_role: Endpoint,
260 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
261 self.req_resp.handle_pending_outbound_connection(
262 connection_id,
263 maybe_peer,
264 addresses,
265 effective_role,
266 )
267 }
268
269 fn on_swarm_event(&mut self, event: FromSwarm) {
270 self.req_resp.on_swarm_event(event)
271 }
272
273 fn on_connection_handler_event(
274 &mut self,
275 peer_id: PeerId,
276 connection_id: ConnectionId,
277 event: THandlerOutEvent<Self>,
278 ) {
279 self.req_resp
280 .on_connection_handler_event(peer_id, connection_id, event)
281 }
282
283 fn poll(
284 &mut self,
285 cx: &mut Context<'_>,
286 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
287 loop {
288 if let Poll::Ready(ev) = self.req_resp.poll(cx) {
289 if let Some(ev) = self.on_to_swarm(ev) {
290 return Poll::Ready(ev);
291 }
292
293 continue;
294 }
295
296 if self.client_handler.poll(cx).is_ready() {
297 continue;
298 }
299
300 if self.server_handler.poll(cx, &mut self.req_resp).is_ready() {
301 continue;
302 }
303
304 return Poll::Pending;
305 }
306 }
307}
308
309pub(crate) struct ConnHandler(ReqRespConnectionHandler);
310
311impl ConnectionHandler for ConnHandler {
312 type ToBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::ToBehaviour;
313 type FromBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::FromBehaviour;
314 type InboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::InboundProtocol;
315 type InboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::InboundOpenInfo;
316 type OutboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::OutboundProtocol;
317 type OutboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::OutboundOpenInfo;
318
319 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
320 self.0.listen_protocol().with_timeout(NEGOTIATION_TIMEOUT)
321 }
322
323 fn poll(
324 &mut self,
325 cx: &mut Context<'_>,
326 ) -> Poll<
327 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
328 > {
329 match self.0.poll(cx) {
330 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
331 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
332 protocol: protocol.with_timeout(NEGOTIATION_TIMEOUT),
333 })
334 }
335 ev => ev,
336 }
337 }
338
339 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
340 self.0.on_behaviour_event(event)
341 }
342
343 fn on_connection_event(
344 &mut self,
345 event: ConnectionEvent<
346 '_,
347 Self::InboundProtocol,
348 Self::OutboundProtocol,
349 Self::InboundOpenInfo,
350 Self::OutboundOpenInfo,
351 >,
352 ) {
353 self.0.on_connection_event(event)
354 }
355
356 fn connection_keep_alive(&self) -> bool {
357 self.0.connection_keep_alive()
358 }
359
360 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
361 self.0.poll_close(cx)
362 }
363}
364
365#[derive(Clone, Copy, Debug, Default)]
366pub(crate) struct HeaderCodec;
367
368#[async_trait]
369impl Codec for HeaderCodec {
370 type Protocol = StreamProtocol;
371 type Request = HeaderRequest;
372 type Response = Vec<HeaderResponse>;
373
374 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
375 where
376 T: AsyncRead + Unpin + Send,
377 {
378 let data = read_up_to(io, REQUEST_SIZE_LIMIT, REQUEST_TIME_LIMIT).await?;
379
380 if data.len() >= REQUEST_SIZE_LIMIT {
381 debug!("Message filled the whole buffer (len: {})", data.len());
382 }
383
384 parse_header_request(&data).ok_or_else(|| {
385 io::Error::other("invalid or incomplete request")
390 })
391 }
392
393 async fn read_response<T>(
394 &mut self,
395 _: &Self::Protocol,
396 io: &mut T,
397 ) -> io::Result<Self::Response>
398 where
399 T: AsyncRead + Unpin + Send,
400 {
401 let data = read_up_to(io, RESPONSE_SIZE_LIMIT, RESPONSE_TIME_LIMIT).await?;
402
403 if data.len() >= RESPONSE_SIZE_LIMIT {
404 debug!("Message filled the whole buffer (len: {})", data.len());
405 }
406
407 let mut data = &data[..];
408 let mut msgs = Vec::new();
409
410 while let Some((header, rest)) = parse_header_response(data) {
411 msgs.push(header);
412 data = rest;
413 }
414
415 if msgs.is_empty() {
416 return Err(io::Error::other("invalid or incomplete response"));
421 }
422
423 Ok(msgs)
424 }
425
426 async fn write_request<T>(
427 &mut self,
428 _: &Self::Protocol,
429 io: &mut T,
430 req: Self::Request,
431 ) -> io::Result<()>
432 where
433 T: AsyncWrite + Unpin + Send,
434 {
435 let mut buf = Vec::with_capacity(REQUEST_SIZE_LIMIT);
436
437 let _ = req.encode_length_delimited(&mut buf);
438
439 timeout(REQUEST_TIME_LIMIT, io.write_all(&buf))
440 .await
441 .map_err(|_| io::Error::other("writing request timed out"))??;
442
443 Ok(())
444 }
445
446 async fn write_response<T>(
447 &mut self,
448 _: &Self::Protocol,
449 io: &mut T,
450 resps: Self::Response,
451 ) -> io::Result<()>
452 where
453 T: AsyncWrite + Unpin + Send,
454 {
455 let mut buf = Vec::with_capacity(RESPONSE_SIZE_LIMIT);
456
457 for resp in resps {
458 if resp.encode_length_delimited(&mut buf).is_err() {
459 debug!("Sending partial response");
462 break;
463 }
464 }
465
466 timeout(RESPONSE_TIME_LIMIT, io.write_all(&buf))
467 .await
468 .map_err(|_| io::Error::other("writing response timed out"))??;
469
470 Ok(())
471 }
472}
473
474async fn read_up_to<T>(io: &mut T, size_limit: usize, time_limit: Duration) -> io::Result<Vec<u8>>
476where
477 T: AsyncRead + Unpin + Send,
478{
479 let mut buf = vec![0u8; size_limit];
480 let mut read_len = 0;
481 let now = Instant::now();
482
483 loop {
484 if read_len == buf.len() {
485 break;
487 }
488
489 let Some(time_limit) = time_limit.checked_sub(now.elapsed()) else {
490 break;
491 };
492
493 let len = match timeout(time_limit, io.read(&mut buf[read_len..])).await {
494 Ok(Ok(len)) => len,
495 Ok(Err(e)) => return Err(e),
496 Err(_) => break,
497 };
498
499 if len == 0 {
500 break;
502 }
503
504 read_len += len;
505 }
506
507 buf.truncate(read_len);
508
509 Ok(buf)
510}
511
512fn parse_delimiter(mut buf: &[u8]) -> Option<(usize, &[u8])> {
513 if buf.is_empty() {
514 return None;
515 }
516
517 let Ok(len) = prost::decode_length_delimiter(&mut buf) else {
518 return None;
519 };
520
521 Some((len, buf))
522}
523
524fn parse_header_response(buf: &[u8]) -> Option<(HeaderResponse, &[u8])> {
525 let (len, rest) = parse_delimiter(buf)?;
526
527 if rest.len() < len {
528 debug!("Message is incomplete: {len}");
529 return None;
530 }
531
532 let Ok(msg) = HeaderResponse::decode(&rest[..len]) else {
533 return None;
534 };
535
536 Some((msg, &rest[len..]))
537}
538
539fn parse_header_request(buf: &[u8]) -> Option<HeaderRequest> {
540 let (len, rest) = parse_delimiter(buf)?;
541
542 if rest.len() < len {
543 debug!("Message is incomplete: {len}");
544 return None;
545 }
546
547 let Ok(msg) = HeaderRequest::decode(&rest[..len]) else {
548 return None;
549 };
550
551 Some(msg)
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557 use bytes::BytesMut;
558 use celestia_proto::p2p::pb::header_request::Data;
559 use futures::io::{Cursor, Error};
560 use lumina_utils::test_utils::async_test;
561 use prost::encode_length_delimiter;
562 use std::io::ErrorKind;
563 use std::pin::Pin;
564
565 #[async_test]
566 async fn test_decode_header_request_empty() {
567 let header_request = HeaderRequest {
568 amount: 0,
569 data: None,
570 };
571
572 let encoded_header_request = header_request.encode_length_delimited_to_vec();
573
574 let mut reader = Cursor::new(encoded_header_request);
575
576 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
577 let mut codec = HeaderCodec {};
578
579 let decoded_header_request = codec
580 .read_request(&stream_protocol, &mut reader)
581 .await
582 .unwrap();
583
584 assert_eq!(header_request, decoded_header_request);
585 }
586
587 #[async_test]
588 async fn test_decode_multiple_small_header_response() {
589 const MSG_COUNT: usize = 10;
590 let header_response = HeaderResponse {
591 body: vec![1, 2, 3],
592 status_code: 1,
593 };
594
595 let encoded_header_response = header_response.encode_length_delimited_to_vec();
596
597 let mut multi_msg = vec![];
598 for _ in 0..MSG_COUNT {
599 multi_msg.extend_from_slice(&encoded_header_response);
600 }
601 let mut reader = Cursor::new(multi_msg);
602
603 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
604 let mut codec = HeaderCodec {};
605
606 let decoded_header_response = codec
607 .read_response(&stream_protocol, &mut reader)
608 .await
609 .unwrap();
610
611 for decoded_header in decoded_header_response.iter() {
612 assert_eq!(&header_response, decoded_header);
613 }
614 assert_eq!(decoded_header_response.len(), MSG_COUNT);
615 }
616
617 #[async_test]
618 async fn test_decode_header_request_too_large() {
619 let too_long_message_len = REQUEST_SIZE_LIMIT + 1;
620 let mut length_delimiter_buffer = BytesMut::new();
621 prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
622 let mut reader = Cursor::new(length_delimiter_buffer);
623
624 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
625 let mut codec = HeaderCodec {};
626
627 let decoding_error = codec
628 .read_request(&stream_protocol, &mut reader)
629 .await
630 .expect_err("expected error for too large request");
631
632 assert_eq!(decoding_error.kind(), ErrorKind::Other);
633 }
634
635 #[async_test]
636 async fn test_decode_header_response_too_large() {
637 let too_long_message_len = RESPONSE_SIZE_LIMIT + 1;
638 let mut length_delimiter_buffer = BytesMut::new();
639 encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
640 let mut reader = Cursor::new(length_delimiter_buffer);
641
642 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
643 let mut codec = HeaderCodec {};
644
645 let decoding_error = codec
646 .read_response(&stream_protocol, &mut reader)
647 .await
648 .expect_err("expected error for too large request");
649
650 assert_eq!(decoding_error.kind(), ErrorKind::Other);
651 }
652
653 #[test]
654 fn test_invalid_varint() {
655 let varint = [
658 0b1000_0000,
659 0b1000_0000,
660 0b1000_0000,
661 0b1000_0000,
662 0b1000_0000,
663 0b1000_0000,
664 0b1000_0000,
665 0b1000_0000,
666 0b1000_0000,
667 0b1000_0000,
668 0b0000_0001,
669 ];
670
671 assert_eq!(parse_delimiter(&varint), None);
672 }
673
674 #[test]
675 fn parse_trailing_zero_varint() {
676 let varint = [0b1000_0001, 0b0000_0000, 0b1111_1111];
677 assert!(matches!(parse_delimiter(&varint), Some((1, [255]))));
678
679 let varint = [0b1000_0000, 0b1000_0000, 0b1000_0000, 0b0000_0000];
680 assert!(matches!(parse_delimiter(&varint), Some((0, []))));
681 }
682
683 #[async_test]
684 async fn test_decode_header_double_response_data() {
685 let mut header_response_buffer = BytesMut::with_capacity(512);
686 let header_response0 = HeaderResponse {
687 body: b"9999888877776666555544443333222211110000".to_vec(),
688 status_code: 1,
689 };
690 let header_response1 = HeaderResponse {
691 body: b"0000111122223333444455556666777788889999".to_vec(),
692 status_code: 2,
693 };
694 header_response0
695 .encode_length_delimited(&mut header_response_buffer)
696 .unwrap();
697 header_response1
698 .encode_length_delimited(&mut header_response_buffer)
699 .unwrap();
700 let mut reader = Cursor::new(header_response_buffer);
701
702 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
703 let mut codec = HeaderCodec {};
704
705 let decoded_header_response = codec
706 .read_response(&stream_protocol, &mut reader)
707 .await
708 .unwrap();
709 assert_eq!(header_response0, decoded_header_response[0]);
710 assert_eq!(header_response1, decoded_header_response[1]);
711 }
712
713 #[async_test]
714 async fn test_decode_header_request_chunked_data() {
715 let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
716 let header_request = HeaderRequest {
717 amount: 1,
718 data: Some(Data::Hash(data.to_vec())),
719 };
720 let encoded_header_request = header_request.encode_length_delimited_to_vec();
721
722 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
723 let mut codec = HeaderCodec {};
724 {
725 let mut reader =
726 ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_request.clone()));
727 let decoded_header_request = codec
728 .read_request(&stream_protocol, &mut reader)
729 .await
730 .unwrap();
731 assert_eq!(header_request, decoded_header_request);
732 }
733 {
734 let mut reader =
735 ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_request.clone()));
736 let decoded_header_request = codec
737 .read_request(&stream_protocol, &mut reader)
738 .await
739 .unwrap();
740
741 assert_eq!(header_request, decoded_header_request);
742 }
743 {
744 let mut reader =
745 ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_request.clone()));
746 let decoded_header_request = codec
747 .read_request(&stream_protocol, &mut reader)
748 .await
749 .unwrap();
750
751 assert_eq!(header_request, decoded_header_request);
752 }
753 {
754 let mut reader =
755 ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_request.clone()));
756 let decoded_header_request = codec
757 .read_request(&stream_protocol, &mut reader)
758 .await
759 .unwrap();
760
761 assert_eq!(header_request, decoded_header_request);
762 }
763 }
764
765 #[async_test]
766 async fn test_decode_header_response_chunked_data() {
767 let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
768 let header_response = HeaderResponse {
769 body: data.to_vec(),
770 status_code: 2,
771 };
772 let encoded_header_response = header_response.encode_length_delimited_to_vec();
773
774 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
775 let mut codec = HeaderCodec {};
776 {
777 let mut reader =
778 ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone()));
779 let decoded_header_response = codec
780 .read_response(&stream_protocol, &mut reader)
781 .await
782 .unwrap();
783 assert_eq!(header_response, decoded_header_response[0]);
784 }
785 {
786 let mut reader =
787 ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone()));
788 let decoded_header_response = codec
789 .read_response(&stream_protocol, &mut reader)
790 .await
791 .unwrap();
792
793 assert_eq!(header_response, decoded_header_response[0]);
794 }
795 {
796 let mut reader =
797 ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone()));
798 let decoded_header_response = codec
799 .read_response(&stream_protocol, &mut reader)
800 .await
801 .unwrap();
802
803 assert_eq!(header_response, decoded_header_response[0]);
804 }
805 {
806 let mut reader =
807 ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone()));
808 let decoded_header_response = codec
809 .read_response(&stream_protocol, &mut reader)
810 .await
811 .unwrap();
812
813 assert_eq!(header_response, decoded_header_response[0]);
814 }
815 }
816
817 #[async_test]
818 async fn test_chunky_async_read() {
819 let read_data = "FOO123";
820 let cur0 = Cursor::new(read_data);
821 let mut chunky = ChunkyAsyncRead::<_, 3>::new(cur0);
822
823 let mut output_buffer: BytesMut = b"BAR987".as_ref().into();
824
825 let _ = chunky.read(&mut output_buffer[..]).await.unwrap();
826 assert_eq!(output_buffer, b"FOO987".as_ref());
827 }
828
829 struct ChunkyAsyncRead<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> {
830 inner: T,
831 }
832
833 impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> ChunkyAsyncRead<T, CHUNK_SIZE> {
834 fn new(inner: T) -> Self {
835 ChunkyAsyncRead { inner }
836 }
837 }
838
839 impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> AsyncRead for ChunkyAsyncRead<T, CHUNK_SIZE> {
840 fn poll_read(
841 mut self: Pin<&mut Self>,
842 cx: &mut Context<'_>,
843 buf: &mut [u8],
844 ) -> Poll<Result<usize, Error>> {
845 let len = buf.len().min(CHUNK_SIZE);
846 Pin::new(&mut self.inner).poll_read(cx, &mut buf[..len])
847 }
848 }
849}