1use std::io;
2use std::sync::Arc;
3use std::task::{Context, Poll};
4
5use async_trait::async_trait;
6use celestia_proto::p2p::pb::{HeaderRequest, HeaderResponse};
7use celestia_types::ExtendedHeader;
8use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9use libp2p::core::transport::PortUse;
10use libp2p::{
11 core::Endpoint,
12 request_response::{self, Codec, InboundFailure, OutboundFailure, ProtocolSupport},
13 swarm::{
14 handler::ConnectionEvent, ConnectionDenied, ConnectionHandler, ConnectionHandlerEvent,
15 ConnectionId, FromSwarm, NetworkBehaviour, SubstreamProtocol, THandlerInEvent,
16 THandlerOutEvent, ToSwarm,
17 },
18 Multiaddr, PeerId, StreamProtocol,
19};
20use lumina_utils::time::timeout;
21use prost::Message;
22use tracing::{debug, instrument, warn};
23use web_time::{Duration, Instant};
24
25mod client;
26mod server;
27pub(crate) mod utils;
28
29use crate::p2p::header_ex::client::HeaderExClientHandler;
30use crate::p2p::header_ex::server::HeaderExServerHandler;
31use crate::p2p::P2pError;
32use crate::peer_tracker::PeerTracker;
33use crate::store::Store;
34use crate::utils::{protocol_id, OneshotResultSender};
35
36const REQUEST_SIZE_LIMIT: usize = 1024;
38const REQUEST_TIME_LIMIT: Duration = Duration::from_secs(1);
40const RESPONSE_SIZE_LIMIT: usize = 10 * 1024 * 1024;
42const RESPONSE_TIME_LIMIT: Duration = Duration::from_secs(5);
44const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(1);
46
47type RequestType = HeaderRequest;
48type ResponseType = Vec<HeaderResponse>;
49type ReqRespBehaviour = request_response::Behaviour<HeaderCodec>;
50type ReqRespEvent = request_response::Event<RequestType, ResponseType>;
51type ReqRespMessage = request_response::Message<RequestType, ResponseType>;
52type ReqRespConnectionHandler = <ReqRespBehaviour as NetworkBehaviour>::ConnectionHandler;
53
54pub(crate) struct HeaderExBehaviour<S>
55where
56 S: Store + 'static,
57{
58 req_resp: ReqRespBehaviour,
59 client_handler: HeaderExClientHandler,
60 server_handler: HeaderExServerHandler<S>,
61}
62
63pub(crate) struct HeaderExConfig<'a, S> {
64 pub network_id: &'a str,
65 pub peer_tracker: Arc<PeerTracker>,
66 pub header_store: Arc<S>,
67}
68
69#[derive(Debug, thiserror::Error)]
71pub enum HeaderExError {
72 #[error("Header not found")]
74 HeaderNotFound,
75
76 #[error("Invalid response")]
78 InvalidResponse,
79
80 #[error("Invalid request")]
82 InvalidRequest,
83
84 #[error("Inbound failure: {0}")]
86 InboundFailure(InboundFailure),
87
88 #[error("Outbound failure: {0}")]
90 OutboundFailure(OutboundFailure),
91
92 #[error("Request cancelled because `Node` is stopping")]
96 RequestCancelled,
97}
98
99impl<S> HeaderExBehaviour<S>
100where
101 S: Store + 'static,
102{
103 pub(crate) fn new(config: HeaderExConfig<'_, S>) -> Self {
104 HeaderExBehaviour {
105 req_resp: ReqRespBehaviour::new(
106 [(
107 protocol_id(config.network_id, "/header-ex/v0.0.3"),
108 ProtocolSupport::Full,
109 )],
110 request_response::Config::default(),
111 ),
112 client_handler: HeaderExClientHandler::new(config.peer_tracker),
113 server_handler: HeaderExServerHandler::new(config.header_store),
114 }
115 }
116
117 #[instrument(level = "trace", skip(self, respond_to))]
118 pub(crate) fn send_request(
119 &mut self,
120 request: HeaderRequest,
121 respond_to: OneshotResultSender<Vec<ExtendedHeader>, P2pError>,
122 ) {
123 self.client_handler
124 .on_send_request(&mut self.req_resp, request, respond_to);
125 }
126
127 pub(crate) fn stop(&mut self) {
128 self.client_handler.on_stop();
129 self.server_handler.on_stop();
130 }
131
132 fn on_to_swarm(
133 &mut self,
134 ev: ToSwarm<ReqRespEvent, THandlerInEvent<ReqRespBehaviour>>,
135 ) -> Option<ToSwarm<(), THandlerInEvent<Self>>> {
136 match ev {
137 ToSwarm::GenerateEvent(ev) => {
138 self.on_req_resp_event(ev);
139 None
140 }
141 _ => Some(ev.map_out(|_| ())),
142 }
143 }
144
145 #[instrument(level = "trace", skip_all)]
146 fn on_req_resp_event(&mut self, ev: ReqRespEvent) {
147 match ev {
148 ReqRespEvent::Message {
150 message:
151 ReqRespMessage::Response {
152 request_id,
153 response,
154 },
155 peer,
156 } => {
157 self.client_handler
158 .on_response_received(peer, request_id, response);
159 }
160
161 ReqRespEvent::OutboundFailure {
163 peer,
164 request_id,
165 error,
166 } => {
167 self.client_handler.on_failure(peer, request_id, error);
168 }
169
170 ReqRespEvent::Message {
172 message:
173 ReqRespMessage::Request {
174 request_id,
175 request,
176 channel,
177 },
178 peer,
179 } => {
180 self.server_handler.on_request_received(
181 peer,
182 request_id,
183 request,
184 &mut self.req_resp,
185 channel,
186 );
187 }
188
189 ReqRespEvent::ResponseSent { peer, request_id } => {
191 self.server_handler.on_response_sent(peer, request_id);
192 }
193
194 ReqRespEvent::InboundFailure {
196 peer,
197 request_id,
198 error,
199 } => {
200 self.server_handler.on_failure(peer, request_id, error);
201 }
202 }
203 }
204}
205
206impl<S> NetworkBehaviour for HeaderExBehaviour<S>
207where
208 S: Store + 'static,
209{
210 type ConnectionHandler = ConnHandler;
211 type ToSwarm = ();
212
213 fn handle_established_inbound_connection(
214 &mut self,
215 connection_id: ConnectionId,
216 peer: PeerId,
217 local_addr: &Multiaddr,
218 remote_addr: &Multiaddr,
219 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
220 self.req_resp
221 .handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)
222 .map(ConnHandler)
223 }
224
225 fn handle_established_outbound_connection(
226 &mut self,
227 connection_id: ConnectionId,
228 peer: PeerId,
229 addr: &Multiaddr,
230 role_override: Endpoint,
231 port_use: PortUse,
232 ) -> Result<Self::ConnectionHandler, ConnectionDenied> {
233 self.req_resp
234 .handle_established_outbound_connection(
235 connection_id,
236 peer,
237 addr,
238 role_override,
239 port_use,
240 )
241 .map(ConnHandler)
242 }
243
244 fn handle_pending_inbound_connection(
245 &mut self,
246 connection_id: ConnectionId,
247 local_addr: &Multiaddr,
248 remote_addr: &Multiaddr,
249 ) -> Result<(), ConnectionDenied> {
250 self.req_resp
251 .handle_pending_inbound_connection(connection_id, local_addr, remote_addr)
252 }
253
254 fn handle_pending_outbound_connection(
255 &mut self,
256 connection_id: ConnectionId,
257 maybe_peer: Option<PeerId>,
258 addresses: &[Multiaddr],
259 effective_role: Endpoint,
260 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
261 self.req_resp.handle_pending_outbound_connection(
262 connection_id,
263 maybe_peer,
264 addresses,
265 effective_role,
266 )
267 }
268
269 fn on_swarm_event(&mut self, event: FromSwarm) {
270 self.req_resp.on_swarm_event(event)
271 }
272
273 fn on_connection_handler_event(
274 &mut self,
275 peer_id: PeerId,
276 connection_id: ConnectionId,
277 event: THandlerOutEvent<Self>,
278 ) {
279 self.req_resp
280 .on_connection_handler_event(peer_id, connection_id, event)
281 }
282
283 fn poll(
284 &mut self,
285 cx: &mut Context<'_>,
286 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
287 loop {
288 if let Poll::Ready(ev) = self.req_resp.poll(cx) {
289 if let Some(ev) = self.on_to_swarm(ev) {
290 return Poll::Ready(ev);
291 }
292
293 continue;
294 }
295
296 if self.client_handler.poll(cx).is_ready() {
297 continue;
298 }
299
300 if self.server_handler.poll(cx, &mut self.req_resp).is_ready() {
301 continue;
302 }
303
304 return Poll::Pending;
305 }
306 }
307}
308
309pub(crate) struct ConnHandler(ReqRespConnectionHandler);
310
311impl ConnectionHandler for ConnHandler {
312 type ToBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::ToBehaviour;
313 type FromBehaviour = <ReqRespConnectionHandler as ConnectionHandler>::FromBehaviour;
314 type InboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::InboundProtocol;
315 type InboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::InboundOpenInfo;
316 type OutboundProtocol = <ReqRespConnectionHandler as ConnectionHandler>::OutboundProtocol;
317 type OutboundOpenInfo = <ReqRespConnectionHandler as ConnectionHandler>::OutboundOpenInfo;
318
319 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
320 self.0.listen_protocol().with_timeout(NEGOTIATION_TIMEOUT)
321 }
322
323 fn poll(
324 &mut self,
325 cx: &mut Context<'_>,
326 ) -> Poll<
327 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
328 > {
329 match self.0.poll(cx) {
330 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
331 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
332 protocol: protocol.with_timeout(NEGOTIATION_TIMEOUT),
333 })
334 }
335 ev => ev,
336 }
337 }
338
339 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
340 self.0.on_behaviour_event(event)
341 }
342
343 fn on_connection_event(
344 &mut self,
345 event: ConnectionEvent<
346 '_,
347 Self::InboundProtocol,
348 Self::OutboundProtocol,
349 Self::InboundOpenInfo,
350 Self::OutboundOpenInfo,
351 >,
352 ) {
353 self.0.on_connection_event(event)
354 }
355
356 fn connection_keep_alive(&self) -> bool {
357 self.0.connection_keep_alive()
358 }
359
360 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
361 self.0.poll_close(cx)
362 }
363}
364
365#[derive(Clone, Copy, Debug, Default)]
366pub(crate) struct HeaderCodec;
367
368#[async_trait]
369impl Codec for HeaderCodec {
370 type Protocol = StreamProtocol;
371 type Request = HeaderRequest;
372 type Response = Vec<HeaderResponse>;
373
374 async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
375 where
376 T: AsyncRead + Unpin + Send,
377 {
378 let data = read_up_to(io, REQUEST_SIZE_LIMIT, REQUEST_TIME_LIMIT).await?;
379
380 if data.len() >= REQUEST_SIZE_LIMIT {
381 debug!("Message filled the whole buffer (len: {})", data.len());
382 }
383
384 parse_header_request(&data).ok_or_else(|| {
385 io::Error::new(io::ErrorKind::Other, "invalid or incomplete request")
390 })
391 }
392
393 async fn read_response<T>(
394 &mut self,
395 _: &Self::Protocol,
396 io: &mut T,
397 ) -> io::Result<Self::Response>
398 where
399 T: AsyncRead + Unpin + Send,
400 {
401 let data = read_up_to(io, RESPONSE_SIZE_LIMIT, RESPONSE_TIME_LIMIT).await?;
402
403 if data.len() >= RESPONSE_SIZE_LIMIT {
404 debug!("Message filled the whole buffer (len: {})", data.len());
405 }
406
407 let mut data = &data[..];
408 let mut msgs = Vec::new();
409
410 while let Some((header, rest)) = parse_header_response(data) {
411 msgs.push(header);
412 data = rest;
413 }
414
415 if msgs.is_empty() {
416 return Err(io::Error::new(
421 io::ErrorKind::Other,
422 "invalid or incomplete response",
423 ));
424 }
425
426 Ok(msgs)
427 }
428
429 async fn write_request<T>(
430 &mut self,
431 _: &Self::Protocol,
432 io: &mut T,
433 req: Self::Request,
434 ) -> io::Result<()>
435 where
436 T: AsyncWrite + Unpin + Send,
437 {
438 let mut buf = Vec::with_capacity(REQUEST_SIZE_LIMIT);
439
440 let _ = req.encode_length_delimited(&mut buf);
441
442 timeout(REQUEST_TIME_LIMIT, io.write_all(&buf))
443 .await
444 .map_err(|_| io::Error::new(io::ErrorKind::Other, "writing request timed out"))??;
445
446 Ok(())
447 }
448
449 async fn write_response<T>(
450 &mut self,
451 _: &Self::Protocol,
452 io: &mut T,
453 resps: Self::Response,
454 ) -> io::Result<()>
455 where
456 T: AsyncWrite + Unpin + Send,
457 {
458 let mut buf = Vec::with_capacity(RESPONSE_SIZE_LIMIT);
459
460 for resp in resps {
461 if resp.encode_length_delimited(&mut buf).is_err() {
462 debug!("Sending partial response");
465 break;
466 }
467 }
468
469 timeout(RESPONSE_TIME_LIMIT, io.write_all(&buf))
470 .await
471 .map_err(|_| io::Error::new(io::ErrorKind::Other, "writing response timed out"))??;
472
473 Ok(())
474 }
475}
476
477async fn read_up_to<T>(io: &mut T, size_limit: usize, time_limit: Duration) -> io::Result<Vec<u8>>
479where
480 T: AsyncRead + Unpin + Send,
481{
482 let mut buf = vec![0u8; size_limit];
483 let mut read_len = 0;
484 let now = Instant::now();
485
486 loop {
487 if read_len == buf.len() {
488 break;
490 }
491
492 let Some(time_limit) = time_limit.checked_sub(now.elapsed()) else {
493 break;
494 };
495
496 let len = match timeout(time_limit, io.read(&mut buf[read_len..])).await {
497 Ok(Ok(len)) => len,
498 Ok(Err(e)) => return Err(e),
499 Err(_) => break,
500 };
501
502 if len == 0 {
503 break;
505 }
506
507 read_len += len;
508 }
509
510 buf.truncate(read_len);
511
512 Ok(buf)
513}
514
515fn parse_delimiter(mut buf: &[u8]) -> Option<(usize, &[u8])> {
516 if buf.is_empty() {
517 return None;
518 }
519
520 let Ok(len) = prost::decode_length_delimiter(&mut buf) else {
521 return None;
522 };
523
524 Some((len, buf))
525}
526
527fn parse_header_response(buf: &[u8]) -> Option<(HeaderResponse, &[u8])> {
528 let (len, rest) = parse_delimiter(buf)?;
529
530 if rest.len() < len {
531 debug!("Message is incomplete: {len}");
532 return None;
533 }
534
535 let Ok(msg) = HeaderResponse::decode(&rest[..len]) else {
536 return None;
537 };
538
539 Some((msg, &rest[len..]))
540}
541
542fn parse_header_request(buf: &[u8]) -> Option<HeaderRequest> {
543 let (len, rest) = parse_delimiter(buf)?;
544
545 if rest.len() < len {
546 debug!("Message is incomplete: {len}");
547 return None;
548 }
549
550 let Ok(msg) = HeaderRequest::decode(&rest[..len]) else {
551 return None;
552 };
553
554 Some(msg)
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use bytes::BytesMut;
561 use celestia_proto::p2p::pb::header_request::Data;
562 use futures::io::{Cursor, Error};
563 use lumina_utils::test_utils::async_test;
564 use prost::encode_length_delimiter;
565 use std::io::ErrorKind;
566 use std::pin::Pin;
567
568 #[async_test]
569 async fn test_decode_header_request_empty() {
570 let header_request = HeaderRequest {
571 amount: 0,
572 data: None,
573 };
574
575 let encoded_header_request = header_request.encode_length_delimited_to_vec();
576
577 let mut reader = Cursor::new(encoded_header_request);
578
579 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
580 let mut codec = HeaderCodec {};
581
582 let decoded_header_request = codec
583 .read_request(&stream_protocol, &mut reader)
584 .await
585 .unwrap();
586
587 assert_eq!(header_request, decoded_header_request);
588 }
589
590 #[async_test]
591 async fn test_decode_multiple_small_header_response() {
592 const MSG_COUNT: usize = 10;
593 let header_response = HeaderResponse {
594 body: vec![1, 2, 3],
595 status_code: 1,
596 };
597
598 let encoded_header_response = header_response.encode_length_delimited_to_vec();
599
600 let mut multi_msg = vec![];
601 for _ in 0..MSG_COUNT {
602 multi_msg.extend_from_slice(&encoded_header_response);
603 }
604 let mut reader = Cursor::new(multi_msg);
605
606 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
607 let mut codec = HeaderCodec {};
608
609 let decoded_header_response = codec
610 .read_response(&stream_protocol, &mut reader)
611 .await
612 .unwrap();
613
614 for decoded_header in decoded_header_response.iter() {
615 assert_eq!(&header_response, decoded_header);
616 }
617 assert_eq!(decoded_header_response.len(), MSG_COUNT);
618 }
619
620 #[async_test]
621 async fn test_decode_header_request_too_large() {
622 let too_long_message_len = REQUEST_SIZE_LIMIT + 1;
623 let mut length_delimiter_buffer = BytesMut::new();
624 prost::encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
625 let mut reader = Cursor::new(length_delimiter_buffer);
626
627 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
628 let mut codec = HeaderCodec {};
629
630 let decoding_error = codec
631 .read_request(&stream_protocol, &mut reader)
632 .await
633 .expect_err("expected error for too large request");
634
635 assert_eq!(decoding_error.kind(), ErrorKind::Other);
636 }
637
638 #[async_test]
639 async fn test_decode_header_response_too_large() {
640 let too_long_message_len = RESPONSE_SIZE_LIMIT + 1;
641 let mut length_delimiter_buffer = BytesMut::new();
642 encode_length_delimiter(too_long_message_len, &mut length_delimiter_buffer).unwrap();
643 let mut reader = Cursor::new(length_delimiter_buffer);
644
645 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
646 let mut codec = HeaderCodec {};
647
648 let decoding_error = codec
649 .read_response(&stream_protocol, &mut reader)
650 .await
651 .expect_err("expected error for too large request");
652
653 assert_eq!(decoding_error.kind(), ErrorKind::Other);
654 }
655
656 #[test]
657 fn test_invalid_varint() {
658 let varint = [
661 0b1000_0000,
662 0b1000_0000,
663 0b1000_0000,
664 0b1000_0000,
665 0b1000_0000,
666 0b1000_0000,
667 0b1000_0000,
668 0b1000_0000,
669 0b1000_0000,
670 0b1000_0000,
671 0b0000_0001,
672 ];
673
674 assert_eq!(parse_delimiter(&varint), None);
675 }
676
677 #[test]
678 fn parse_trailing_zero_varint() {
679 let varint = [0b1000_0001, 0b0000_0000, 0b1111_1111];
680 assert!(matches!(parse_delimiter(&varint), Some((1, [255]))));
681
682 let varint = [0b1000_0000, 0b1000_0000, 0b1000_0000, 0b0000_0000];
683 assert!(matches!(parse_delimiter(&varint), Some((0, []))));
684 }
685
686 #[async_test]
687 async fn test_decode_header_double_response_data() {
688 let mut header_response_buffer = BytesMut::with_capacity(512);
689 let header_response0 = HeaderResponse {
690 body: b"9999888877776666555544443333222211110000".to_vec(),
691 status_code: 1,
692 };
693 let header_response1 = HeaderResponse {
694 body: b"0000111122223333444455556666777788889999".to_vec(),
695 status_code: 2,
696 };
697 header_response0
698 .encode_length_delimited(&mut header_response_buffer)
699 .unwrap();
700 header_response1
701 .encode_length_delimited(&mut header_response_buffer)
702 .unwrap();
703 let mut reader = Cursor::new(header_response_buffer);
704
705 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
706 let mut codec = HeaderCodec {};
707
708 let decoded_header_response = codec
709 .read_response(&stream_protocol, &mut reader)
710 .await
711 .unwrap();
712 assert_eq!(header_response0, decoded_header_response[0]);
713 assert_eq!(header_response1, decoded_header_response[1]);
714 }
715
716 #[async_test]
717 async fn test_decode_header_request_chunked_data() {
718 let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
719 let header_request = HeaderRequest {
720 amount: 1,
721 data: Some(Data::Hash(data.to_vec())),
722 };
723 let encoded_header_request = header_request.encode_length_delimited_to_vec();
724
725 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
726 let mut codec = HeaderCodec {};
727 {
728 let mut reader =
729 ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_request.clone()));
730 let decoded_header_request = codec
731 .read_request(&stream_protocol, &mut reader)
732 .await
733 .unwrap();
734 assert_eq!(header_request, decoded_header_request);
735 }
736 {
737 let mut reader =
738 ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_request.clone()));
739 let decoded_header_request = codec
740 .read_request(&stream_protocol, &mut reader)
741 .await
742 .unwrap();
743
744 assert_eq!(header_request, decoded_header_request);
745 }
746 {
747 let mut reader =
748 ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_request.clone()));
749 let decoded_header_request = codec
750 .read_request(&stream_protocol, &mut reader)
751 .await
752 .unwrap();
753
754 assert_eq!(header_request, decoded_header_request);
755 }
756 {
757 let mut reader =
758 ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_request.clone()));
759 let decoded_header_request = codec
760 .read_request(&stream_protocol, &mut reader)
761 .await
762 .unwrap();
763
764 assert_eq!(header_request, decoded_header_request);
765 }
766 }
767
768 #[async_test]
769 async fn test_decode_header_response_chunked_data() {
770 let data = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ";
771 let header_response = HeaderResponse {
772 body: data.to_vec(),
773 status_code: 2,
774 };
775 let encoded_header_response = header_response.encode_length_delimited_to_vec();
776
777 let stream_protocol = StreamProtocol::new("/foo/bar/v0.1");
778 let mut codec = HeaderCodec {};
779 {
780 let mut reader =
781 ChunkyAsyncRead::<_, 1>::new(Cursor::new(encoded_header_response.clone()));
782 let decoded_header_response = codec
783 .read_response(&stream_protocol, &mut reader)
784 .await
785 .unwrap();
786 assert_eq!(header_response, decoded_header_response[0]);
787 }
788 {
789 let mut reader =
790 ChunkyAsyncRead::<_, 2>::new(Cursor::new(encoded_header_response.clone()));
791 let decoded_header_response = codec
792 .read_response(&stream_protocol, &mut reader)
793 .await
794 .unwrap();
795
796 assert_eq!(header_response, decoded_header_response[0]);
797 }
798 {
799 let mut reader =
800 ChunkyAsyncRead::<_, 3>::new(Cursor::new(encoded_header_response.clone()));
801 let decoded_header_response = codec
802 .read_response(&stream_protocol, &mut reader)
803 .await
804 .unwrap();
805
806 assert_eq!(header_response, decoded_header_response[0]);
807 }
808 {
809 let mut reader =
810 ChunkyAsyncRead::<_, 4>::new(Cursor::new(encoded_header_response.clone()));
811 let decoded_header_response = codec
812 .read_response(&stream_protocol, &mut reader)
813 .await
814 .unwrap();
815
816 assert_eq!(header_response, decoded_header_response[0]);
817 }
818 }
819
820 #[async_test]
821 async fn test_chunky_async_read() {
822 let read_data = "FOO123";
823 let cur0 = Cursor::new(read_data);
824 let mut chunky = ChunkyAsyncRead::<_, 3>::new(cur0);
825
826 let mut output_buffer: BytesMut = b"BAR987".as_ref().into();
827
828 let _ = chunky.read(&mut output_buffer[..]).await.unwrap();
829 assert_eq!(output_buffer, b"FOO987".as_ref());
830 }
831
832 struct ChunkyAsyncRead<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> {
833 inner: T,
834 }
835
836 impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> ChunkyAsyncRead<T, CHUNK_SIZE> {
837 fn new(inner: T) -> Self {
838 ChunkyAsyncRead { inner }
839 }
840 }
841
842 impl<T: AsyncRead + Unpin, const CHUNK_SIZE: usize> AsyncRead for ChunkyAsyncRead<T, CHUNK_SIZE> {
843 fn poll_read(
844 mut self: Pin<&mut Self>,
845 cx: &mut Context<'_>,
846 buf: &mut [u8],
847 ) -> Poll<Result<usize, Error>> {
848 let len = buf.len().min(CHUNK_SIZE);
849 Pin::new(&mut self.inner).poll_read(cx, &mut buf[..len])
850 }
851 }
852}